Newer
Older
cortex-hub / ai-hub / app / core / services / node_registry.py
"""
NodeRegistry Service — AI Hub Integration Layer
Manages live Agent Node registrations, their in-memory gRPC queues,
and a rich event bus for real-time UI streaming (split-pane terminal UX).

Persistence strategy (M2):
- The in-memory dict is the live connection cache (fast, gRPC-queue holder).
- The DB (via get_db_session) is the source of truth for node identity across reboots.
- On every register/deregister/heartbeat, we upsert the AgentNode DB record.
- On Hub startup all DB nodes are "offline"; they go "online" when they reconnect.
"""
import threading
import queue
from datetime import datetime
from typing import Dict, Optional, List, Any


# All event types emitted across the system — rendered in the live UI
EVENT_TYPES = {
    "node_online":    "đŸŸĸ Node Online",
    "node_offline":   "âšĢ Node Offline",
    "node_stale":     "🟡 Node Stale",
    "heartbeat":      "💓 Heartbeat",
    "task_assigned":  "đŸ“Ļ Task Assigned",
    "task_start":     "🚀 Task Starting",
    "task_stdout":    "📤 Output",
    "task_done":      "✅ Task Done",
    "task_error":     "❌ Task Error",
    "task_cancelled": "🛑 Task Cancelled",
    "task_retry":     "🔄 Retry",
    "sync_manifest":  "📋 Manifest Sync",
    "sync_push":      "📁 File Push",
    "sync_drift":     "âš ī¸ Drift Detected",
    "sync_recovery":  "đŸĨ Recovery",
    "sync_locked":    "🔒 Workspace Locked",
    "sync_unlocked":  "🔓 Workspace Unlocked",
    "info":           "â„šī¸ Info",
}


class LiveNodeRecord:
    """Represents a single connected Agent Node and its associated live state."""
    def __init__(self, node_id: str, user_id: str, metadata: dict):
        self.node_id = node_id
        self.user_id = user_id           # Owner — maps node to a Hub user
        self.metadata = metadata         # desc, caps (capabilities dict)
        self.queue: queue.Queue = queue.Queue()   # gRPC outbound message queue
        self.stats: dict = {
            "active_worker_count": 0,
            "cpu_usage_percent": 0.0,
            "memory_usage_percent": 0.0,
            "running": [],
        }
        self.connected_at: datetime = datetime.utcnow()
        self.last_heartbeat_at: datetime = datetime.utcnow()

    def update_stats(self, stats: dict):
        self.stats.update(stats)
        self.last_heartbeat_at = datetime.utcnow()

    def to_dict(self) -> dict:
        return {
            "node_id": self.node_id,
            "user_id": self.user_id,
            "description": self.metadata.get("desc", ""),
            "capabilities": self.metadata.get("caps", {}),
            "stats": self.stats,
            "connected_at": self.connected_at.isoformat(),
            "last_heartbeat_at": self.last_heartbeat_at.isoformat(),
            "status": self._compute_status(),
        }

    def _compute_status(self) -> str:
        delta = (datetime.utcnow() - self.last_heartbeat_at).total_seconds()
        if delta > 30:
            return "stale"
        return "online"


