Newer
Older
cortex-hub / ai-hub / app / core / services / node_registry.py
"""
NodeRegistry Service — AI Hub Integration Layer
Manages live Agent Node registrations, their in-memory gRPC queues,
and a rich event bus for real-time UI streaming (split-pane terminal UX).
"""
import threading
import queue
from datetime import datetime
from typing import Dict, Optional, List, Any


# All event types emitted across the system — rendered in the live UI
EVENT_TYPES = {
    "node_online":    "đŸŸĸ Node Online",
    "node_offline":   "âšĢ Node Offline",
    "node_stale":     "🟡 Node Stale",
    "heartbeat":      "💓 Heartbeat",
    "task_assigned":  "đŸ“Ļ Task Assigned",
    "task_start":     "🚀 Task Starting",
    "task_stdout":    "📤 Output",
    "task_done":      "✅ Task Done",
    "task_error":     "❌ Task Error",
    "task_cancelled": "🛑 Task Cancelled",
    "task_retry":     "🔄 Retry",
    "sync_manifest":  "📋 Manifest Sync",
    "sync_push":      "📁 File Push",
    "sync_drift":     "âš ī¸ Drift Detected",
    "sync_recovery":  "đŸĨ Recovery",
    "sync_locked":    "🔒 Workspace Locked",
    "sync_unlocked":  "🔓 Workspace Unlocked",
    "info":           "â„šī¸ Info",
}


class LiveNodeRecord:
    """Represents a single connected Agent Node and its associated state."""
    def __init__(self, node_id: str, user_id: str, metadata: dict):
        self.node_id = node_id
        self.user_id = user_id           # Owner — maps node to a Hub user
        self.metadata = metadata         # desc, caps (capabilities dict)
        self.queue: queue.Queue = queue.Queue()   # gRPC outbound message queue
        self.stats: dict = {
            "active_worker_count": 0,
            "cpu_usage_percent": 0.0,
            "memory_usage_percent": 0.0,
            "running": [],
        }
        self.connected_at: datetime = datetime.utcnow()
        self.last_heartbeat_at: datetime = datetime.utcnow()

    def update_stats(self, stats: dict):
        self.stats.update(stats)
        self.last_heartbeat_at = datetime.utcnow()

    def to_dict(self) -> dict:
        return {
            "node_id": self.node_id,
            "user_id": self.user_id,
            "description": self.metadata.get("desc", ""),
            "capabilities": self.metadata.get("caps", {}),
            "stats": self.stats,
            "connected_at": self.connected_at.isoformat(),
            "last_heartbeat_at": self.last_heartbeat_at.isoformat(),
            "status": self._compute_status(),
        }

    def _compute_status(self) -> str:
        delta = (datetime.utcnow() - self.last_heartbeat_at).total_seconds()
        if delta > 30:
            return "stale"
        return "online"


