Newer
Older
cortex-hub / poc-grpc-agent / server.py
import grpc
import os
from concurrent import futures
import time
import agent_pb2
import agent_pb2_grpc
import threading
import queue
import jwt
import hmac
import hashlib
import json

SECRET_KEY = "cortex-secret-shared-key" 

class TaskJournal:
    """State machine for tracking tasks through their lifecycle."""
    def __init__(self):
        self.lock = threading.Lock()
        self.tasks = {} # task_id -> { "event": Event, "result": None, "node_id": str }

    def register(self, task_id, node_id=None):
        event = threading.Event()
        with self.lock:
            self.tasks[task_id] = {"event": event, "result": None, "node_id": node_id}
        return event

    def fulfill(self, task_id, result):
        with self.lock:
            if task_id in self.tasks:
                self.tasks[task_id]["result"] = result
                self.tasks[task_id]["event"].set()
                return True
        return False

    def get_result(self, task_id):
        with self.lock:
            data = self.tasks.get(task_id)
            return data["result"] if data else None

    def pop(self, task_id):
        with self.lock:
            return self.tasks.pop(task_id, None)

class AbstractNodeRegistry:
    """Interface for finding and tracking Managers."""
    def register(self, node_id, data): raise NotImplementedError
    def update_stats(self, node_id, stats): raise NotImplementedError
    def get_best(self): raise NotImplementedError
    def get_node(self, node_id): raise NotImplementedError

class MemoryNodeRegistry(AbstractNodeRegistry):
    def __init__(self):
        self.lock = threading.Lock()
        self.nodes = {} # node_id -> { stats: {}, queue: queue, metadata: {} }

    def register(self, node_id, q, metadata):
        with self.lock:
            self.nodes[node_id] = {"stats": {}, "queue": q, "metadata": metadata}
            print(f"[📋] Registered: {node_id}")

    def update_stats(self, node_id, stats):
        with self.lock:
            if node_id in self.nodes: self.nodes[node_id]["stats"].update(stats)

    def get_best(self):
        with self.lock:
            if not self.nodes: return None
            # Pick based on active worker count
            return sorted(self.nodes.items(), key=lambda x: x[1]["stats"].get("active_worker_count", 999))[0][0]

    def get_node(self, node_id):
        with self.lock: return self.nodes.get(node_id)

class GlobalWorkPool:
    def __init__(self):
        self.lock = threading.Lock()
        self.available = {"shared-001": '{"command": "uname -a"}', "shared-002": '{"command": "uptime"}'}
    
    def claim(self, task_id, node_id):
        with self.lock:
            if task_id in self.available:
                return True, self.available.pop(task_id)
            return False, None

class TaskAssistant:
    """The High-Level AI API."""
    def __init__(self, registry, journal, pool):
        self.registry = registry
        self.journal = journal
        self.pool = pool

    def dispatch_single(self, node_id, cmd, timeout=30):
        # Implementation of retry logic and signing
        node = self.registry.get_node(node_id)
        if not node: return {"error": "Offline"}

        tid = f"task-{int(time.time()*1000)}"
        event = self.journal.register(tid, node_id)
        
        sig = hmac.new(SECRET_KEY.encode(), json.dumps({"command": cmd}).encode(), hashlib.sha256).hexdigest()
        req = agent_pb2.ServerTaskMessage(task_request=agent_pb2.TaskRequest(
            task_id=tid, payload_json=json.dumps({"command": cmd}), signature=sig))
        
        node["queue"].put(req)
        
        if event.wait(timeout):
            res = self.journal.get_result(tid)
            self.journal.pop(tid)
            return res
        self.journal.pop(tid)
        return {"error": "Timeout"}

    def dispatch_browser(self, node_id, action, timeout=60):
        node = self.registry.get_node(node_id)
        if not node: return {"error": "Offline"}

        tid = f"br-{int(time.time()*1000)}"
        event = self.journal.register(tid, node_id)

        # Basic signature for POC: Sign the action enum name + URL
        sign_base = f"{action.action}:{action.url}:{action.session_id}"
        sig = hmac.new(SECRET_KEY.encode(), sign_base.encode(), hashlib.sha256).hexdigest()

        req = agent_pb2.ServerTaskMessage(task_request=agent_pb2.TaskRequest(
            task_id=tid, browser_action=action, signature=sig))
        
        node["queue"].put(req)
        if event.wait(timeout):
            res = self.journal.get_result(tid)
            self.journal.pop(tid)
            return res
        self.journal.pop(tid)
        return {"error": "Timeout"}

