import threading
import queue
import time
import os
try:
import requests as _requests # optional; only needed for M4 token validation
except ImportError:
_requests = None
from protos import agent_pb2, agent_pb2_grpc
from orchestrator.core.registry import MemoryNodeRegistry
from orchestrator.core.journal import TaskJournal
from orchestrator.core.pool import GlobalWorkPool
from orchestrator.core.mirror import GhostMirrorManager
from orchestrator.services.assistant import TaskAssistant
from orchestrator.utils.crypto import sign_payload
# M4: Hub HTTP API for invite-token validation
# Calls POST /nodes/validate-token before accepting any SyncConfiguration.
# Set HUB_API_URL=http://localhost:8000 (or 0 to skip validation in dev mode).
HUB_API_URL = os.getenv("HUB_API_URL", "") # empty = skip validation (dev)
HUB_API_PATH = "/nodes/validate-token"
class AgentOrchestrator(agent_pb2_grpc.AgentOrchestratorServicer):
"""Refactored gRPC Servicer for Agent Orchestration."""
def __init__(self):
self.registry = MemoryNodeRegistry()
self.journal = TaskJournal()
self.pool = GlobalWorkPool()
self.mirror = GhostMirrorManager()
self.assistant = TaskAssistant(self.registry, self.journal, self.pool, self.mirror)
self.pool.on_new_work = self._broadcast_work
# 4. Mesh Observation (Aggregated Health Dashboard)
threading.Thread(target=self._monitor_mesh, daemon=True, name="MeshMonitor").start()
def _monitor_mesh(self):
"""Periodically prints status of all nodes in the mesh."""
while True:
time.sleep(10)
active_nodes = self.registry.list_nodes()
print("\n" + "="*50)
print(f"📡 CORTEX MESH DASHBOARD | {len(active_nodes)} Nodes Online")
print("-" * 50)
if not active_nodes:
print(" No nodes currently connected.")
for nid in active_nodes:
node = self.registry.get_node(nid)
stats = node.get("stats", {})
tasks = stats.get("running", [])
capability = node.get("metadata", {}).get("caps", {})
print(f" 🟢 {nid:15} | Workers: {stats.get('active_worker_count', 0)} | Running: {len(tasks)} tasks")
print(f" Capabilities: {capability}")
print("="*50 + "\n", flush=True)
def _broadcast_work(self, _):
"""Pushes work notifications to all active nodes."""
with self.registry.lock:
for node_id, node in self.registry.nodes.items():
print(f" [📢] Broadcasting availability to {node_id}")
node["queue"].put(agent_pb2.ServerTaskMessage(
work_pool_update=agent_pb2.WorkPoolUpdate(available_task_ids=self.pool.list_available())
))
def SyncConfiguration(self, request, context):
"""M4 Authenticated Handshake: Validate invite_token, then send policy."""
node_id = request.node_id
invite_token = request.auth_token # field in RegistrationRequest proto
# --- M4: Token validation via Hub API ---
if HUB_API_URL and _requests:
try:
resp = _requests.post(
f"{HUB_API_URL}{HUB_API_PATH}",
params={"node_id": node_id, "token": invite_token},
timeout=5,
)
payload = resp.json()
if not payload.get("valid"):
reason = payload.get("reason", "Token rejected")
print(f"[🔒] SyncConfiguration REJECTED {node_id}: {reason}")
return agent_pb2.RegistrationResponse(
success=False,
message=reason,
)
skill_cfg = payload.get("skill_config", {})
print(f"[🔑] Token validated for {node_id} (display: {payload.get('display_name')})")
except Exception as e:
# If Hub is unreachable in dev, fall through with a warning
print(f"[⚠️] Hub token validation unavailable ({e}); proceeding without auth.")
skill_cfg = {}
else:
# Dev mode: skip validation
skill_cfg = {}
print(f"[⚠️] HUB_API_URL not set — skipping invite_token validation for {node_id}")
# Build allowed_commands from skill_config (shell skill)
shell_cfg = skill_cfg.get("shell", {})
if shell_cfg.get("enabled", True):
allowed_commands = ["ls", "cat", "echo", "pwd", "uname", "curl", "python3", "git"]
else:
allowed_commands = [] # Shell disabled by admin
# Register the node in the local in-memory registry
self.registry.register(request.node_id, queue.Queue(), {
"desc": request.node_description,
"caps": dict(request.capabilities),
})
return agent_pb2.RegistrationResponse(
success=True,
policy=agent_pb2.SandboxPolicy(
mode=agent_pb2.SandboxPolicy.STRICT,
allowed_commands=allowed_commands,
)
)
def TaskStream(self, request_iterator, context):
"""Persistent Bi-directional Stream for Command & Control."""
try:
# 1. Blocking wait for Node Identity
first_msg = next(request_iterator)
if first_msg.WhichOneof('payload') != 'announce':
print("[!] Stream rejected: No NodeAnnounce")
return
node_id = first_msg.announce.node_id
node = self.registry.get_node(node_id)
if not node:
print(f"[!] Stream rejected: Node {node_id} not registered")
return
print(f"[📶] Stream Online for {node_id}")
# Phase 5: Automatic Reconciliation on Reconnect
self.assistant.reconcile_node(node_id)
# 2. Results Listener (Read Thread)
def _read_results():
for msg in request_iterator:
self._handle_client_message(msg, node_id, node)
threading.Thread(target=_read_results, daemon=True, name=f"Results-{node_id}").start()
# 3. Work Dispatcher (Main Stream)
last_keepalive = 0
while context.is_active():
try:
# Non-blocking wait to check context periodically
msg = node["queue"].get(timeout=1.0)
yield msg
except queue.Empty:
# Occasional broadcast to nodes to ensure pool sync
now = time.time()
if (now - last_keepalive) > 10.0:
last_keepalive = now
if self.pool.available:
yield agent_pb2.ServerTaskMessage(
work_pool_update=agent_pb2.WorkPoolUpdate(available_task_ids=self.pool.list_available())
)
continue
except StopIteration: pass
except Exception as e:
print(f"[!] TaskStream Error for {node_id}: {e}")
def _handle_client_message(self, msg, node_id, node):
kind = msg.WhichOneof('payload')
if kind == 'task_claim':
task_id = msg.task_claim.task_id
success, payload = self.pool.claim(task_id, node_id)
# Send status response back to the node first
node["queue"].put(agent_pb2.ServerTaskMessage(
claim_status=agent_pb2.TaskClaimResponse(
task_id=task_id,
granted=success,
reason="Task successfully claimed" if success else "Task already claimed by another node"
)
))
# M6: Notify UI that a node is claiming a global task
self.registry.emit(node_id, "task_claim", {"task_id": task_id, "granted": success})
if success:
sig = sign_payload(payload)
node["queue"].put(agent_pb2.ServerTaskMessage(
task_request=agent_pb2.TaskRequest(
task_id=task_id,
payload_json=payload,
signature=sig
)
))
elif kind == 'task_response':
tr = msg.task_response
res_obj = {"stdout": tr.stdout, "status": tr.status}
if tr.HasField("browser_result"):
br = tr.browser_result
res_obj["browser"] = {
"url": br.url, "title": br.title, "has_snapshot": len(br.snapshot) > 0,
"eval": br.eval_result
}
self.journal.fulfill(tr.task_id, res_obj)
# M6: Emit to EventBus for UI streaming
event_type = "task_complete" if tr.status == agent_pb2.TaskResponse.SUCCESS else "task_error"
self.registry.emit(node_id, event_type, res_obj, task_id=tr.task_id)
elif kind == 'browser_event':
e = msg.browser_event
event_data = {}
if e.HasField("console_msg"):
event_data = {"type": "console", "text": e.console_msg.text, "level": e.console_msg.level}
elif e.HasField("network_req"):
event_data = {"type": "network", "method": e.network_req.method, "url": e.network_req.url}
# M6: Stream live browser logs to UI
self.registry.emit(node_id, "browser_event", event_data)
elif kind == 'file_sync':
fs = msg.file_sync
if fs.HasField("file_data"):
self.mirror.write_file_chunk(fs.session_id, fs.file_data)
self.assistant.broadcast_file_chunk(fs.session_id, node_id, fs.file_data)
# M6: Emit sync progress (rarely to avoid flood, but good for large pushes)
if fs.file_data.chunk_index % 10 == 0:
self.registry.emit(node_id, "sync_progress", {"path": fs.file_data.path, "chunk": fs.file_data.chunk_index})
elif fs.HasField("status"):
print(f" [📁] Sync Status from {node_id}: {fs.status.message}")
self.registry.emit(node_id, "sync_status", {"message": fs.status.message, "code": fs.status.code})
if fs.status.code == agent_pb2.SyncStatus.RECONCILE_REQUIRED:
for path in fs.status.reconcile_paths:
self.assistant.push_file(node_id, fs.session_id, path)
def ReportHealth(self, request_iterator, context):
"""Collect Health Metrics and Feed Policy Updates."""
for hb in request_iterator:
self.registry.update_stats(hb.node_id, {
"active_worker_count": hb.active_worker_count,
"running": list(hb.running_task_ids)
})
yield agent_pb2.HealthCheckResponse(server_time_ms=int(time.time()*1000))