Newer
Older
cortex-hub / poc-grpc-agent / orchestrator / services / grpc_server.py
import threading
import queue
import time
import os
try:
    import requests as _requests  # optional; only needed for M4 token validation
except ImportError:
    _requests = None
from protos import agent_pb2, agent_pb2_grpc
from orchestrator.core.registry import MemoryNodeRegistry
from orchestrator.core.journal import TaskJournal
from orchestrator.core.pool import GlobalWorkPool
from orchestrator.core.mirror import GhostMirrorManager
from orchestrator.services.assistant import TaskAssistant
from orchestrator.utils.crypto import sign_payload

# M4: Hub HTTP API for invite-token validation
# Calls POST /nodes/validate-token before accepting any SyncConfiguration.
# Set HUB_API_URL=http://localhost:8000 (or 0 to skip validation in dev mode).
HUB_API_URL = os.getenv("HUB_API_URL", "")       # empty = skip validation (dev)
HUB_API_PATH = "/nodes/validate-token"

class AgentOrchestrator(agent_pb2_grpc.AgentOrchestratorServicer):
    """Refactored gRPC Servicer for Agent Orchestration."""
    def __init__(self):
        self.registry = MemoryNodeRegistry()
        self.journal = TaskJournal()
        self.pool = GlobalWorkPool()
        self.mirror = GhostMirrorManager()
        self.assistant = TaskAssistant(self.registry, self.journal, self.pool, self.mirror)
        self.pool.on_new_work = self._broadcast_work
        
        # 4. Mesh Observation (Aggregated Health Dashboard)
        threading.Thread(target=self._monitor_mesh, daemon=True, name="MeshMonitor").start()

    def _monitor_mesh(self):
        """Periodically prints status of all nodes in the mesh."""
        while True:
            time.sleep(10)
            active_nodes = self.registry.list_nodes()
            print("\n" + "="*50)
            print(f"📡 CORTEX MESH DASHBOARD | {len(active_nodes)} Nodes Online")
            print("-" * 50)
            if not active_nodes:
                print("  No nodes currently connected.")
            for nid in active_nodes:
                node = self.registry.get_node(nid)
                stats = node.get("stats", {})
                tasks = stats.get("running", [])
                capability = node.get("metadata", {}).get("caps", {})
                print(f"  🟢 {nid:15} | Workers: {stats.get('active_worker_count', 0)} | Running: {len(tasks)} tasks")
                print(f"      Capabilities: {capability}")
            print("="*50 + "\n", flush=True)

    def _broadcast_work(self, _):
        """Pushes work notifications to all active nodes."""
        with self.registry.lock:
            for node_id, node in self.registry.nodes.items():
                print(f"    [📢] Broadcasting availability to {node_id}")
                node["queue"].put(agent_pb2.ServerTaskMessage(
                    work_pool_update=agent_pb2.WorkPoolUpdate(available_task_ids=self.pool.list_available())
                ))
    
    def SyncConfiguration(self, request, context):
        """M4 Authenticated Handshake: Validate invite_token, then send policy."""
        node_id = request.node_id
        invite_token = request.auth_token  # field in RegistrationRequest proto

        # --- M4: Token validation via Hub API ---
        if HUB_API_URL and _requests:
            try:
                resp = _requests.post(
                    f"{HUB_API_URL}{HUB_API_PATH}",
                    params={"node_id": node_id, "token": invite_token},
                    timeout=5,
                )
                payload = resp.json()
                if not payload.get("valid"):
                    reason = payload.get("reason", "Token rejected")
                    print(f"[🔒] SyncConfiguration REJECTED {node_id}: {reason}")
                    return agent_pb2.RegistrationResponse(
                        success=False,
                        message=reason,
                    )
                skill_cfg = payload.get("skill_config", {})
                print(f"[🔑] Token validated for {node_id} (display: {payload.get('display_name')})")
            except Exception as e:
                # If Hub is unreachable in dev, fall through with a warning
                print(f"[⚠️] Hub token validation unavailable ({e}); proceeding without auth.")
                skill_cfg = {}
        else:
            # Dev mode: skip validation
            skill_cfg = {}
            print(f"[⚠️] HUB_API_URL not set — skipping invite_token validation for {node_id}")

        # Build allowed_commands from skill_config (shell skill)
        shell_cfg = skill_cfg.get("shell", {})
        if shell_cfg.get("enabled", True):
            allowed_commands = ["ls", "cat", "echo", "pwd", "uname", "curl", "python3", "git"]
        else:
            allowed_commands = []  # Shell disabled by admin

        # Register the node in the local in-memory registry
        self.registry.register(request.node_id, queue.Queue(), {
            "desc": request.node_description,
            "caps": dict(request.capabilities),
        })

        return agent_pb2.RegistrationResponse(
            success=True,
            policy=agent_pb2.SandboxPolicy(
                mode=agent_pb2.SandboxPolicy.STRICT,
                allowed_commands=allowed_commands,
            )
        )

    def TaskStream(self, request_iterator, context):
        """Persistent Bi-directional Stream for Command & Control."""
        try:
            # 1. Blocking wait for Node Identity
            first_msg = next(request_iterator)
            if first_msg.WhichOneof('payload') != 'announce':
                print("[!] Stream rejected: No NodeAnnounce")
                return
            
            node_id = first_msg.announce.node_id
            node = self.registry.get_node(node_id)
            if not node:
                print(f"[!] Stream rejected: Node {node_id} not registered")
                return
            
            print(f"[📶] Stream Online for {node_id}")
            
            # Phase 5: Automatic Reconciliation on Reconnect
            self.assistant.reconcile_node(node_id)

            # 2. Results Listener (Read Thread)
            def _read_results():
                for msg in request_iterator:
                    self._handle_client_message(msg, node_id, node)
            
            threading.Thread(target=_read_results, daemon=True, name=f"Results-{node_id}").start()

            # 3. Work Dispatcher (Main Stream)
            last_keepalive = 0
            while context.is_active():
                try:
                    # Non-blocking wait to check context periodically
                    msg = node["queue"].get(timeout=1.0)
                    yield msg
                except queue.Empty:
                    # Occasional broadcast to nodes to ensure pool sync
                    now = time.time()
                    if (now - last_keepalive) > 10.0:
                        last_keepalive = now
                        if self.pool.available:
                            yield agent_pb2.ServerTaskMessage(
                                work_pool_update=agent_pb2.WorkPoolUpdate(available_task_ids=self.pool.list_available())
                            )
                    continue
                    
        except StopIteration: pass
        except Exception as e:
            print(f"[!] TaskStream Error for {node_id}: {e}")

    def _handle_client_message(self, msg, node_id, node):
        kind = msg.WhichOneof('payload')
        if kind == 'task_claim':
            task_id = msg.task_claim.task_id
            success, payload = self.pool.claim(task_id, node_id)
            
            # Send status response back to the node first
            node["queue"].put(agent_pb2.ServerTaskMessage(
                claim_status=agent_pb2.TaskClaimResponse(
                    task_id=task_id, 
                    granted=success, 
                    reason="Task successfully claimed" if success else "Task already claimed by another node"
                )
            ))

            if success:
                sig = sign_payload(payload)
                node["queue"].put(agent_pb2.ServerTaskMessage(
                    task_request=agent_pb2.TaskRequest(
                        task_id=task_id, 
                        payload_json=payload, 
                        signature=sig
                    )
                ))
        
        elif kind == 'task_response':
            res_obj = {"stdout": msg.task_response.stdout, "status": msg.task_response.status}
            if msg.task_response.HasField("browser_result"):
                br = msg.task_response.browser_result
                res_obj["browser"] = {
                    "url": br.url, "title": br.title, "has_snapshot": len(br.snapshot) > 0,
                    "a11y": br.a11y_tree[:100] + "..." if br.a11y_tree else None,
                    "eval": br.eval_result
                }
            self.journal.fulfill(msg.task_response.task_id, res_obj)
        
        elif kind == 'browser_event':
            e = msg.browser_event
            prefix = "[🖥️] Live Console" if e.HasField("console_msg") else "[🌐] Net Inspect"
            content = e.console_msg.text if e.HasField("console_msg") else f"{e.network_req.method} {e.network_req.url}"
            print(f"    {prefix}: {content}", flush=True)

        elif kind == 'file_sync':
            # Handle inbound file data from nodes (Node-Primary model)
            fs = msg.file_sync
            if fs.HasField("file_data"):
                print(f"    [📁📥] Mirroring {fs.file_data.path} (chunk {fs.file_data.chunk_index})")
                self.mirror.write_file_chunk(fs.session_id, fs.file_data)
                # BROADCAST to other nodes in the mesh
                self.assistant.broadcast_file_chunk(fs.session_id, node_id, fs.file_data)
            elif fs.HasField("status"):
                print(f"    [📁] Sync Status from {node_id}: {fs.status.message}")
                if fs.status.code == agent_pb2.SyncStatus.RECONCILE_REQUIRED:
                    print(f"    [📁🔄] Server triggering recovery sync for {len(fs.status.reconcile_paths)} files to {node_id}")
                    for path in fs.status.reconcile_paths:
                        self.assistant.push_file(node_id, fs.session_id, path)

    def ReportHealth(self, request_iterator, context):
        """Collect Health Metrics and Feed Policy Updates."""
        for hb in request_iterator:
            self.registry.update_stats(hb.node_id, {
                "active_worker_count": hb.active_worker_count, 
                "running": list(hb.running_task_ids)
            })
            yield agent_pb2.HealthCheckResponse(server_time_ms=int(time.time()*1000))