import threading
import queue
import time
from protos import agent_pb2, agent_pb2_grpc
from orchestrator.core.registry import MemoryNodeRegistry
from orchestrator.core.journal import TaskJournal
from orchestrator.core.pool import GlobalWorkPool
from orchestrator.services.assistant import TaskAssistant
from orchestrator.utils.crypto import sign_payload
class AgentOrchestrator(agent_pb2_grpc.AgentOrchestratorServicer):
"""Refactored gRPC Servicer for Agent Orchestration."""
def __init__(self):
self.registry = MemoryNodeRegistry()
self.journal = TaskJournal()
self.pool = GlobalWorkPool()
self.assistant = TaskAssistant(self.registry, self.journal, self.pool)
self.pool.on_new_work = self._broadcast_work
def _broadcast_work(self, _):
"""Pushes work notifications to all active nodes."""
with self.registry.lock:
for node_id, node in self.registry.nodes.items():
print(f" [📢] Broadcasting availability to {node_id}")
node["queue"].put(agent_pb2.ServerTaskMessage(
work_pool_update=agent_pb2.WorkPoolUpdate(available_task_ids=self.pool.list_available())
))
def SyncConfiguration(self, request, context):
"""Standard Handshake: Authenticate and Send Policy."""
# Pre-registration for metadata search
self.registry.register(request.node_id, queue.Queue(), {
"desc": request.node_description,
"caps": dict(request.capabilities)
})
# 12-Factor Sandbox Policy (Standardized Mode)
return agent_pb2.RegistrationResponse(
success=True,
policy=agent_pb2.SandboxPolicy(
mode=agent_pb2.SandboxPolicy.STRICT,
allowed_commands=["ls", "uname", "echo", "sleep"]
)
)
def TaskStream(self, request_iterator, context):
"""Persistent Bi-directional Stream for Command & Control."""
try:
# 1. Blocking wait for Node Identity
first_msg = next(request_iterator)
if first_msg.WhichOneof('payload') != 'announce':
print("[!] Stream rejected: No NodeAnnounce")
return
node_id = first_msg.announce.node_id
node = self.registry.get_node(node_id)
if not node:
print(f"[!] Stream rejected: Node {node_id} not registered")
return
print(f"[📶] Stream Online for {node_id}")
# 2. Results Listener (Read Thread)
def _read_results():
for msg in request_iterator:
self._handle_client_message(msg, node_id, node)
threading.Thread(target=_read_results, daemon=True, name=f"Results-{node_id}").start()
# 3. Work Dispatcher (Main Stream)
last_keepalive = 0
while context.is_active():
try:
# Non-blocking wait to check context periodically
msg = node["queue"].get(timeout=1.0)
yield msg
except queue.Empty:
# Occasional broadcast to nodes to ensure pool sync
now = time.time()
if (now - last_keepalive) > 10.0:
last_keepalive = now
if self.pool.available:
yield agent_pb2.ServerTaskMessage(
work_pool_update=agent_pb2.WorkPoolUpdate(available_task_ids=self.pool.list_available())
)
continue
except StopIteration: pass
except Exception as e:
print(f"[!] TaskStream Error for {node_id}: {e}")
def _handle_client_message(self, msg, node_id, node):
kind = msg.WhichOneof('payload')
if kind == 'task_claim':
success, payload = self.pool.claim(msg.task_claim.task_id, node_id)
if success:
sig = sign_payload(payload)
node["queue"].put(agent_pb2.ServerTaskMessage(
task_request=agent_pb2.TaskRequest(task_id=msg.task_claim.task_id, payload_json=payload, signature=sig)))
elif kind == 'task_response':
res_obj = {"stdout": msg.task_response.stdout, "status": msg.task_response.status}
if msg.task_response.HasField("browser_result"):
br = msg.task_response.browser_result
res_obj["browser"] = {
"url": br.url, "title": br.title, "has_snapshot": len(br.snapshot) > 0,
"a11y": br.a11y_tree[:100] + "..." if br.a11y_tree else None,
"eval": br.eval_result
}
self.journal.fulfill(msg.task_response.task_id, res_obj)
elif kind == 'browser_event':
e = msg.browser_event
prefix = "[🖥️] Live Console" if e.HasField("console_msg") else "[🌐] Net Inspect"
content = e.console_msg.text if e.HasField("console_msg") else f"{e.network_req.method} {e.network_req.url}"
print(f" {prefix}: {content}", flush=True)
def ReportHealth(self, request_iterator, context):
"""Collect Health Metrics and Feed Policy Updates."""
for hb in request_iterator:
self.registry.update_stats(hb.node_id, {
"active_worker_count": hb.active_worker_count,
"running": list(hb.running_task_ids)
})
yield agent_pb2.HealthCheckResponse(server_time_ms=int(time.time()*1000))