Newer
Older
cortex-hub / ai-hub / app / api / routes / nodes.py
"""
Agent Node REST + WebSocket API
Admin-managed nodes, group access control, and user-facing live streaming.

Admin endpoints (require role=admin):
  POST  /nodes/                         — Create node registration + generate invite_token
  GET   /nodes/admin                    — List all nodes (admin view, full detail)
  GET   /nodes/admin/{node_id}          — Full admin detail including invite_token
  PATCH /nodes/admin/{node_id}          — Update node config (description, skill_config, is_active)
  POST  /nodes/admin/{node_id}/access   — Grant group access to a node
  DELETE /nodes/admin/{node_id}/access/{group_id} — Revoke group access

User endpoints (scoped to caller's group):
  GET  /nodes/                          — List accessible nodes (user view, no sensitive data)
  GET  /nodes/{node_id}/status          — Quick online/offline probe
  POST /nodes/{node_id}/dispatch        — Dispatch a task to a node
  PATCH /nodes/preferences              — Update user's default_nodes + data_source prefs

WebSocket (real-time streaming):
  WS   /nodes/{node_id}/stream          — Single-node live execution stream
  WS   /nodes/stream/all?user_id=...    — All-nodes global bus (multi-pane UI)
"""
import asyncio
import json
import queue
import uuid
import secrets
import logging
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect, Depends
from sqlalchemy.orm import Session

from app.api.dependencies import ServiceContainer, get_db
from app.api import schemas
from app.db import models

logger = logging.getLogger(__name__)

HEARTBEAT_INTERVAL_S = 5