class NodeRegistryService:
    """
    Persistent + in-memory registry of live Agent Nodes.

    Two-tier storage:
      Tier 1 — In-memory (_nodes dict):  live connections, gRPC queues, real-time stats.
      Tier 2 — Database (AgentNode model): node identity, capabilities, invite_token,
                                           last_status, last_seen_at survive Hub restarts.

    When a node reconnects after a Hub restart it calls SyncConfiguration again,
    which calls register() again → the DB record is updated to 'online'.
    Nodes that haven't reconnected stay in the DB with 'offline' status so the UI
    can still show them as known (but disconnected) nodes.
    """

    def __init__(self):
        self._lock = threading.Lock()
        self._nodes: Dict[str, LiveNodeRecord] = {}
        # Per-node WS subscribers:  node_id  -> [queue, ...]
        self._node_listeners: Dict[str, List[queue.Queue]] = {}
        # Per-user WS subscribers:  user_id  -> [queue, ...] (ALL nodes for that user)
        self._user_listeners: Dict[str, List[queue.Queue]] = {}

    # ------------------------------------------------------------------ #
    #  DB Helpers                                                          #
    # ------------------------------------------------------------------ #

    def _db_upsert_node(self, node_id: str, user_id: str, metadata: dict):
        """Update the AgentNode DB record on connect. The admin must pre-create the node."""
        from app.db.models import AgentNode
        from app.db.session import get_db_session
        try:
            with get_db_session() as db:
                record = db.query(AgentNode).filter(AgentNode.node_id == node_id).first()
                if record:
                    record.capabilities = metadata.get("caps", {})
                    record.last_status = "online"
                    record.last_seen_at = datetime.utcnow()
                else:
                    # Node not pre-registered by admin — log warning but don't crash
                    print(f"[NodeRegistry] WARNING: Node '{node_id}' connected but has no DB record. Admin must register it first.")
        except Exception as e:
            print(f"[NodeRegistry] DB upsert failed for {node_id}: {e}")

    def _db_mark_offline(self, node_id: str):
        """Update last_seen_at and mark last_status = 'offline' on disconnect."""
        from app.db.models import AgentNode
        from app.db.session import get_db_session
        try:
            with get_db_session() as db:
                record = db.query(AgentNode).filter(AgentNode.node_id == node_id).first()
                if record:
                    record.last_status = "offline"
                    record.last_seen_at = datetime.utcnow()
        except Exception as e:
            print(f"[NodeRegistry] DB mark-offline failed for {node_id}: {e}")

    def _db_update_heartbeat(self, node_id: str):
        """Bump last_seen_at on each heartbeat so we know when the node was last active."""
        from app.db.models import AgentNode
        from app.db.session import get_db_session
        try:
            with get_db_session() as db:
                record = db.query(AgentNode).filter(AgentNode.node_id == node_id).first()
                if record:
                    record.last_seen_at = datetime.utcnow()
                    record.last_status = "online"
        except Exception as e:
            print(f"[NodeRegistry] DB heartbeat update failed for {node_id}: {e}")


    # ------------------------------------------------------------------ #
    #  Registration                                                        #
    # ------------------------------------------------------------------ #

    def register(self, node_id: str, user_id: str, metadata: dict) -> LiveNodeRecord:
        """
        Register or re-register a node.
        Called from gRPC SyncConfiguration on every node connect/reconnect.
        Persists to DB so the node survives Hub restarts as a known entity.
        """
        with self._lock:
            record = LiveNodeRecord(node_id=node_id, user_id=user_id, metadata=metadata)
            self._nodes[node_id] = record

        # Persist to DB (background-safe — session is scoped)
        self._db_upsert_node(node_id, user_id, metadata)

        print(f"[📋] NodeRegistry: Registered {node_id} (owner: {user_id})")
        self.emit(node_id, "node_online", record.to_dict())
        return record

    def deregister(self, node_id: str):
        """
        Remove a node from live_registry when its gRPC stream closes.
        The DB record is kept with last_status='offline' so the user can
        still see the node in their list (as disconnected).
        """
        with self._lock:
            node = self._nodes.pop(node_id, None)
            user_id = node.user_id if node else None

        self._db_mark_offline(node_id)
        self.emit(node_id, "node_offline", {"node_id": node_id, "user_id": user_id})
        print(f"[📋] NodeRegistry: Deregistered {node_id}")

    # ------------------------------------------------------------------ #
    #  Queries                                                             #
    # ------------------------------------------------------------------ #

    def get_node(self, node_id: str) -> Optional[LiveNodeRecord]:
        """Returns a live record only if the node is currently connected."""
        with self._lock:
            return self._nodes.get(node_id)

    def list_nodes(self, user_id: Optional[str] = None) -> List[LiveNodeRecord]:
        """Returns only currently LIVE nodes (use the DB for the full list)."""
        with self._lock:
            if user_id:
                return [n for n in self._nodes.values() if n.user_id == user_id]
            return list(self._nodes.values())

    def get_best(self, user_id: Optional[str] = None) -> Optional[str]:
        """Pick the least-loaded live node for a given owner."""
        nodes = self.list_nodes(user_id=user_id)
        if not nodes:
            return None
        return sorted(nodes, key=lambda n: n.stats.get("active_worker_count", 999))[0].node_id

    def update_stats(self, node_id: str, stats: dict):
        """Called every heartbeat interval. Updates in-memory stats and bumps DB last_seen."""
        with self._lock:
            node = self._nodes.get(node_id)
            if node:
                node.update_stats(stats)
        # Persist heartbeat timestamp to DB (throttle: already ~10s cadence from node)
        self._db_update_heartbeat(node_id)
        # Emit heartbeat event to live UI
        self.emit(node_id, "heartbeat", stats)

    # ------------------------------------------------------------------ #
    #  Event Bus                                                          #
    # ------------------------------------------------------------------ #

    def emit(self, node_id: str, event_type: str, data: Any = None, task_id: str = ""):
        """
        Emit a rich structured execution event.
        Delivered to:
          - Per-node WS subscribers → powers the single-node execution pane
          - Per-user WS subscribers → powers the global multi-node execution bus
        """
        with self._lock:
            node = self._nodes.get(node_id)
            user_id = node.user_id if node else ""
            node_qs = list(self._node_listeners.get(node_id, []))
            user_qs = list(self._user_listeners.get(user_id, [])) if user_id else []

        event = {
            "event": event_type,
            "label": EVENT_TYPES.get(event_type, event_type),
            "node_id": node_id,
            "user_id": user_id,
            "task_id": task_id,
            "timestamp": datetime.utcnow().isoformat(),
            "data": data,
        }
        seen = set()
        for q in node_qs + user_qs:
            if id(q) not in seen:
                seen.add(id(q))
                try:
                    q.put_nowait(event)
                except Exception:
                    pass

    # ------------------------------------------------------------------ #
    #  WS Subscriptions                                                   #
    # ------------------------------------------------------------------ #

    def subscribe_node(self, node_id: str, q: queue.Queue):
        """Subscribe to execution events for a specific node (single-pane view)."""
        with self._lock:
            self._node_listeners.setdefault(node_id, []).append(q)

    def unsubscribe_node(self, node_id: str, q: queue.Queue):
        with self._lock:
            lst = self._node_listeners.get(node_id, [])
            if q in lst:
                lst.remove(q)

    def subscribe_user(self, user_id: str, q: queue.Queue):
        """Subscribe to ALL node events for a user (multi-pane global bus)."""
        with self._lock:
            self._user_listeners.setdefault(user_id, []).append(q)

    def unsubscribe_user(self, user_id: str, q: queue.Queue):
        with self._lock:
            lst = self._user_listeners.get(user_id, [])
            if q in lst:
                lst.remove(q)