class AgentOrchestrator(agent_pb2_grpc.AgentOrchestratorServicer):
    def __init__(self):
        self.registry = MemoryNodeRegistry()
        self.journal = TaskJournal()
        self.pool = GlobalWorkPool()
        self.assistant = TaskAssistant(self.registry, self.journal, self.pool)

    def SyncConfiguration(self, request, context):
        # Pre-registration for metadata search
        self.registry.register(request.node_id, queue.Queue(), {"desc": request.node_description, "caps": dict(request.capabilities)})
        return agent_pb2.RegistrationResponse(success=True, 
                                             policy=agent_pb2.SandboxPolicy(mode=agent_pb2.SandboxPolicy.STRICT, 
                                             allowed_commands=["ls", "uname", "echo", "sleep"]))

    def TaskStream(self, request_iterator, context):
        node_id = None
        
        def _read():
            nonlocal node_id
            for msg in request_iterator:
                kind = msg.WhichOneof('payload')
                if kind == 'task_claim':
                    node_id = msg.task_claim.node_id
                    success, payload = self.pool.claim(msg.task_claim.task_id, node_id)
                    if success:
                        sig = hmac.new(SECRET_KEY.encode(), payload.encode(), hashlib.sha256).hexdigest()
                        self.registry.get_node(node_id)["queue"].put(agent_pb2.ServerTaskMessage(
                            task_request=agent_pb2.TaskRequest(task_id=msg.task_claim.task_id, payload_json=payload, signature=sig)))
                elif kind == 'task_response':
                    res_obj = {"stdout": msg.task_response.stdout, "status": msg.task_response.status}
                    if msg.task_response.HasField("browser_result"):
                        br = msg.task_response.browser_result
                        res_obj["browser"] = {"url": br.url, "title": br.title, "has_snapshot": len(br.snapshot) > 0}
                    self.journal.fulfill(msg.task_response.task_id, res_obj)
        
        threading.Thread(target=_read, daemon=True).start()

        while context.is_active():
            # Broadcast pool
            if self.pool.available:
                yield agent_pb2.ServerTaskMessage(work_pool_update=agent_pb2.WorkPoolUpdate(available_task_ids=list(self.pool.available.keys())))
            
            # Send direct tasks
            if node_id and self.registry.get_node(node_id):
                try:
                    msg = self.registry.get_node(node_id)["queue"].get(timeout=2)
                    yield msg
                except queue.Empty: pass
            else: time.sleep(1)

    def ReportHealth(self, request_iterator, context):
        for hb in request_iterator:
            self.registry.update_stats(hb.node_id, {"active_worker_count": hb.active_worker_count, "running": list(hb.running_task_ids)})
            yield agent_pb2.HealthCheckResponse(server_time_ms=int(time.time()*1000))

def serve():
    with open('certs/server.key', 'rb') as f: pkey = f.read()
    with open('certs/server.crt', 'rb') as f: cert = f.read()
    with open('certs/ca.crt', 'rb') as f: ca = f.read()
    creds = grpc.ssl_server_credentials([(pkey, cert)], ca, True)
    server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
    orch = AgentOrchestrator()
    agent_pb2_grpc.add_AgentOrchestratorServicer_to_server(orch, server)
    server.add_secure_port('[::]:50051', creds)
    
    print("[🛡️] Boss Plane Refactored & Online.")
    server.start()
    
    # Simple AI Simulation loop
    time.sleep(10)
    print("\n[🧠] AI Simulation Start...")
    print(f"    Whoami: {orch.assistant.dispatch_single('agent-node-007', 'whoami')}")
    
    # NEW: Browser Phase
    print("\n[🧠] AI Phase 4: Navigating Browser (Antigravity Bridge)...")
    nav_action = agent_pb2.BrowserAction(
        action=agent_pb2.BrowserAction.NAVIGATE, 
        url="https://example.com", 
        session_id="antigravity-session-1"
    )
    res_nav = orch.assistant.dispatch_browser("agent-node-007", nav_action)
    print(f"    Nav Result: {res_nav}")

    print("\n[🧠] AI Phase 5: Multi-Action Persistence (Screenshot)...")
    snap_action = agent_pb2.BrowserAction(
        action=agent_pb2.BrowserAction.SCREENSHOT,
        session_id="antigravity-session-1"
    )
    res_snap = orch.assistant.dispatch_browser("agent-node-007", snap_action)
    print(f"    Snap Result: {res_snap.get('browser', {}).get('title')} | Snapshot captured: {res_snap.get('browser', {}).get('has_snapshot')}")

    server.wait_for_termination()

if __name__ == '__main__':
    serve()