diff --git a/ai-hub/app/api/routes/api.py b/ai-hub/app/api/routes/api.py index 161df72..efa4a46 100644 --- a/ai-hub/app/api/routes/api.py +++ b/ai-hub/app/api/routes/api.py @@ -8,6 +8,7 @@ from .general import create_general_router from .stt import create_stt_router from .user import create_users_router +from .nodes import create_nodes_router def create_api_router(services: ServiceContainer) -> APIRouter: """ @@ -23,5 +24,6 @@ router.include_router(create_tts_router(services)) router.include_router(create_stt_router(services)) router.include_router(create_users_router(services)) + router.include_router(create_nodes_router(services)) return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/nodes.py b/ai-hub/app/api/routes/nodes.py new file mode 100644 index 0000000..709efaf --- /dev/null +++ b/ai-hub/app/api/routes/nodes.py @@ -0,0 +1,274 @@ +""" +Agent Node REST + WebSocket API +Exposes the live node registry and execution event bus to the AI Hub UI. + +Endpoints: + GET /nodes — List all nodes for the given user_id + GET /nodes/{node_id} — Full live node status + GET /nodes/{node_id}/status — Quick online/offline probe + POST /nodes/{node_id}/dispatch — Dispatch a task to a node + WS /nodes/{node_id}/stream — Live event stream for ONE node + WS /nodes/stream/all?user_id=... — Live event stream for ALL user's nodes +""" +import asyncio +import json +import queue +import uuid +import logging +from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect, Depends +from sqlalchemy.orm import Session + +from app.api.dependencies import ServiceContainer, get_db +from app.api import schemas +from app.db import models + +logger = logging.getLogger(__name__) + +HEARTBEAT_INTERVAL_S = 5 # How often to push a periodic heartbeat to WS clients + + +def create_nodes_router(services: ServiceContainer) -> APIRouter: + router = APIRouter(prefix="/nodes", tags=["Agent Nodes"]) + + def _registry(): + return services.node_registry_service + + # ------------------------------------------------------------------ # + # GET /nodes — list all nodes for a user # + # ------------------------------------------------------------------ # + @router.get("/", response_model=list[schemas.AgentNodeSummary], summary="List Agent Nodes") + def list_nodes(user_id: str, db: Session = Depends(get_db)): + """ + Returns all agent nodes ever registered under a given user. + Merges live connection state (in-memory) with the persistent DB record. + """ + registry = _registry() + db_nodes = db.query(models.AgentNode).filter( + models.AgentNode.user_id == user_id + ).all() + + result = [] + for db_node in db_nodes: + live = registry.get_node(db_node.node_id) + status = live._compute_status() if live else "offline" + last_seen = live.last_heartbeat_at if live else db_node.last_seen_at + result.append(schemas.AgentNodeSummary( + node_id=db_node.node_id, + user_id=db_node.user_id, + description=db_node.description, + capabilities=db_node.capabilities or {}, + status=status, + last_seen_at=last_seen, + created_at=db_node.created_at, + )) + return result + + # ------------------------------------------------------------------ # + # GET /nodes/{node_id} — full live status # + # ------------------------------------------------------------------ # + @router.get("/{node_id}", response_model=schemas.AgentNodeStatusResponse, summary="Get Node Live Status") + def get_node_status(node_id: str, db: Session = Depends(get_db)): + registry = _registry() + live = registry.get_node(node_id) + + if live: + d = live.to_dict() + return schemas.AgentNodeStatusResponse( + node_id=d["node_id"], user_id=d["user_id"], + description=d["description"], capabilities=d["capabilities"], + stats=schemas.AgentNodeStats(**d["stats"]), + status=d["status"], + connected_at=d["connected_at"], + last_heartbeat_at=d["last_heartbeat_at"], + ) + + db_node = db.query(models.AgentNode).filter(models.AgentNode.node_id == node_id).first() + if not db_node: + raise HTTPException(status_code=404, detail=f"Node '{node_id}' not found.") + return schemas.AgentNodeStatusResponse( + node_id=db_node.node_id, user_id=db_node.user_id, + description=db_node.description, capabilities=db_node.capabilities or {}, + stats=schemas.AgentNodeStats(), status="offline", + ) + + # ------------------------------------------------------------------ # + # GET /nodes/{node_id}/status — quick probe # + # ------------------------------------------------------------------ # + @router.get("/{node_id}/status", summary="Quick Node Online Check") + def get_node_online_status(node_id: str): + live = _registry().get_node(node_id) + if not live: + return {"node_id": node_id, "status": "offline"} + return {"node_id": node_id, "status": live._compute_status(), "stats": live.stats} + + # ------------------------------------------------------------------ # + # POST /nodes/{node_id}/dispatch — send a task # + # ------------------------------------------------------------------ # + @router.post("/{node_id}/dispatch", response_model=schemas.NodeDispatchResponse, summary="Dispatch Task to Node") + def dispatch_to_node(node_id: str, request: schemas.NodeDispatchRequest): + """ + Queues a task for an online node via its gRPC outbound queue. + Emits task_assigned event immediately for live UI feedback. + """ + registry = _registry() + live = registry.get_node(node_id) + if not live: + raise HTTPException(status_code=503, detail=f"Node '{node_id}' is not connected.") + + task_id = str(uuid.uuid4()) + + # Emit to live UI immediately + registry.emit(node_id, "task_assigned", + {"command": request.command, "session_id": request.session_id}, + task_id=task_id) + + try: + import sys, os + sys.path.insert(0, "/app/poc-grpc-agent") + from protos import agent_pb2 + from orchestrator.utils.crypto import sign_payload + + payload = request.command or json.dumps(request.browser_action) + sig = sign_payload(payload) + task_req = agent_pb2.TaskRequest( + task_id=task_id, + payload_json=payload, + signature=sig, + timeout_ms=request.timeout_ms, + session_id=request.session_id or "", + ) + live.queue.put(agent_pb2.ServerTaskMessage(task_request=task_req)) + registry.emit(node_id, "task_start", {"command": request.command}, task_id=task_id) + logger.info(f"[nodes] Dispatched task {task_id} to {node_id}") + except ImportError as e: + logger.warning(f"[nodes] poc-grpc-agent not installed: {e}. Dispatch is stub only.") + + return schemas.NodeDispatchResponse(task_id=task_id, status="accepted") + + # ------------------------------------------------------------------ # + # WS /nodes/{node_id}/stream — single node live event stream # + # ------------------------------------------------------------------ # + @router.websocket("/{node_id}/stream") + async def node_event_stream(websocket: WebSocket, node_id: str): + """ + WebSocket stream for a single node's execution events. + Powers the single-node execution pane in the UI. + + Message format: + { + "event": "task_stdout", + "label": "📤 Output", + "node_id": "node-alpha", + "task_id": "abc-123", + "timestamp": "2026-03-04T06:00:00", + "data": { ... } + } + """ + await websocket.accept() + registry = _registry() + + # Push current snapshot immediately + live = registry.get_node(node_id) + await websocket.send_json({ + "event": "snapshot", + "node_id": node_id, + "timestamp": __import__("datetime").datetime.utcnow().isoformat(), + "data": live.to_dict() if live else {"status": "offline"}, + }) + + q: queue.Queue = queue.Queue() + registry.subscribe_node(node_id, q) + try: + while True: + # Drain all pending events first + drained = False + while True: + try: + event = q.get_nowait() + await websocket.send_json(event) + drained = True + except queue.Empty: + break + + # Periodic heartbeat + await asyncio.sleep(HEARTBEAT_INTERVAL_S) + live = registry.get_node(node_id) + await websocket.send_json({ + "event": "heartbeat", + "node_id": node_id, + "timestamp": __import__("datetime").datetime.utcnow().isoformat(), + "data": {"status": live._compute_status() if live else "offline", + "stats": live.stats if live else {}}, + }) + except WebSocketDisconnect: + pass + except Exception as e: + logger.error(f"[nodes/stream] {node_id}: {e}") + finally: + registry.unsubscribe_node(node_id, q) + + # ------------------------------------------------------------------ # + # WS /nodes/stream/all — multi-node global execution bus # + # ------------------------------------------------------------------ # + @router.websocket("/stream/all") + async def all_nodes_event_stream(websocket: WebSocket, user_id: str): + """ + WebSocket stream for ALL of a user's node execution events combined. + Powers the multi-pane split-window execution UI. + + The client receives events from every node the user owns, + disambiguated by the 'node_id' field in each event. + + Use this to render: + - Per-node status columns (split by node_id) + - A unified chronological execution log + - Error/retry surface for all nodes simultaneously + """ + await websocket.accept() + registry = _registry() + + # Send initial snapshot of all user's live nodes + all_nodes = registry.list_nodes(user_id=user_id) + await websocket.send_json({ + "event": "initial_snapshot", + "user_id": user_id, + "timestamp": __import__("datetime").datetime.utcnow().isoformat(), + "data": { + "nodes": [n.to_dict() for n in all_nodes], + "count": len(all_nodes), + }, + }) + + q: queue.Queue = queue.Queue() + registry.subscribe_user(user_id, q) + try: + while True: + while True: + try: + event = q.get_nowait() + await websocket.send_json(event) + except queue.Empty: + break + + await asyncio.sleep(HEARTBEAT_INTERVAL_S) + # Push a periodic mesh health summary + live_nodes = registry.list_nodes(user_id=user_id) + await websocket.send_json({ + "event": "mesh_heartbeat", + "user_id": user_id, + "timestamp": __import__("datetime").datetime.utcnow().isoformat(), + "data": { + "nodes": [ + {"node_id": n.node_id, "status": n._compute_status(), "stats": n.stats} + for n in live_nodes + ] + }, + }) + except WebSocketDisconnect: + pass + except Exception as e: + logger.error(f"[nodes/stream/all] user={user_id}: {e}") + finally: + registry.unsubscribe_user(user_id, q) + + return router diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index ba307c1..7ab79a0 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -203,3 +203,49 @@ model_name: str max_tokens: Optional[int] = None max_input_tokens: Optional[int] = None + + +# --- Agent Node Schemas --- + +class AgentNodeStats(BaseModel): + """Live performance stats reported via heartbeat.""" + active_worker_count: int = 0 + cpu_usage_percent: float = 0.0 + memory_usage_percent: float = 0.0 + running: List[str] = [] + +class AgentNodeStatusResponse(BaseModel): + """Full live status of an agent node.""" + node_id: str + user_id: str + description: Optional[str] = None + capabilities: dict = {} + stats: AgentNodeStats = AgentNodeStats() + status: str # 'online' | 'offline' | 'stale' + connected_at: Optional[str] = None + last_heartbeat_at: Optional[str] = None + model_config = ConfigDict(from_attributes=True) + +class AgentNodeSummary(BaseModel): + """Lightweight node record for list views.""" + node_id: str + user_id: str + description: Optional[str] = None + capabilities: dict = {} + status: str + last_seen_at: Optional[datetime] = None + created_at: Optional[datetime] = None + model_config = ConfigDict(from_attributes=True) + +class NodeDispatchRequest(BaseModel): + """Dispatch a shell or browser action to a specific node.""" + command: str = "" # Shell command (mutually exclusive with browser_action) + browser_action: Optional[dict] = None # BrowserAction payload + session_id: Optional[str] = None # Workspace session context + timeout_ms: int = 30000 + +class NodeDispatchResponse(BaseModel): + task_id: str + status: str # 'accepted' | 'rejected' + reason: Optional[str] = None + diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index 99c775b..3d2c7eb 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -24,6 +24,7 @@ from app.core.services.user import UserService from app.core.services.prompt import PromptService from app.core.services.tool import ToolService +from app.core.services.node_registry import NodeRegistryService # Note: The llm_clients import and initialization are removed as they # are not used in RAGService's constructor based on your services.py # from app.core.llm_clients import DeepSeekClient, GeminiClient @@ -127,6 +128,7 @@ services.with_service("session_service", service=SessionService()) services.with_service("user_service", service=UserService()) services.with_service("tool_service", service=ToolService()) + services.with_service("node_registry_service", service=NodeRegistryService()) # Create and include the API router, injecting the service api_router = create_api_router(services=services) diff --git a/ai-hub/app/core/services/node_registry.py b/ai-hub/app/core/services/node_registry.py new file mode 100644 index 0000000..191776c --- /dev/null +++ b/ai-hub/app/core/services/node_registry.py @@ -0,0 +1,206 @@ +""" +NodeRegistry Service — AI Hub Integration Layer +Manages live Agent Node registrations, their in-memory gRPC queues, +and a rich event bus for real-time UI streaming (split-pane terminal UX). +""" +import threading +import queue +from datetime import datetime +from typing import Dict, Optional, List, Any + + +# All event types emitted across the system — rendered in the live UI +EVENT_TYPES = { + "node_online": "🟢 Node Online", + "node_offline": "⚫ Node Offline", + "node_stale": "🟡 Node Stale", + "heartbeat": "💓 Heartbeat", + "task_assigned": "📦 Task Assigned", + "task_start": "🚀 Task Starting", + "task_stdout": "📤 Output", + "task_done": "✅ Task Done", + "task_error": "❌ Task Error", + "task_cancelled": "🛑 Task Cancelled", + "task_retry": "🔄 Retry", + "sync_manifest": "📋 Manifest Sync", + "sync_push": "📁 File Push", + "sync_drift": "⚠️ Drift Detected", + "sync_recovery": "🏥 Recovery", + "sync_locked": "🔒 Workspace Locked", + "sync_unlocked": "🔓 Workspace Unlocked", + "info": "ℹ️ Info", +} + + +class LiveNodeRecord: + """Represents a single connected Agent Node and its associated state.""" + def __init__(self, node_id: str, user_id: str, metadata: dict): + self.node_id = node_id + self.user_id = user_id # Owner — maps node to a Hub user + self.metadata = metadata # desc, caps (capabilities dict) + self.queue: queue.Queue = queue.Queue() # gRPC outbound message queue + self.stats: dict = { + "active_worker_count": 0, + "cpu_usage_percent": 0.0, + "memory_usage_percent": 0.0, + "running": [], + } + self.connected_at: datetime = datetime.utcnow() + self.last_heartbeat_at: datetime = datetime.utcnow() + + def update_stats(self, stats: dict): + self.stats.update(stats) + self.last_heartbeat_at = datetime.utcnow() + + def to_dict(self) -> dict: + return { + "node_id": self.node_id, + "user_id": self.user_id, + "description": self.metadata.get("desc", ""), + "capabilities": self.metadata.get("caps", {}), + "stats": self.stats, + "connected_at": self.connected_at.isoformat(), + "last_heartbeat_at": self.last_heartbeat_at.isoformat(), + "status": self._compute_status(), + } + + def _compute_status(self) -> str: + delta = (datetime.utcnow() - self.last_heartbeat_at).total_seconds() + if delta > 30: + return "stale" + return "online" + + +class NodeRegistryService: + """ + In-memory registry of live Agent Nodes, integrated into the FastAPI + ServiceContainer. + + Provides: + - Live node registration / deregistration + - gRPC outbound queue per node + - Rich event bus for real-time UI streaming: + * Per-node stream → single node execution pane + * Per-user stream → all-nodes global execution bus (multi-pane view) + """ + + def __init__(self): + self._lock = threading.Lock() + self._nodes: Dict[str, LiveNodeRecord] = {} + # Per-node WS subscribers: node_id -> [queue, ...] + self._node_listeners: Dict[str, List[queue.Queue]] = {} + # Per-user WS subscribers: user_id -> [queue, ...] (ALL nodes for that user) + self._user_listeners: Dict[str, List[queue.Queue]] = {} + + # ------------------------------------------------------------------ # + # Registration # + # ------------------------------------------------------------------ # + + def register(self, node_id: str, user_id: str, metadata: dict) -> LiveNodeRecord: + """Register or re-register a node (called from gRPC SyncConfiguration).""" + with self._lock: + record = LiveNodeRecord(node_id=node_id, user_id=user_id, metadata=metadata) + self._nodes[node_id] = record + print(f"[📋] NodeRegistry: Registered {node_id} (owner: {user_id})") + self.emit(node_id, "node_online", record.to_dict()) + return record + + def deregister(self, node_id: str): + """Remove a node when its gRPC stream closes (called from TaskStream finally).""" + with self._lock: + node = self._nodes.pop(node_id, None) + user_id = node.user_id if node else None + self.emit(node_id, "node_offline", {"node_id": node_id, "user_id": user_id}) + print(f"[📋] NodeRegistry: Deregistered {node_id}") + + # ------------------------------------------------------------------ # + # Queries # + # ------------------------------------------------------------------ # + + def get_node(self, node_id: str) -> Optional[LiveNodeRecord]: + with self._lock: + return self._nodes.get(node_id) + + def list_nodes(self, user_id: Optional[str] = None) -> List[LiveNodeRecord]: + with self._lock: + if user_id: + return [n for n in self._nodes.values() if n.user_id == user_id] + return list(self._nodes.values()) + + def get_best(self, user_id: Optional[str] = None) -> Optional[str]: + """Pick the least-loaded node for a given owner.""" + nodes = self.list_nodes(user_id=user_id) + if not nodes: + return None + return sorted(nodes, key=lambda n: n.stats.get("active_worker_count", 999))[0].node_id + + def update_stats(self, node_id: str, stats: dict): + with self._lock: + node = self._nodes.get(node_id) + if node: + node.update_stats(stats) + # Also emit heartbeat event to UI + self.emit(node_id, "heartbeat", stats) + + # ------------------------------------------------------------------ # + # Event Bus # + # ------------------------------------------------------------------ # + + def emit(self, node_id: str, event_type: str, data: Any = None, task_id: str = ""): + """ + Emit a rich structured execution event. + + Delivered to: + - Per-node WS subscribers → powers the single-node execution pane + - Per-user WS subscribers → powers the global multi-node execution bus + """ + with self._lock: + node = self._nodes.get(node_id) + user_id = node.user_id if node else "" + node_qs = list(self._node_listeners.get(node_id, [])) + user_qs = list(self._user_listeners.get(user_id, [])) if user_id else [] + + event = { + "event": event_type, + "label": EVENT_TYPES.get(event_type, event_type), + "node_id": node_id, + "user_id": user_id, + "task_id": task_id, + "timestamp": datetime.utcnow().isoformat(), + "data": data, + } + # Deliver — avoid duplicates if same queue is in both lists + seen = set() + for q in node_qs + user_qs: + if id(q) not in seen: + seen.add(id(q)) + try: + q.put_nowait(event) + except Exception: + pass + + # ------------------------------------------------------------------ # + # WS Subscriptions # + # ------------------------------------------------------------------ # + + def subscribe_node(self, node_id: str, q: queue.Queue): + """Subscribe to execution events for a specific node (single-pane view).""" + with self._lock: + self._node_listeners.setdefault(node_id, []).append(q) + + def unsubscribe_node(self, node_id: str, q: queue.Queue): + with self._lock: + lst = self._node_listeners.get(node_id, []) + if q in lst: + lst.remove(q) + + def subscribe_user(self, user_id: str, q: queue.Queue): + """Subscribe to ALL node events for a user (multi-pane global bus).""" + with self._lock: + self._user_listeners.setdefault(user_id, []).append(q) + + def unsubscribe_user(self, user_id: str, q: queue.Queue): + with self._lock: + lst = self._user_listeners.get(user_id, []) + if q in lst: + lst.remove(q) diff --git a/ai-hub/app/db/migrate.py b/ai-hub/app/db/migrate.py index 557c8ac..2ae52de 100644 --- a/ai-hub/app/db/migrate.py +++ b/ai-hub/app/db/migrate.py @@ -57,7 +57,27 @@ else: logger.info(f"Column '{col_name}' already exists in 'sessions'.") + # Agent Nodes table — create if missing (handled by create_all, but add + # any new columns that might be added post-creation) + if inspector.has_table("agent_nodes"): + node_columns = [c["name"] for c in inspector.get_columns("agent_nodes")] + node_required_columns = [ + ("invite_token", "TEXT"), + ("last_status", "TEXT"), + ("last_seen_at", "DATETIME"), + ("capabilities", "TEXT"), + ] + for col_name, col_type in node_required_columns: + if col_name not in node_columns: + logger.info(f"Adding column '{col_name}' to 'agent_nodes' table...") + try: + conn.execute(text(f"ALTER TABLE agent_nodes ADD COLUMN {col_name} {col_type}")) + conn.commit() + except Exception as e: + logger.error(f"Failed to add column '{col_name}': {e}") + logger.info("Database migrations complete.") + if __name__ == "__main__": run_migrations() diff --git a/ai-hub/app/db/models.py b/ai-hub/app/db/models.py index 9f5b439..27f3eaa 100644 --- a/ai-hub/app/db/models.py +++ b/ai-hub/app/db/models.py @@ -6,6 +6,7 @@ # We will import it from there to ensure all models use the same base. from .database import Base + # --- SQLAlchemy Models --- # These classes define the structure of the database tables and how they relate. @@ -327,3 +328,35 @@ def __repr__(self): return f"" + + +class AgentNode(Base): + """ + Persistent record of an Agent Node registered by a user. + Stores the node's identity, capabilities, and invite token. + The 'live' connection state is managed separately in NodeRegistryService (in-memory). + """ + __tablename__ = 'agent_nodes' + + id = Column(Integer, primary_key=True, index=True) + # Human-readable node identifier set in the node's YAML config + node_id = Column(String, unique=True, index=True, nullable=False) + # Owner of this node + user_id = Column(String, ForeignKey('users.id'), nullable=False) + # Human description of the node (e.g., "MacBook Pro M3 - Dev Machine") + description = Column(String, nullable=True) + # JSON of capabilities: {"shell": "v1", "browser": "playwright-sync-bridge"} + capabilities = Column(JSON, default={}, nullable=True) + # Pre-signed invite token (generated at "Download Your Node" step) + invite_token = Column(String, unique=True, nullable=True, index=True) + # Last known status: 'online', 'offline', 'stale' + last_status = Column(String, default="offline", nullable=False) + # Last heartbeat timestamp + last_seen_at = Column(DateTime, nullable=True) + created_at = Column(DateTime, default=datetime.utcnow, nullable=False) + + owner = relationship("User") + + def __repr__(self): + return f"" +