Newer
Older
cortex-hub / poc-grpc-agent / orchestrator / services / grpc_server.py
import threading
import queue
import time
from protos import agent_pb2, agent_pb2_grpc
from orchestrator.core.registry import MemoryNodeRegistry
from orchestrator.core.journal import TaskJournal
from orchestrator.core.pool import GlobalWorkPool
from orchestrator.core.mirror import GhostMirrorManager
from orchestrator.services.assistant import TaskAssistant
from orchestrator.utils.crypto import sign_payload

class AgentOrchestrator(agent_pb2_grpc.AgentOrchestratorServicer):
    """Refactored gRPC Servicer for Agent Orchestration."""
    def __init__(self):
        self.registry = MemoryNodeRegistry()
        self.journal = TaskJournal()
        self.pool = GlobalWorkPool()
        self.mirror = GhostMirrorManager()
        self.assistant = TaskAssistant(self.registry, self.journal, self.pool, self.mirror)
        self.pool.on_new_work = self._broadcast_work
        
        # 4. Mesh Observation (Aggregated Health Dashboard)
        threading.Thread(target=self._monitor_mesh, daemon=True, name="MeshMonitor").start()

    def _monitor_mesh(self):
        """Periodically prints status of all nodes in the mesh."""
        while True:
            time.sleep(10)
            active_nodes = self.registry.list_nodes()
            print("\n" + "="*50)
            print(f"📡 CORTEX MESH DASHBOARD | {len(active_nodes)} Nodes Online")
            print("-" * 50)
            if not active_nodes:
                print("  No nodes currently connected.")
            for nid in active_nodes:
                node = self.registry.get_node(nid)
                stats = node.get("stats", {})
                tasks = stats.get("running", [])
                capability = node.get("metadata", {}).get("caps", {})
                print(f"  🟢 {nid:15} | Workers: {stats.get('active_worker_count', 0)} | Running: {len(tasks)} tasks")
                print(f"      Capabilities: {capability}")
            print("="*50 + "\n", flush=True)

    def _broadcast_work(self, _):
        """Pushes work notifications to all active nodes."""
        with self.registry.lock:
            for node_id, node in self.registry.nodes.items():
                print(f"    [📢] Broadcasting availability to {node_id}")
                node["queue"].put(agent_pb2.ServerTaskMessage(
                    work_pool_update=agent_pb2.WorkPoolUpdate(available_task_ids=self.pool.list_available())
                ))
    
    def SyncConfiguration(self, request, context):
        """Standard Handshake: Authenticate and Send Policy."""
        # Pre-registration for metadata search
        self.registry.register(request.node_id, queue.Queue(), {
            "desc": request.node_description, 
            "caps": dict(request.capabilities)
        })
        
        # 12-Factor Sandbox Policy (Standardized Mode)
        return agent_pb2.RegistrationResponse(
            success=True, 
            policy=agent_pb2.SandboxPolicy(
                mode=agent_pb2.SandboxPolicy.STRICT, 
                allowed_commands=["ls", "uname", "echo", "sleep"]
            )
        )

    def TaskStream(self, request_iterator, context):
        """Persistent Bi-directional Stream for Command & Control."""
        try:
            # 1. Blocking wait for Node Identity
            first_msg = next(request_iterator)
            if first_msg.WhichOneof('payload') != 'announce':
                print("[!] Stream rejected: No NodeAnnounce")
                return
            
            node_id = first_msg.announce.node_id
            node = self.registry.get_node(node_id)
            if not node:
                print(f"[!] Stream rejected: Node {node_id} not registered")
                return
            
            print(f"[📶] Stream Online for {node_id}")
            
            # Phase 5: Automatic Reconciliation on Reconnect
            self.assistant.reconcile_node(node_id)

            # 2. Results Listener (Read Thread)
            def _read_results():
                for msg in request_iterator:
                    self._handle_client_message(msg, node_id, node)
            
            threading.Thread(target=_read_results, daemon=True, name=f"Results-{node_id}").start()

            # 3. Work Dispatcher (Main Stream)
            last_keepalive = 0
            while context.is_active():
                try:
                    # Non-blocking wait to check context periodically
                    msg = node["queue"].get(timeout=1.0)
                    yield msg
                except queue.Empty:
                    # Occasional broadcast to nodes to ensure pool sync
                    now = time.time()
                    if (now - last_keepalive) > 10.0:
                        last_keepalive = now
                        if self.pool.available:
                            yield agent_pb2.ServerTaskMessage(
                                work_pool_update=agent_pb2.WorkPoolUpdate(available_task_ids=self.pool.list_available())
                            )
                    continue
                    
        except StopIteration: pass
        except Exception as e:
            print(f"[!] TaskStream Error for {node_id}: {e}")

    def _handle_client_message(self, msg, node_id, node):
        kind = msg.WhichOneof('payload')
        if kind == 'task_claim':
            task_id = msg.task_claim.task_id
            success, payload = self.pool.claim(task_id, node_id)
            
            # Send status response back to the node first
            node["queue"].put(agent_pb2.ServerTaskMessage(
                claim_status=agent_pb2.TaskClaimResponse(
                    task_id=task_id, 
                    granted=success, 
                    reason="Task successfully claimed" if success else "Task already claimed by another node"
                )
            ))

            if success:
                sig = sign_payload(payload)
                node["queue"].put(agent_pb2.ServerTaskMessage(
                    task_request=agent_pb2.TaskRequest(
                        task_id=task_id, 
                        payload_json=payload, 
                        signature=sig
                    )
                ))
        
        elif kind == 'task_response':
            res_obj = {"stdout": msg.task_response.stdout, "status": msg.task_response.status}
            if msg.task_response.HasField("browser_result"):
                br = msg.task_response.browser_result
                res_obj["browser"] = {
                    "url": br.url, "title": br.title, "has_snapshot": len(br.snapshot) > 0,
                    "a11y": br.a11y_tree[:100] + "..." if br.a11y_tree else None,
                    "eval": br.eval_result
                }
            self.journal.fulfill(msg.task_response.task_id, res_obj)
        
        elif kind == 'browser_event':
            e = msg.browser_event
            prefix = "[🖥️] Live Console" if e.HasField("console_msg") else "[🌐] Net Inspect"
            content = e.console_msg.text if e.HasField("console_msg") else f"{e.network_req.method} {e.network_req.url}"
            print(f"    {prefix}: {content}", flush=True)

        elif kind == 'file_sync':
            # Handle inbound file data from nodes (Node-Primary model)
            fs = msg.file_sync
            if fs.HasField("file_data"):
                print(f"    [📁📥] Mirroring {fs.file_data.path} (chunk {fs.file_data.chunk_index})")
                self.mirror.write_file_chunk(fs.session_id, fs.file_data)
                # BROADCAST to other nodes in the mesh
                self.assistant.broadcast_file_chunk(fs.session_id, node_id, fs.file_data)
            elif fs.HasField("status"):
                print(f"    [📁] Sync Status from {node_id}: {fs.status.message}")
                if fs.status.code == agent_pb2.SyncStatus.RECONCILE_REQUIRED:
                    print(f"    [📁🔄] Server triggering recovery sync for {len(fs.status.reconcile_paths)} files to {node_id}")
                    for path in fs.status.reconcile_paths:
                        self.assistant.push_file(node_id, fs.session_id, path)

    def ReportHealth(self, request_iterator, context):
        """Collect Health Metrics and Feed Policy Updates."""
        for hb in request_iterator:
            self.registry.update_stats(hb.node_id, {
                "active_worker_count": hb.active_worker_count, 
                "running": list(hb.running_task_ids)
            })
            yield agent_pb2.HealthCheckResponse(server_time_ms=int(time.time()*1000))