class NodeRegistryService:
    """
    In-memory registry of live Agent Nodes, integrated into the FastAPI
    ServiceContainer.

    Provides:
    - Live node registration / deregistration
    - gRPC outbound queue per node
    - Rich event bus for real-time UI streaming:
        * Per-node stream  → single node execution pane
        * Per-user stream  → all-nodes global execution bus (multi-pane view)
    """

    def __init__(self):
        self._lock = threading.Lock()
        self._nodes: Dict[str, LiveNodeRecord] = {}
        # Per-node WS subscribers:  node_id  -> [queue, ...]
        self._node_listeners: Dict[str, List[queue.Queue]] = {}
        # Per-user WS subscribers:  user_id  -> [queue, ...]  (ALL nodes for that user)
        self._user_listeners: Dict[str, List[queue.Queue]] = {}

    # ------------------------------------------------------------------ #
    #  Registration                                                        #
    # ------------------------------------------------------------------ #

    def register(self, node_id: str, user_id: str, metadata: dict) -> LiveNodeRecord:
        """Register or re-register a node (called from gRPC SyncConfiguration)."""
        with self._lock:
            record = LiveNodeRecord(node_id=node_id, user_id=user_id, metadata=metadata)
            self._nodes[node_id] = record
        print(f"[📋] NodeRegistry: Registered {node_id} (owner: {user_id})")
        self.emit(node_id, "node_online", record.to_dict())
        return record

    def deregister(self, node_id: str):
        """Remove a node when its gRPC stream closes (called from TaskStream finally)."""
        with self._lock:
            node = self._nodes.pop(node_id, None)
            user_id = node.user_id if node else None
        self.emit(node_id, "node_offline", {"node_id": node_id, "user_id": user_id})
        print(f"[📋] NodeRegistry: Deregistered {node_id}")

    # ------------------------------------------------------------------ #
    #  Queries                                                             #
    # ------------------------------------------------------------------ #

    def get_node(self, node_id: str) -> Optional[LiveNodeRecord]:
        with self._lock:
            return self._nodes.get(node_id)

    def list_nodes(self, user_id: Optional[str] = None) -> List[LiveNodeRecord]:
        with self._lock:
            if user_id:
                return [n for n in self._nodes.values() if n.user_id == user_id]
            return list(self._nodes.values())

    def get_best(self, user_id: Optional[str] = None) -> Optional[str]:
        """Pick the least-loaded node for a given owner."""
        nodes = self.list_nodes(user_id=user_id)
        if not nodes:
            return None
        return sorted(nodes, key=lambda n: n.stats.get("active_worker_count", 999))[0].node_id

    def update_stats(self, node_id: str, stats: dict):
        with self._lock:
            node = self._nodes.get(node_id)
            if node:
                node.update_stats(stats)
        # Also emit heartbeat event to UI
        self.emit(node_id, "heartbeat", stats)

    # ------------------------------------------------------------------ #
    #  Event Bus                                                          #
    # ------------------------------------------------------------------ #

    def emit(self, node_id: str, event_type: str, data: Any = None, task_id: str = ""):
        """
        Emit a rich structured execution event.

        Delivered to:
          - Per-node WS subscribers → powers the single-node execution pane
          - Per-user WS subscribers → powers the global multi-node execution bus
        """
        with self._lock:
            node = self._nodes.get(node_id)
            user_id = node.user_id if node else ""
            node_qs = list(self._node_listeners.get(node_id, []))
            user_qs = list(self._user_listeners.get(user_id, [])) if user_id else []

        event = {
            "event": event_type,
            "label": EVENT_TYPES.get(event_type, event_type),
            "node_id": node_id,
            "user_id": user_id,
            "task_id": task_id,
            "timestamp": datetime.utcnow().isoformat(),
            "data": data,
        }
        # Deliver — avoid duplicates if same queue is in both lists
        seen = set()
        for q in node_qs + user_qs:
            if id(q) not in seen:
                seen.add(id(q))
                try:
                    q.put_nowait(event)
                except Exception:
                    pass

    # ------------------------------------------------------------------ #
    #  WS Subscriptions                                                   #
    # ------------------------------------------------------------------ #

    def subscribe_node(self, node_id: str, q: queue.Queue):
        """Subscribe to execution events for a specific node (single-pane view)."""
        with self._lock:
            self._node_listeners.setdefault(node_id, []).append(q)

    def unsubscribe_node(self, node_id: str, q: queue.Queue):
        with self._lock:
            lst = self._node_listeners.get(node_id, [])
            if q in lst:
                lst.remove(q)

    def subscribe_user(self, user_id: str, q: queue.Queue):
        """Subscribe to ALL node events for a user (multi-pane global bus)."""
        with self._lock:
            self._user_listeners.setdefault(user_id, []).append(q)

    def unsubscribe_user(self, user_id: str, q: queue.Queue):
        with self._lock:
            lst = self._user_listeners.get(user_id, [])
            if q in lst:
                lst.remove(q)