def create_nodes_router(services: ServiceContainer) -> APIRouter:
    router = APIRouter(prefix="/nodes", tags=["Agent Nodes"])

    def _registry():
        return services.node_registry_service

    def _require_admin(user_id: str, db: Session):
        user = db.query(models.User).filter(models.User.id == user_id).first()
        if not user or user.role != "admin":
            raise HTTPException(status_code=403, detail="Admin access required.")
        return user

    # ==================================================================
    #  ADMIN ENDPOINTS
    # ==================================================================

    @router.post("/admin", response_model=schemas.AgentNodeAdminDetail, summary="[Admin] Register New Node")
    def admin_create_node(
        request: schemas.AgentNodeCreate,
        admin_id: str,
        db: Session = Depends(get_db)
    ):
        """
        Admin registers a new Agent Node.
        Returns the node record including a generated invite_token that must be
        placed in the node's config YAML before deployment.
        """
        _require_admin(admin_id, db)

        existing = db.query(models.AgentNode).filter(
            models.AgentNode.node_id == request.node_id
        ).first()
        if existing:
            raise HTTPException(status_code=409, detail=f"Node '{request.node_id}' already exists.")

        # Generate a cryptographically secure invite token
        invite_token = secrets.token_urlsafe(32)

        node = models.AgentNode(
            node_id=request.node_id,
            display_name=request.display_name,
            description=request.description,
            registered_by=admin_id,
            skill_config=request.skill_config.model_dump(),
            invite_token=invite_token,
            last_status="offline",
        )
        db.add(node)
        db.commit()
        db.refresh(node)

        logger.info(f"[admin] Created node '{request.node_id}' by admin {admin_id}")
        return _node_to_admin_detail(node, _registry())

    @router.get("/admin", response_model=list[schemas.AgentNodeAdminDetail], summary="[Admin] List All Nodes")
    def admin_list_nodes(admin_id: str, db: Session = Depends(get_db)):
        """Full node list for admin dashboard, including invite_token and skill config."""
        _require_admin(admin_id, db)
        nodes = db.query(models.AgentNode).all()
        return [_node_to_admin_detail(n, _registry()) for n in nodes]

    @router.get("/admin/{node_id}", response_model=schemas.AgentNodeAdminDetail, summary="[Admin] Get Node Detail")
    def admin_get_node(node_id: str, admin_id: str, db: Session = Depends(get_db)):
        _require_admin(admin_id, db)
        node = _get_node_or_404(node_id, db)
        return _node_to_admin_detail(node, _registry())

    @router.patch("/admin/{node_id}", response_model=schemas.AgentNodeAdminDetail, summary="[Admin] Update Node Config")
    def admin_update_node(
        node_id: str,
        update: schemas.AgentNodeUpdate,
        admin_id: str,
        db: Session = Depends(get_db)
    ):
        """Update display_name, description, skill_config toggles, or is_active."""
        _require_admin(admin_id, db)
        node = _get_node_or_404(node_id, db)

        if update.display_name is not None:
            node.display_name = update.display_name
        if update.description is not None:
            node.description = update.description
        if update.skill_config is not None:
            node.skill_config = update.skill_config.model_dump()
        if update.is_active is not None:
            node.is_active = update.is_active

        db.commit()
        db.refresh(node)
        return _node_to_admin_detail(node, _registry())

    @router.post("/admin/{node_id}/access", response_model=schemas.NodeAccessResponse, summary="[Admin] Grant Group Access")
    def admin_grant_access(
        node_id: str,
        grant: schemas.NodeAccessGrant,
        admin_id: str,
        db: Session = Depends(get_db)
    ):
        """Grant a group access to use this node in sessions."""
        _require_admin(admin_id, db)
        node = _get_node_or_404(node_id, db)

        existing = db.query(models.NodeGroupAccess).filter(
            models.NodeGroupAccess.node_id == node_id,
            models.NodeGroupAccess.group_id == grant.group_id
        ).first()
        if existing:
            existing.access_level = grant.access_level
            existing.granted_by = admin_id
            db.commit()
            db.refresh(existing)
            return existing

        access = models.NodeGroupAccess(
            node_id=node_id,
            group_id=grant.group_id,
            access_level=grant.access_level,
            granted_by=admin_id,
        )
        db.add(access)
        db.commit()
        db.refresh(access)
        return access

    @router.delete("/admin/{node_id}/access/{group_id}", summary="[Admin] Revoke Group Access")
    def admin_revoke_access(
        node_id: str,
        group_id: str,
        admin_id: str,
        db: Session = Depends(get_db)
    ):
        _require_admin(admin_id, db)
        access = db.query(models.NodeGroupAccess).filter(
            models.NodeGroupAccess.node_id == node_id,
            models.NodeGroupAccess.group_id == group_id
        ).first()
        if not access:
            raise HTTPException(status_code=404, detail="Access grant not found.")
        db.delete(access)
        db.commit()
        return {"message": f"Access revoked for group '{group_id}' on node '{node_id}'."}

    # ==================================================================
    #  USER-FACING ENDPOINTS
    # ==================================================================

    @router.get("/", response_model=list[schemas.AgentNodeUserView], summary="List Accessible Nodes")
    def list_accessible_nodes(user_id: str, db: Session = Depends(get_db)):
        """
        Returns nodes the calling user's group has access to.
        Merges live connection state from the in-memory registry.
        """
        user = db.query(models.User).filter(models.User.id == user_id).first()
        if not user:
            raise HTTPException(status_code=404, detail="User not found.")

        # Admin sees everything; users see only group-granted nodes
        if user.role == "admin":
            nodes = db.query(models.AgentNode).filter(models.AgentNode.is_active == True).all()
        else:
            # Nodes accessible via user's group
            accesses = db.query(models.NodeGroupAccess).filter(
                models.NodeGroupAccess.group_id == user.group_id
            ).all()
            node_ids = [a.node_id for a in accesses]
            nodes = db.query(models.AgentNode).filter(
                models.AgentNode.node_id.in_(node_ids),
                models.AgentNode.is_active == True
            ).all()

        registry = _registry()
        return [_node_to_user_view(n, registry) for n in nodes]

    @router.get("/{node_id}/status", summary="Quick Node Online Check")
    def get_node_status(node_id: str):
        live = _registry().get_node(node_id)
        if not live:
            return {"node_id": node_id, "status": "offline"}
        return {"node_id": node_id, "status": live._compute_status(), "stats": live.stats}

    @router.post("/{node_id}/dispatch", response_model=schemas.NodeDispatchResponse, summary="Dispatch Task to Node")
    def dispatch_to_node(node_id: str, request: schemas.NodeDispatchRequest):
        """
        Queue a shell or browser task to an online node.
        Emits task_assigned immediately so the live UI shows it.
        """
        registry = _registry()
        live = registry.get_node(node_id)
        if not live:
            raise HTTPException(status_code=503, detail=f"Node '{node_id}' is not connected.")

        task_id = str(uuid.uuid4())
        registry.emit(node_id, "task_assigned",
                       {"command": request.command, "session_id": request.session_id},
                       task_id=task_id)

        try:
            import sys
            sys.path.insert(0, "/app/poc-grpc-agent")
            from protos import agent_pb2
            from orchestrator.utils.crypto import sign_payload
            payload = request.command or json.dumps(request.browser_action)
            task_req = agent_pb2.TaskRequest(
                task_id=task_id,
                payload_json=payload,
                signature=sign_payload(payload),
                timeout_ms=request.timeout_ms,
                session_id=request.session_id or "",
            )
            live.queue.put(agent_pb2.ServerTaskMessage(task_request=task_req))
            registry.emit(node_id, "task_start", {"command": request.command}, task_id=task_id)
        except ImportError:
            logger.warning("[nodes] poc-grpc-agent not on path; dispatch is stub only.")

        return schemas.NodeDispatchResponse(task_id=task_id, status="accepted")

    @router.patch("/preferences", summary="Update User Node Preferences")
    def update_node_preferences(
        user_id: str,
        prefs: schemas.UserNodePreferences,
        db: Session = Depends(get_db)
    ):
        """
        Save the user's default_node_ids and data_source config into their preferences.
        The UI reads this to auto-attach nodes when a new session starts.
        """
        user = db.query(models.User).filter(models.User.id == user_id).first()
        if not user:
            raise HTTPException(status_code=404, detail="User not found.")
        existing_prefs = user.preferences or {}
        existing_prefs["nodes"] = prefs.model_dump()
        user.preferences = existing_prefs
        db.commit()
        return {"message": "Node preferences saved.", "nodes": prefs.model_dump()}

    @router.get("/preferences", response_model=schemas.UserNodePreferences, summary="Get User Node Preferences")
    def get_node_preferences(user_id: str, db: Session = Depends(get_db)):
        user = db.query(models.User).filter(models.User.id == user_id).first()
        if not user:
            raise HTTPException(status_code=404, detail="User not found.")
        node_prefs = (user.preferences or {}).get("nodes", {})
        return schemas.UserNodePreferences(**node_prefs) if node_prefs else schemas.UserNodePreferences()

    # ==================================================================
    #  WEBSOCKET — Single-node live event stream
    # ==================================================================

    @router.websocket("/{node_id}/stream")
    async def node_event_stream(websocket: WebSocket, node_id: str):
        """
        Single-node live event stream.
        Powers the per-node execution pane in the split-window UI.

        Event format:
          { event, label, node_id, task_id, timestamp, data }
        """
        await websocket.accept()
        registry = _registry()

        live = registry.get_node(node_id)
        await websocket.send_json({
            "event": "snapshot",
            "node_id": node_id,
            "timestamp": _now(),
            "data": live.to_dict() if live else {"status": "offline"},
        })

        q: queue.Queue = queue.Queue()
        registry.subscribe_node(node_id, q)
        try:
            while True:
                _drain(q, websocket)
                await asyncio.sleep(HEARTBEAT_INTERVAL_S)
                live = registry.get_node(node_id)
                await websocket.send_json({
                    "event": "heartbeat", "node_id": node_id, "timestamp": _now(),
                    "data": {"status": live._compute_status() if live else "offline",
                             "stats": live.stats if live else {}},
                })
        except WebSocketDisconnect:
            pass
        except Exception as e:
            logger.error(f"[nodes/stream] {node_id}: {e}")
        finally:
            registry.unsubscribe_node(node_id, q)

    # ==================================================================
    #  WEBSOCKET — Multi-node global execution bus
    # ==================================================================

    @router.websocket("/stream/all")
    async def all_nodes_event_stream(websocket: WebSocket, user_id: str):
        """
        Multi-node global event bus for a user.
        Powers the split-window multi-pane execution UI.

        On connect: sends initial_snapshot with all live nodes.
        Ongoing:    streams all events from all user's nodes (disambiguated by node_id).
        Every 5s:   sends a mesh_heartbeat summary across all nodes.
        """
        await websocket.accept()
        registry = _registry()

        all_live = registry.list_nodes(user_id=user_id)
        await websocket.send_json({
            "event": "initial_snapshot",
            "user_id": user_id,
            "timestamp": _now(),
            "data": {"nodes": [n.to_dict() for n in all_live], "count": len(all_live)},
        })

        q: queue.Queue = queue.Queue()
        registry.subscribe_user(user_id, q)
        try:
            while True:
                _drain(q, websocket)
                await asyncio.sleep(HEARTBEAT_INTERVAL_S)
                live_nodes = registry.list_nodes(user_id=user_id)
                await websocket.send_json({
                    "event": "mesh_heartbeat",
                    "user_id": user_id,
                    "timestamp": _now(),
                    "data": {
                        "nodes": [{"node_id": n.node_id, "status": n._compute_status(), "stats": n.stats}
                                  for n in live_nodes]
                    },
                })
        except WebSocketDisconnect:
            pass
        except Exception as e:
            logger.error(f"[nodes/stream/all] user={user_id}: {e}")
        finally:
            registry.unsubscribe_user(user_id, q)

    return router


