Newer
Older
cortex-hub / poc-grpc-agent / orchestrator / services / grpc_server.py
import threading
import queue
import time
from protos import agent_pb2, agent_pb2_grpc
from orchestrator.core.registry import MemoryNodeRegistry
from orchestrator.core.journal import TaskJournal
from orchestrator.core.pool import GlobalWorkPool
from orchestrator.services.assistant import TaskAssistant
from orchestrator.utils.crypto import sign_payload

class AgentOrchestrator(agent_pb2_grpc.AgentOrchestratorServicer):
    """Refactored gRPC Servicer for Agent Orchestration."""
    def __init__(self):
        self.registry = MemoryNodeRegistry()
        self.journal = TaskJournal()
        self.pool = GlobalWorkPool()
        self.assistant = TaskAssistant(self.registry, self.journal, self.pool)
        self.pool.on_new_work = self._broadcast_work

    def _broadcast_work(self, _):
        """Pushes work notifications to all active nodes."""
        with self.registry.lock:
            for node_id, node in self.registry.nodes.items():
                print(f"    [📢] Broadcasting availability to {node_id}")
                node["queue"].put(agent_pb2.ServerTaskMessage(
                    work_pool_update=agent_pb2.WorkPoolUpdate(available_task_ids=self.pool.list_available())
                ))
    
    def SyncConfiguration(self, request, context):
        """Standard Handshake: Authenticate and Send Policy."""
        # Pre-registration for metadata search
        self.registry.register(request.node_id, queue.Queue(), {
            "desc": request.node_description, 
            "caps": dict(request.capabilities)
        })
        
        # 12-Factor Sandbox Policy (Standardized Mode)
        return agent_pb2.RegistrationResponse(
            success=True, 
            policy=agent_pb2.SandboxPolicy(
                mode=agent_pb2.SandboxPolicy.STRICT, 
                allowed_commands=["ls", "uname", "echo", "sleep"]
            )
        )

    def TaskStream(self, request_iterator, context):
        """Persistent Bi-directional Stream for Command & Control."""
        try:
            # 1. Blocking wait for Node Identity
            first_msg = next(request_iterator)
            if first_msg.WhichOneof('payload') != 'announce':
                print("[!] Stream rejected: No NodeAnnounce")
                return
            
            node_id = first_msg.announce.node_id
            node = self.registry.get_node(node_id)
            if not node:
                print(f"[!] Stream rejected: Node {node_id} not registered")
                return
            
            print(f"[📶] Stream Online for {node_id}")

            # 2. Results Listener (Read Thread)
            def _read_results():
                for msg in request_iterator:
                    self._handle_client_message(msg, node_id, node)
            
            threading.Thread(target=_read_results, daemon=True, name=f"Results-{node_id}").start()

            # 3. Work Dispatcher (Main Stream)
            last_keepalive = 0
            while context.is_active():
                try:
                    # Non-blocking wait to check context periodically
                    msg = node["queue"].get(timeout=1.0)
                    yield msg
                except queue.Empty:
                    # Occasional broadcast to nodes to ensure pool sync
                    now = time.time()
                    if (now - last_keepalive) > 10.0:
                        last_keepalive = now
                        if self.pool.available:
                            yield agent_pb2.ServerTaskMessage(
                                work_pool_update=agent_pb2.WorkPoolUpdate(available_task_ids=self.pool.list_available())
                            )
                    continue
                    
        except StopIteration: pass
        except Exception as e:
            print(f"[!] TaskStream Error for {node_id}: {e}")

    def _handle_client_message(self, msg, node_id, node):
        kind = msg.WhichOneof('payload')
        if kind == 'task_claim':
            success, payload = self.pool.claim(msg.task_claim.task_id, node_id)
            if success:
                sig = sign_payload(payload)
                node["queue"].put(agent_pb2.ServerTaskMessage(
                    task_request=agent_pb2.TaskRequest(task_id=msg.task_claim.task_id, payload_json=payload, signature=sig)))
        
        elif kind == 'task_response':
            res_obj = {"stdout": msg.task_response.stdout, "status": msg.task_response.status}
            if msg.task_response.HasField("browser_result"):
                br = msg.task_response.browser_result
                res_obj["browser"] = {
                    "url": br.url, "title": br.title, "has_snapshot": len(br.snapshot) > 0,
                    "a11y": br.a11y_tree[:100] + "..." if br.a11y_tree else None,
                    "eval": br.eval_result
                }
            self.journal.fulfill(msg.task_response.task_id, res_obj)
        
        elif kind == 'browser_event':
            e = msg.browser_event
            prefix = "[🖥️] Live Console" if e.HasField("console_msg") else "[🌐] Net Inspect"
            content = e.console_msg.text if e.HasField("console_msg") else f"{e.network_req.method} {e.network_req.url}"
            print(f"    {prefix}: {content}", flush=True)

    def ReportHealth(self, request_iterator, context):
        """Collect Health Metrics and Feed Policy Updates."""
        for hb in request_iterator:
            self.registry.update_stats(hb.node_id, {
                "active_worker_count": hb.active_worker_count, 
                "running": list(hb.running_task_ids)
            })
            yield agent_pb2.HealthCheckResponse(server_time_ms=int(time.time()*1000))