diff --git a/ai-hub/app/api/routes/nodes.py b/ai-hub/app/api/routes/nodes.py index 709efaf..6fe2319 100644 --- a/ai-hub/app/api/routes/nodes.py +++ b/ai-hub/app/api/routes/nodes.py @@ -1,19 +1,30 @@ """ Agent Node REST + WebSocket API -Exposes the live node registry and execution event bus to the AI Hub UI. +Admin-managed nodes, group access control, and user-facing live streaming. -Endpoints: - GET /nodes — List all nodes for the given user_id - GET /nodes/{node_id} — Full live node status - GET /nodes/{node_id}/status — Quick online/offline probe - POST /nodes/{node_id}/dispatch — Dispatch a task to a node - WS /nodes/{node_id}/stream — Live event stream for ONE node - WS /nodes/stream/all?user_id=... — Live event stream for ALL user's nodes +Admin endpoints (require role=admin): + POST /nodes/ — Create node registration + generate invite_token + GET /nodes/admin — List all nodes (admin view, full detail) + GET /nodes/admin/{node_id} — Full admin detail including invite_token + PATCH /nodes/admin/{node_id} — Update node config (description, skill_config, is_active) + POST /nodes/admin/{node_id}/access — Grant group access to a node + DELETE /nodes/admin/{node_id}/access/{group_id} — Revoke group access + +User endpoints (scoped to caller's group): + GET /nodes/ — List accessible nodes (user view, no sensitive data) + GET /nodes/{node_id}/status — Quick online/offline probe + POST /nodes/{node_id}/dispatch — Dispatch a task to a node + PATCH /nodes/preferences — Update user's default_nodes + data_source prefs + +WebSocket (real-time streaming): + WS /nodes/{node_id}/stream — Single-node live execution stream + WS /nodes/stream/all?user_id=... — All-nodes global bus (multi-pane UI) """ import asyncio import json import queue import uuid +import secrets import logging from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect, Depends from sqlalchemy.orm import Session @@ -24,7 +35,7 @@ logger = logging.getLogger(__name__) -HEARTBEAT_INTERVAL_S = 5 # How often to push a periodic heartbeat to WS clients +HEARTBEAT_INTERVAL_S = 5 def create_nodes_router(services: ServiceContainer) -> APIRouter: @@ -33,82 +44,185 @@ def _registry(): return services.node_registry_service - # ------------------------------------------------------------------ # - # GET /nodes — list all nodes for a user # - # ------------------------------------------------------------------ # - @router.get("/", response_model=list[schemas.AgentNodeSummary], summary="List Agent Nodes") - def list_nodes(user_id: str, db: Session = Depends(get_db)): + def _require_admin(user_id: str, db: Session): + user = db.query(models.User).filter(models.User.id == user_id).first() + if not user or user.role != "admin": + raise HTTPException(status_code=403, detail="Admin access required.") + return user + + # ================================================================== + # ADMIN ENDPOINTS + # ================================================================== + + @router.post("/admin", response_model=schemas.AgentNodeAdminDetail, summary="[Admin] Register New Node") + def admin_create_node( + request: schemas.AgentNodeCreate, + admin_id: str, + db: Session = Depends(get_db) + ): """ - Returns all agent nodes ever registered under a given user. - Merges live connection state (in-memory) with the persistent DB record. + Admin registers a new Agent Node. + Returns the node record including a generated invite_token that must be + placed in the node's config YAML before deployment. """ - registry = _registry() - db_nodes = db.query(models.AgentNode).filter( - models.AgentNode.user_id == user_id - ).all() + _require_admin(admin_id, db) - result = [] - for db_node in db_nodes: - live = registry.get_node(db_node.node_id) - status = live._compute_status() if live else "offline" - last_seen = live.last_heartbeat_at if live else db_node.last_seen_at - result.append(schemas.AgentNodeSummary( - node_id=db_node.node_id, - user_id=db_node.user_id, - description=db_node.description, - capabilities=db_node.capabilities or {}, - status=status, - last_seen_at=last_seen, - created_at=db_node.created_at, - )) - return result + existing = db.query(models.AgentNode).filter( + models.AgentNode.node_id == request.node_id + ).first() + if existing: + raise HTTPException(status_code=409, detail=f"Node '{request.node_id}' already exists.") - # ------------------------------------------------------------------ # - # GET /nodes/{node_id} — full live status # - # ------------------------------------------------------------------ # - @router.get("/{node_id}", response_model=schemas.AgentNodeStatusResponse, summary="Get Node Live Status") - def get_node_status(node_id: str, db: Session = Depends(get_db)): - registry = _registry() - live = registry.get_node(node_id) + # Generate a cryptographically secure invite token + invite_token = secrets.token_urlsafe(32) - if live: - d = live.to_dict() - return schemas.AgentNodeStatusResponse( - node_id=d["node_id"], user_id=d["user_id"], - description=d["description"], capabilities=d["capabilities"], - stats=schemas.AgentNodeStats(**d["stats"]), - status=d["status"], - connected_at=d["connected_at"], - last_heartbeat_at=d["last_heartbeat_at"], - ) - - db_node = db.query(models.AgentNode).filter(models.AgentNode.node_id == node_id).first() - if not db_node: - raise HTTPException(status_code=404, detail=f"Node '{node_id}' not found.") - return schemas.AgentNodeStatusResponse( - node_id=db_node.node_id, user_id=db_node.user_id, - description=db_node.description, capabilities=db_node.capabilities or {}, - stats=schemas.AgentNodeStats(), status="offline", + node = models.AgentNode( + node_id=request.node_id, + display_name=request.display_name, + description=request.description, + registered_by=admin_id, + skill_config=request.skill_config.model_dump(), + invite_token=invite_token, + last_status="offline", ) + db.add(node) + db.commit() + db.refresh(node) - # ------------------------------------------------------------------ # - # GET /nodes/{node_id}/status — quick probe # - # ------------------------------------------------------------------ # + logger.info(f"[admin] Created node '{request.node_id}' by admin {admin_id}") + return _node_to_admin_detail(node, _registry()) + + @router.get("/admin", response_model=list[schemas.AgentNodeAdminDetail], summary="[Admin] List All Nodes") + def admin_list_nodes(admin_id: str, db: Session = Depends(get_db)): + """Full node list for admin dashboard, including invite_token and skill config.""" + _require_admin(admin_id, db) + nodes = db.query(models.AgentNode).all() + return [_node_to_admin_detail(n, _registry()) for n in nodes] + + @router.get("/admin/{node_id}", response_model=schemas.AgentNodeAdminDetail, summary="[Admin] Get Node Detail") + def admin_get_node(node_id: str, admin_id: str, db: Session = Depends(get_db)): + _require_admin(admin_id, db) + node = _get_node_or_404(node_id, db) + return _node_to_admin_detail(node, _registry()) + + @router.patch("/admin/{node_id}", response_model=schemas.AgentNodeAdminDetail, summary="[Admin] Update Node Config") + def admin_update_node( + node_id: str, + update: schemas.AgentNodeUpdate, + admin_id: str, + db: Session = Depends(get_db) + ): + """Update display_name, description, skill_config toggles, or is_active.""" + _require_admin(admin_id, db) + node = _get_node_or_404(node_id, db) + + if update.display_name is not None: + node.display_name = update.display_name + if update.description is not None: + node.description = update.description + if update.skill_config is not None: + node.skill_config = update.skill_config.model_dump() + if update.is_active is not None: + node.is_active = update.is_active + + db.commit() + db.refresh(node) + return _node_to_admin_detail(node, _registry()) + + @router.post("/admin/{node_id}/access", response_model=schemas.NodeAccessResponse, summary="[Admin] Grant Group Access") + def admin_grant_access( + node_id: str, + grant: schemas.NodeAccessGrant, + admin_id: str, + db: Session = Depends(get_db) + ): + """Grant a group access to use this node in sessions.""" + _require_admin(admin_id, db) + node = _get_node_or_404(node_id, db) + + existing = db.query(models.NodeGroupAccess).filter( + models.NodeGroupAccess.node_id == node_id, + models.NodeGroupAccess.group_id == grant.group_id + ).first() + if existing: + existing.access_level = grant.access_level + existing.granted_by = admin_id + db.commit() + db.refresh(existing) + return existing + + access = models.NodeGroupAccess( + node_id=node_id, + group_id=grant.group_id, + access_level=grant.access_level, + granted_by=admin_id, + ) + db.add(access) + db.commit() + db.refresh(access) + return access + + @router.delete("/admin/{node_id}/access/{group_id}", summary="[Admin] Revoke Group Access") + def admin_revoke_access( + node_id: str, + group_id: str, + admin_id: str, + db: Session = Depends(get_db) + ): + _require_admin(admin_id, db) + access = db.query(models.NodeGroupAccess).filter( + models.NodeGroupAccess.node_id == node_id, + models.NodeGroupAccess.group_id == group_id + ).first() + if not access: + raise HTTPException(status_code=404, detail="Access grant not found.") + db.delete(access) + db.commit() + return {"message": f"Access revoked for group '{group_id}' on node '{node_id}'."} + + # ================================================================== + # USER-FACING ENDPOINTS + # ================================================================== + + @router.get("/", response_model=list[schemas.AgentNodeUserView], summary="List Accessible Nodes") + def list_accessible_nodes(user_id: str, db: Session = Depends(get_db)): + """ + Returns nodes the calling user's group has access to. + Merges live connection state from the in-memory registry. + """ + user = db.query(models.User).filter(models.User.id == user_id).first() + if not user: + raise HTTPException(status_code=404, detail="User not found.") + + # Admin sees everything; users see only group-granted nodes + if user.role == "admin": + nodes = db.query(models.AgentNode).filter(models.AgentNode.is_active == True).all() + else: + # Nodes accessible via user's group + accesses = db.query(models.NodeGroupAccess).filter( + models.NodeGroupAccess.group_id == user.group_id + ).all() + node_ids = [a.node_id for a in accesses] + nodes = db.query(models.AgentNode).filter( + models.AgentNode.node_id.in_(node_ids), + models.AgentNode.is_active == True + ).all() + + registry = _registry() + return [_node_to_user_view(n, registry) for n in nodes] + @router.get("/{node_id}/status", summary="Quick Node Online Check") - def get_node_online_status(node_id: str): + def get_node_status(node_id: str): live = _registry().get_node(node_id) if not live: return {"node_id": node_id, "status": "offline"} return {"node_id": node_id, "status": live._compute_status(), "stats": live.stats} - # ------------------------------------------------------------------ # - # POST /nodes/{node_id}/dispatch — send a task # - # ------------------------------------------------------------------ # @router.post("/{node_id}/dispatch", response_model=schemas.NodeDispatchResponse, summary="Dispatch Task to Node") def dispatch_to_node(node_id: str, request: schemas.NodeDispatchRequest): """ - Queues a task for an online node via its gRPC outbound queue. - Emits task_assigned event immediately for live UI feedback. + Queue a shell or browser task to an online node. + Emits task_assigned immediately so the live UI shows it. """ registry = _registry() live = registry.get_node(node_id) @@ -116,63 +230,78 @@ raise HTTPException(status_code=503, detail=f"Node '{node_id}' is not connected.") task_id = str(uuid.uuid4()) - - # Emit to live UI immediately registry.emit(node_id, "task_assigned", {"command": request.command, "session_id": request.session_id}, task_id=task_id) try: - import sys, os + import sys sys.path.insert(0, "/app/poc-grpc-agent") from protos import agent_pb2 from orchestrator.utils.crypto import sign_payload - payload = request.command or json.dumps(request.browser_action) - sig = sign_payload(payload) task_req = agent_pb2.TaskRequest( task_id=task_id, payload_json=payload, - signature=sig, + signature=sign_payload(payload), timeout_ms=request.timeout_ms, session_id=request.session_id or "", ) live.queue.put(agent_pb2.ServerTaskMessage(task_request=task_req)) registry.emit(node_id, "task_start", {"command": request.command}, task_id=task_id) - logger.info(f"[nodes] Dispatched task {task_id} to {node_id}") - except ImportError as e: - logger.warning(f"[nodes] poc-grpc-agent not installed: {e}. Dispatch is stub only.") + except ImportError: + logger.warning("[nodes] poc-grpc-agent not on path; dispatch is stub only.") return schemas.NodeDispatchResponse(task_id=task_id, status="accepted") - # ------------------------------------------------------------------ # - # WS /nodes/{node_id}/stream — single node live event stream # - # ------------------------------------------------------------------ # + @router.patch("/preferences", summary="Update User Node Preferences") + def update_node_preferences( + user_id: str, + prefs: schemas.UserNodePreferences, + db: Session = Depends(get_db) + ): + """ + Save the user's default_node_ids and data_source config into their preferences. + The UI reads this to auto-attach nodes when a new session starts. + """ + user = db.query(models.User).filter(models.User.id == user_id).first() + if not user: + raise HTTPException(status_code=404, detail="User not found.") + existing_prefs = user.preferences or {} + existing_prefs["nodes"] = prefs.model_dump() + user.preferences = existing_prefs + db.commit() + return {"message": "Node preferences saved.", "nodes": prefs.model_dump()} + + @router.get("/preferences", response_model=schemas.UserNodePreferences, summary="Get User Node Preferences") + def get_node_preferences(user_id: str, db: Session = Depends(get_db)): + user = db.query(models.User).filter(models.User.id == user_id).first() + if not user: + raise HTTPException(status_code=404, detail="User not found.") + node_prefs = (user.preferences or {}).get("nodes", {}) + return schemas.UserNodePreferences(**node_prefs) if node_prefs else schemas.UserNodePreferences() + + # ================================================================== + # WEBSOCKET — Single-node live event stream + # ================================================================== + @router.websocket("/{node_id}/stream") async def node_event_stream(websocket: WebSocket, node_id: str): """ - WebSocket stream for a single node's execution events. - Powers the single-node execution pane in the UI. + Single-node live event stream. + Powers the per-node execution pane in the split-window UI. - Message format: - { - "event": "task_stdout", - "label": "📤 Output", - "node_id": "node-alpha", - "task_id": "abc-123", - "timestamp": "2026-03-04T06:00:00", - "data": { ... } - } + Event format: + { event, label, node_id, task_id, timestamp, data } """ await websocket.accept() registry = _registry() - # Push current snapshot immediately live = registry.get_node(node_id) await websocket.send_json({ "event": "snapshot", "node_id": node_id, - "timestamp": __import__("datetime").datetime.utcnow().isoformat(), + "timestamp": _now(), "data": live.to_dict() if live else {"status": "offline"}, }) @@ -180,23 +309,11 @@ registry.subscribe_node(node_id, q) try: while True: - # Drain all pending events first - drained = False - while True: - try: - event = q.get_nowait() - await websocket.send_json(event) - drained = True - except queue.Empty: - break - - # Periodic heartbeat + _drain(q, websocket) await asyncio.sleep(HEARTBEAT_INTERVAL_S) live = registry.get_node(node_id) await websocket.send_json({ - "event": "heartbeat", - "node_id": node_id, - "timestamp": __import__("datetime").datetime.utcnow().isoformat(), + "event": "heartbeat", "node_id": node_id, "timestamp": _now(), "data": {"status": live._compute_status() if live else "offline", "stats": live.stats if live else {}}, }) @@ -207,61 +324,45 @@ finally: registry.unsubscribe_node(node_id, q) - # ------------------------------------------------------------------ # - # WS /nodes/stream/all — multi-node global execution bus # - # ------------------------------------------------------------------ # + # ================================================================== + # WEBSOCKET — Multi-node global execution bus + # ================================================================== + @router.websocket("/stream/all") async def all_nodes_event_stream(websocket: WebSocket, user_id: str): """ - WebSocket stream for ALL of a user's node execution events combined. - Powers the multi-pane split-window execution UI. + Multi-node global event bus for a user. + Powers the split-window multi-pane execution UI. - The client receives events from every node the user owns, - disambiguated by the 'node_id' field in each event. - - Use this to render: - - Per-node status columns (split by node_id) - - A unified chronological execution log - - Error/retry surface for all nodes simultaneously + On connect: sends initial_snapshot with all live nodes. + Ongoing: streams all events from all user's nodes (disambiguated by node_id). + Every 5s: sends a mesh_heartbeat summary across all nodes. """ await websocket.accept() registry = _registry() - # Send initial snapshot of all user's live nodes - all_nodes = registry.list_nodes(user_id=user_id) + all_live = registry.list_nodes(user_id=user_id) await websocket.send_json({ "event": "initial_snapshot", "user_id": user_id, - "timestamp": __import__("datetime").datetime.utcnow().isoformat(), - "data": { - "nodes": [n.to_dict() for n in all_nodes], - "count": len(all_nodes), - }, + "timestamp": _now(), + "data": {"nodes": [n.to_dict() for n in all_live], "count": len(all_live)}, }) q: queue.Queue = queue.Queue() registry.subscribe_user(user_id, q) try: while True: - while True: - try: - event = q.get_nowait() - await websocket.send_json(event) - except queue.Empty: - break - + _drain(q, websocket) await asyncio.sleep(HEARTBEAT_INTERVAL_S) - # Push a periodic mesh health summary live_nodes = registry.list_nodes(user_id=user_id) await websocket.send_json({ "event": "mesh_heartbeat", "user_id": user_id, - "timestamp": __import__("datetime").datetime.utcnow().isoformat(), + "timestamp": _now(), "data": { - "nodes": [ - {"node_id": n.node_id, "status": n._compute_status(), "stats": n.stats} - for n in live_nodes - ] + "nodes": [{"node_id": n.node_id, "status": n._compute_status(), "stats": n.stats} + for n in live_nodes] }, }) except WebSocketDisconnect: @@ -272,3 +373,71 @@ registry.unsubscribe_user(user_id, q) return router + + +# =========================================================================== +# Helpers +# =========================================================================== + +def _get_node_or_404(node_id: str, db: Session) -> models.AgentNode: + node = db.query(models.AgentNode).filter(models.AgentNode.node_id == node_id).first() + if not node: + raise HTTPException(status_code=404, detail=f"Node '{node_id}' not found.") + return node + + +def _node_to_admin_detail(node: models.AgentNode, registry) -> schemas.AgentNodeAdminDetail: + live = registry.get_node(node.node_id) + status = live._compute_status() if live else node.last_status or "offline" + stats = schemas.AgentNodeStats(**live.stats) if live else schemas.AgentNodeStats() + return schemas.AgentNodeAdminDetail( + node_id=node.node_id, + display_name=node.display_name, + description=node.description, + skill_config=node.skill_config or {}, + capabilities=node.capabilities or {}, + invite_token=node.invite_token, + is_active=node.is_active, + last_status=status, + last_seen_at=node.last_seen_at, + created_at=node.created_at, + registered_by=node.registered_by, + group_access=[ + schemas.NodeAccessResponse( + id=a.id, node_id=a.node_id, group_id=a.group_id, + access_level=a.access_level, granted_at=a.granted_at + ) for a in (node.group_access or []) + ], + stats=stats, + ) + + +def _node_to_user_view(node: models.AgentNode, registry) -> schemas.AgentNodeUserView: + live = registry.get_node(node.node_id) + status = live._compute_status() if live else node.last_status or "offline" + skill_cfg = node.skill_config or {} + available = [skill for skill, cfg in skill_cfg.items() if cfg.get("enabled", True)] + return schemas.AgentNodeUserView( + node_id=node.node_id, + display_name=node.display_name, + description=node.description, + capabilities=node.capabilities or {}, + available_skills=available, + last_status=status, + last_seen_at=node.last_seen_at, + ) + + +def _now() -> str: + from datetime import datetime + return datetime.utcnow().isoformat() + + +async def _drain(q: queue.Queue, websocket: WebSocket): + """Drain all pending queue items and send to websocket (non-blocking).""" + while True: + try: + event = q.get_nowait() + await websocket.send_json(event) + except queue.Empty: + break diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 7ab79a0..beee122 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -205,7 +205,56 @@ max_input_tokens: Optional[int] = None -# --- Agent Node Schemas --- +# --------------------------------------------------------------------------- +# Agent Node Schemas +# --------------------------------------------------------------------------- + +# --- Skill Toggles (admin-configured) --- + +class SkillConfig(BaseModel): + """Per-skill enable/disable with optional config.""" + enabled: bool = True + cwd_jail: Optional[str] = None # shell only: restrict working directory + max_file_size_mb: Optional[int] = None # sync only: file size cap + +class NodeSkillConfig(BaseModel): + """Admin-controlled skill configuration for a node.""" + shell: SkillConfig = SkillConfig(enabled=True) + browser: SkillConfig = SkillConfig(enabled=True) + sync: SkillConfig = SkillConfig(enabled=True) + +# --- Admin Create / Update --- + +class AgentNodeCreate(BaseModel): + """Payload for admin creating a new node registration.""" + node_id: str = Field(..., description="Stable identifier used in the node's config YAML, e.g. 'dev-macbook-m3'") + display_name: str = Field(..., description="Human-readable name shown in the UI") + description: Optional[str] = None + skill_config: NodeSkillConfig = NodeSkillConfig() + +class AgentNodeUpdate(BaseModel): + """Payload for admin updating node configuration.""" + display_name: Optional[str] = None + description: Optional[str] = None + skill_config: Optional[NodeSkillConfig] = None + is_active: Optional[bool] = None + +# --- Group Access --- + +class NodeAccessGrant(BaseModel): + """Admin grants a group access to a node.""" + group_id: str + access_level: str = Field("use", description="'view', 'use', or 'admin'") + +class NodeAccessResponse(BaseModel): + id: int + node_id: str + group_id: str + access_level: str + granted_at: datetime + model_config = ConfigDict(from_attributes=True) + +# --- Live Stats --- class AgentNodeStats(BaseModel): """Live performance stats reported via heartbeat.""" @@ -214,38 +263,75 @@ memory_usage_percent: float = 0.0 running: List[str] = [] -class AgentNodeStatusResponse(BaseModel): - """Full live status of an agent node.""" +# --- Node Responses --- + +class AgentNodeAdminDetail(BaseModel): + """Full node detail for admin view — includes invite_token and skill config.""" node_id: str - user_id: str + display_name: str description: Optional[str] = None + skill_config: dict = {} capabilities: dict = {} + invite_token: Optional[str] = None + is_active: bool = True + last_status: str + last_seen_at: Optional[datetime] = None + created_at: datetime + registered_by: str + group_access: List[NodeAccessResponse] = [] stats: AgentNodeStats = AgentNodeStats() - status: str # 'online' | 'offline' | 'stale' - connected_at: Optional[str] = None - last_heartbeat_at: Optional[str] = None model_config = ConfigDict(from_attributes=True) -class AgentNodeSummary(BaseModel): - """Lightweight node record for list views.""" +class AgentNodeUserView(BaseModel): + """Node as seen by a user — no invite_token, no admin config details.""" node_id: str - user_id: str + display_name: str description: Optional[str] = None capabilities: dict = {} - status: str + # Which skills are available to this user (derived from skill_config.enabled) + available_skills: List[str] = [] + last_status: str # 'online' | 'offline' | 'stale' last_seen_at: Optional[datetime] = None - created_at: Optional[datetime] = None model_config = ConfigDict(from_attributes=True) +class AgentNodeStatusResponse(BaseModel): + """Full live status of an agent node (used internally).""" + node_id: str + display_name: Optional[str] = None + stats: AgentNodeStats = AgentNodeStats() + status: str + connected_at: Optional[str] = None + last_heartbeat_at: Optional[str] = None + +# --- User Node Preferences --- + +class NodeDataSourceConfig(BaseModel): + """How a node should seed its workspace for a session.""" + source: str = Field("empty", description="'empty' | 'server' | 'node_local'") + path: Optional[str] = None # root path on node when source='node_local' + +class UserNodePreferences(BaseModel): + """Stored in User.preferences['nodes'].""" + default_node_ids: List[str] = Field( + default_factory=list, + description="Node IDs auto-attached when starting a new session" + ) + data_source: NodeDataSourceConfig = NodeDataSourceConfig() + +# --- Task Dispatch --- + class NodeDispatchRequest(BaseModel): """Dispatch a shell or browser action to a specific node.""" - command: str = "" # Shell command (mutually exclusive with browser_action) - browser_action: Optional[dict] = None # BrowserAction payload - session_id: Optional[str] = None # Workspace session context + command: str = "" + browser_action: Optional[dict] = None + session_id: Optional[str] = None timeout_ms: int = 30000 class NodeDispatchResponse(BaseModel): task_id: str - status: str # 'accepted' | 'rejected' + status: str # 'accepted' | 'rejected' reason: Optional[str] = None +# Keep backward-compat alias +AgentNodeSummary = AgentNodeUserView + diff --git a/ai-hub/app/core/services/node_registry.py b/ai-hub/app/core/services/node_registry.py index 191776c..2ba9aa3 100644 --- a/ai-hub/app/core/services/node_registry.py +++ b/ai-hub/app/core/services/node_registry.py @@ -2,6 +2,12 @@ NodeRegistry Service — AI Hub Integration Layer Manages live Agent Node registrations, their in-memory gRPC queues, and a rich event bus for real-time UI streaming (split-pane terminal UX). + +Persistence strategy (M2): +- The in-memory dict is the live connection cache (fast, gRPC-queue holder). +- The DB (via get_db_session) is the source of truth for node identity across reboots. +- On every register/deregister/heartbeat, we upsert the AgentNode DB record. +- On Hub startup all DB nodes are "offline"; they go "online" when they reconnect. """ import threading import queue @@ -33,7 +39,7 @@ class LiveNodeRecord: - """Represents a single connected Agent Node and its associated state.""" + """Represents a single connected Agent Node and its associated live state.""" def __init__(self, node_id: str, user_id: str, metadata: dict): self.node_id = node_id self.user_id = user_id # Owner — maps node to a Hub user @@ -73,15 +79,17 @@ class NodeRegistryService: """ - In-memory registry of live Agent Nodes, integrated into the FastAPI - ServiceContainer. + Persistent + in-memory registry of live Agent Nodes. - Provides: - - Live node registration / deregistration - - gRPC outbound queue per node - - Rich event bus for real-time UI streaming: - * Per-node stream → single node execution pane - * Per-user stream → all-nodes global execution bus (multi-pane view) + Two-tier storage: + Tier 1 — In-memory (_nodes dict): live connections, gRPC queues, real-time stats. + Tier 2 — Database (AgentNode model): node identity, capabilities, invite_token, + last_status, last_seen_at survive Hub restarts. + + When a node reconnects after a Hub restart it calls SyncConfiguration again, + which calls register() again → the DB record is updated to 'online'. + Nodes that haven't reconnected stay in the DB with 'offline' status so the UI + can still show them as known (but disconnected) nodes. """ def __init__(self): @@ -89,27 +97,98 @@ self._nodes: Dict[str, LiveNodeRecord] = {} # Per-node WS subscribers: node_id -> [queue, ...] self._node_listeners: Dict[str, List[queue.Queue]] = {} - # Per-user WS subscribers: user_id -> [queue, ...] (ALL nodes for that user) + # Per-user WS subscribers: user_id -> [queue, ...] (ALL nodes for that user) self._user_listeners: Dict[str, List[queue.Queue]] = {} # ------------------------------------------------------------------ # + # DB Helpers # + # ------------------------------------------------------------------ # + + def _db_upsert_node(self, node_id: str, user_id: str, metadata: dict): + """Create or update the AgentNode DB record (on connect).""" + from app.db.models import AgentNode + from app.db.session import get_db_session + try: + with get_db_session() as db: + record = db.query(AgentNode).filter(AgentNode.node_id == node_id).first() + if record: + record.user_id = user_id + record.description = metadata.get("desc", "") + record.capabilities = metadata.get("caps", {}) + record.last_status = "online" + record.last_seen_at = datetime.utcnow() + else: + record = AgentNode( + node_id=node_id, + user_id=user_id, + description=metadata.get("desc", ""), + capabilities=metadata.get("caps", {}), + last_status="online", + last_seen_at=datetime.utcnow(), + ) + db.add(record) + except Exception as e: + print(f"[NodeRegistry] DB upsert failed for {node_id}: {e}") + + def _db_mark_offline(self, node_id: str): + """Update last_seen_at and mark last_status = 'offline' on disconnect.""" + from app.db.models import AgentNode + from app.db.session import get_db_session + try: + with get_db_session() as db: + record = db.query(AgentNode).filter(AgentNode.node_id == node_id).first() + if record: + record.last_status = "offline" + record.last_seen_at = datetime.utcnow() + except Exception as e: + print(f"[NodeRegistry] DB mark-offline failed for {node_id}: {e}") + + def _db_update_heartbeat(self, node_id: str): + """Bump last_seen_at on each heartbeat so we know when the node was last active.""" + from app.db.models import AgentNode + from app.db.session import get_db_session + try: + with get_db_session() as db: + record = db.query(AgentNode).filter(AgentNode.node_id == node_id).first() + if record: + record.last_seen_at = datetime.utcnow() + record.last_status = "online" + except Exception as e: + print(f"[NodeRegistry] DB heartbeat update failed for {node_id}: {e}") + + + # ------------------------------------------------------------------ # # Registration # # ------------------------------------------------------------------ # def register(self, node_id: str, user_id: str, metadata: dict) -> LiveNodeRecord: - """Register or re-register a node (called from gRPC SyncConfiguration).""" + """ + Register or re-register a node. + Called from gRPC SyncConfiguration on every node connect/reconnect. + Persists to DB so the node survives Hub restarts as a known entity. + """ with self._lock: record = LiveNodeRecord(node_id=node_id, user_id=user_id, metadata=metadata) self._nodes[node_id] = record + + # Persist to DB (background-safe — session is scoped) + self._db_upsert_node(node_id, user_id, metadata) + print(f"[📋] NodeRegistry: Registered {node_id} (owner: {user_id})") self.emit(node_id, "node_online", record.to_dict()) return record def deregister(self, node_id: str): - """Remove a node when its gRPC stream closes (called from TaskStream finally).""" + """ + Remove a node from live_registry when its gRPC stream closes. + The DB record is kept with last_status='offline' so the user can + still see the node in their list (as disconnected). + """ with self._lock: node = self._nodes.pop(node_id, None) user_id = node.user_id if node else None + + self._db_mark_offline(node_id) self.emit(node_id, "node_offline", {"node_id": node_id, "user_id": user_id}) print(f"[📋] NodeRegistry: Deregistered {node_id}") @@ -118,28 +197,33 @@ # ------------------------------------------------------------------ # def get_node(self, node_id: str) -> Optional[LiveNodeRecord]: + """Returns a live record only if the node is currently connected.""" with self._lock: return self._nodes.get(node_id) def list_nodes(self, user_id: Optional[str] = None) -> List[LiveNodeRecord]: + """Returns only currently LIVE nodes (use the DB for the full list).""" with self._lock: if user_id: return [n for n in self._nodes.values() if n.user_id == user_id] return list(self._nodes.values()) def get_best(self, user_id: Optional[str] = None) -> Optional[str]: - """Pick the least-loaded node for a given owner.""" + """Pick the least-loaded live node for a given owner.""" nodes = self.list_nodes(user_id=user_id) if not nodes: return None return sorted(nodes, key=lambda n: n.stats.get("active_worker_count", 999))[0].node_id def update_stats(self, node_id: str, stats: dict): + """Called every heartbeat interval. Updates in-memory stats and bumps DB last_seen.""" with self._lock: node = self._nodes.get(node_id) if node: node.update_stats(stats) - # Also emit heartbeat event to UI + # Persist heartbeat timestamp to DB (throttle: already ~10s cadence from node) + self._db_update_heartbeat(node_id) + # Emit heartbeat event to live UI self.emit(node_id, "heartbeat", stats) # ------------------------------------------------------------------ # @@ -149,7 +233,6 @@ def emit(self, node_id: str, event_type: str, data: Any = None, task_id: str = ""): """ Emit a rich structured execution event. - Delivered to: - Per-node WS subscribers → powers the single-node execution pane - Per-user WS subscribers → powers the global multi-node execution bus @@ -169,7 +252,6 @@ "timestamp": datetime.utcnow().isoformat(), "data": data, } - # Deliver — avoid duplicates if same queue is in both lists seen = set() for q in node_qs + user_qs: if id(q) not in seen: diff --git a/ai-hub/app/db/migrate.py b/ai-hub/app/db/migrate.py index 2ae52de..696524a 100644 --- a/ai-hub/app/db/migrate.py +++ b/ai-hub/app/db/migrate.py @@ -57,15 +57,18 @@ else: logger.info(f"Column '{col_name}' already exists in 'sessions'.") - # Agent Nodes table — create if missing (handled by create_all, but add - # any new columns that might be added post-creation) + # Agent Nodes table migrations if inspector.has_table("agent_nodes"): node_columns = [c["name"] for c in inspector.get_columns("agent_nodes")] node_required_columns = [ - ("invite_token", "TEXT"), - ("last_status", "TEXT"), - ("last_seen_at", "DATETIME"), - ("capabilities", "TEXT"), + ("display_name", "TEXT"), + ("registered_by", "TEXT"), + ("skill_config", "TEXT"), + ("invite_token", "TEXT"), + ("is_active", "INTEGER"), + ("last_status", "TEXT"), + ("last_seen_at", "DATETIME"), + ("capabilities", "TEXT"), ] for col_name, col_type in node_required_columns: if col_name not in node_columns: @@ -79,5 +82,7 @@ logger.info("Database migrations complete.") + + if __name__ == "__main__": run_migrations() diff --git a/ai-hub/app/db/models.py b/ai-hub/app/db/models.py index 27f3eaa..9626f8e 100644 --- a/ai-hub/app/db/models.py +++ b/ai-hub/app/db/models.py @@ -332,31 +332,80 @@ class AgentNode(Base): """ - Persistent record of an Agent Node registered by a user. - Stores the node's identity, capabilities, and invite token. - The 'live' connection state is managed separately in NodeRegistryService (in-memory). + Admin-configured Agent Node. + Only admins register and configure nodes. Groups are then granted access. + Users see nodes available to their group and can attach them to sessions. + + Lifecycle: + 1. Admin creates the node record here (description, skill_config, invite_token). + 2. Admin deploys the client-side node software with the generated config YAML. + 3. Node connects → last_status flips to 'online'. + 4. Admin grants access to one or more groups (NodeGroupAccess). + 5. Users in those groups see the node in preferences / session setup. """ __tablename__ = 'agent_nodes' id = Column(Integer, primary_key=True, index=True) - # Human-readable node identifier set in the node's YAML config + # Stable identifier used in the node's YAML config (e.g. "dev-macbook-m3") node_id = Column(String, unique=True, index=True, nullable=False) - # Owner of this node - user_id = Column(String, ForeignKey('users.id'), nullable=False) - # Human description of the node (e.g., "MacBook Pro M3 - Dev Machine") + # Human-readable name shown in the UI + display_name = Column(String, nullable=False) + # Rich description — like a skill description; tells users what this node is for description = Column(String, nullable=True) - # JSON of capabilities: {"shell": "v1", "browser": "playwright-sync-bridge"} + # Admin user who registered this node + registered_by = Column(String, ForeignKey('users.id'), nullable=False) + # Skill enablement toggles + per-skill config + # Example: + # { + # "shell": {"enabled": true, "cwd_jail": "/home/user/projects"}, + # "browser": {"enabled": false}, + # "sync": {"enabled": true, "max_file_size_mb": 50} + # } + skill_config = Column(JSON, default={ + "shell": {"enabled": True}, + "browser": {"enabled": True}, + "sync": {"enabled": True}, + }, nullable=False) + # Actual capabilities reported by the node on connect (read-only, set by node) capabilities = Column(JSON, default={}, nullable=True) - # Pre-signed invite token (generated at "Download Your Node" step) + # Pre-signed invite token generated at node creation (used in downloaded config YAML) invite_token = Column(String, unique=True, nullable=True, index=True) - # Last known status: 'online', 'offline', 'stale' + # Whether this node is administratively active (can be disabled without deleting) + is_active = Column(Boolean, default=True, nullable=False) + # Live status updated by NodeRegistryService: 'online' | 'offline' | 'stale' last_status = Column(String, default="offline", nullable=False) # Last heartbeat timestamp last_seen_at = Column(DateTime, nullable=True) created_at = Column(DateTime, default=datetime.utcnow, nullable=False) - owner = relationship("User") + registered_by_user = relationship("User", foreign_keys=[registered_by]) + # Groups that have been granted access to this node + group_access = relationship("NodeGroupAccess", back_populates="node", cascade="all, delete-orphan") def __repr__(self): - return f"" + return f"" + + +class NodeGroupAccess(Base): + """ + Grants a group access to a specific agent node. + Admin sets this; users in the group can then see and use the node. + """ + __tablename__ = 'node_group_access' + + id = Column(Integer, primary_key=True, index=True) + node_id = Column(String, ForeignKey('agent_nodes.node_id'), nullable=False, index=True) + group_id = Column(String, ForeignKey('groups.id'), nullable=False, index=True) + # access_level: 'view' (see but not use), 'use' (can attach to session), 'admin' (can config) + access_level = Column(String, default="use", nullable=False) + granted_by = Column(String, ForeignKey('users.id'), nullable=False) + granted_at = Column(DateTime, default=datetime.utcnow, nullable=False) + + node = relationship("AgentNode", back_populates="group_access") + group = relationship("Group") + granted_by_user = relationship("User", foreign_keys=[granted_by]) + + def __repr__(self): + return f"" + diff --git a/ai-hub/app/db/session.py b/ai-hub/app/db/session.py index 417de48..6f46c4d 100644 --- a/ai-hub/app/db/session.py +++ b/ai-hub/app/db/session.py @@ -28,6 +28,24 @@ # It's the standard way to interact with the database in SQLAlchemy. SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) +from contextlib import contextmanager + +@contextmanager +def get_db_session(): + """ + Context-manager database session for use outside of FastAPI request scope. + Used by background services (e.g. NodeRegistryService) that need their own session. + """ + db = SessionLocal() + try: + yield db + db.commit() + except Exception: + db.rollback() + raise + finally: + db.close() + def create_db_and_tables(): """ Creates all database tables defined by models inheriting from Base.