# ===========================================================================
#  Helpers
# ===========================================================================

def _get_node_or_404(node_id: str, db: Session) -> models.AgentNode:
    node = db.query(models.AgentNode).filter(models.AgentNode.node_id == node_id).first()
    if not node:
        raise HTTPException(status_code=404, detail=f"Node '{node_id}' not found.")
    return node


def _node_to_admin_detail(node: models.AgentNode, registry) -> schemas.AgentNodeAdminDetail:
    live = registry.get_node(node.node_id)
    status = live._compute_status() if live else node.last_status or "offline"
    stats = schemas.AgentNodeStats(**live.stats) if live else schemas.AgentNodeStats()
    return schemas.AgentNodeAdminDetail(
        node_id=node.node_id,
        display_name=node.display_name,
        description=node.description,
        skill_config=node.skill_config or {},
        capabilities=node.capabilities or {},
        invite_token=node.invite_token,
        is_active=node.is_active,
        last_status=status,
        last_seen_at=node.last_seen_at,
        created_at=node.created_at,
        registered_by=node.registered_by,
        group_access=[
            schemas.NodeAccessResponse(
                id=a.id, node_id=a.node_id, group_id=a.group_id,
                access_level=a.access_level, granted_at=a.granted_at
            ) for a in (node.group_access or [])
        ],
        stats=stats,
    )


def _node_to_user_view(node: models.AgentNode, registry) -> schemas.AgentNodeUserView:
    live = registry.get_node(node.node_id)
    status = live._compute_status() if live else node.last_status or "offline"
    skill_cfg = node.skill_config or {}
    available = [skill for skill, cfg in skill_cfg.items() if cfg.get("enabled", True)]
    return schemas.AgentNodeUserView(
        node_id=node.node_id,
        display_name=node.display_name,
        description=node.description,
        capabilities=node.capabilities or {},
        available_skills=available,
        last_status=status,
        last_seen_at=node.last_seen_at,
    )


def _now() -> str:
    from datetime import datetime
    return datetime.utcnow().isoformat()


async def _drain(q: queue.Queue, websocket: WebSocket):
    """Drain all pending queue items and send to websocket (non-blocking)."""
    while True:
        try:
            event = q.get_nowait()
            await websocket.send_json(event)
        except queue.Empty:
            break