"""
Agent Node REST + WebSocket API
Exposes the live node registry and execution event bus to the AI Hub UI.
Endpoints:
GET /nodes — List all nodes for the given user_id
GET /nodes/{node_id} — Full live node status
GET /nodes/{node_id}/status — Quick online/offline probe
POST /nodes/{node_id}/dispatch — Dispatch a task to a node
WS /nodes/{node_id}/stream — Live event stream for ONE node
WS /nodes/stream/all?user_id=... — Live event stream for ALL user's nodes
"""
import asyncio
import json
import queue
import uuid
import logging
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect, Depends
from sqlalchemy.orm import Session
from app.api.dependencies import ServiceContainer, get_db
from app.api import schemas
from app.db import models
logger = logging.getLogger(__name__)
HEARTBEAT_INTERVAL_S = 5 # How often to push a periodic heartbeat to WS clients
def create_nodes_router(services: ServiceContainer) -> APIRouter:
router = APIRouter(prefix="/nodes", tags=["Agent Nodes"])
def _registry():
return services.node_registry_service
# ------------------------------------------------------------------ #
# GET /nodes — list all nodes for a user #
# ------------------------------------------------------------------ #
@router.get("/", response_model=list[schemas.AgentNodeSummary], summary="List Agent Nodes")
def list_nodes(user_id: str, db: Session = Depends(get_db)):
"""
Returns all agent nodes ever registered under a given user.
Merges live connection state (in-memory) with the persistent DB record.
"""
registry = _registry()
db_nodes = db.query(models.AgentNode).filter(
models.AgentNode.user_id == user_id
).all()
result = []
for db_node in db_nodes:
live = registry.get_node(db_node.node_id)
status = live._compute_status() if live else "offline"
last_seen = live.last_heartbeat_at if live else db_node.last_seen_at
result.append(schemas.AgentNodeSummary(
node_id=db_node.node_id,
user_id=db_node.user_id,
description=db_node.description,
capabilities=db_node.capabilities or {},
status=status,
last_seen_at=last_seen,
created_at=db_node.created_at,
))
return result
# ------------------------------------------------------------------ #
# GET /nodes/{node_id} — full live status #
# ------------------------------------------------------------------ #
@router.get("/{node_id}", response_model=schemas.AgentNodeStatusResponse, summary="Get Node Live Status")
def get_node_status(node_id: str, db: Session = Depends(get_db)):
registry = _registry()
live = registry.get_node(node_id)
if live:
d = live.to_dict()
return schemas.AgentNodeStatusResponse(
node_id=d["node_id"], user_id=d["user_id"],
description=d["description"], capabilities=d["capabilities"],
stats=schemas.AgentNodeStats(**d["stats"]),
status=d["status"],
connected_at=d["connected_at"],
last_heartbeat_at=d["last_heartbeat_at"],
)
db_node = db.query(models.AgentNode).filter(models.AgentNode.node_id == node_id).first()
if not db_node:
raise HTTPException(status_code=404, detail=f"Node '{node_id}' not found.")
return schemas.AgentNodeStatusResponse(
node_id=db_node.node_id, user_id=db_node.user_id,
description=db_node.description, capabilities=db_node.capabilities or {},
stats=schemas.AgentNodeStats(), status="offline",
)
# ------------------------------------------------------------------ #
# GET /nodes/{node_id}/status — quick probe #
# ------------------------------------------------------------------ #
@router.get("/{node_id}/status", summary="Quick Node Online Check")
def get_node_online_status(node_id: str):
live = _registry().get_node(node_id)
if not live:
return {"node_id": node_id, "status": "offline"}
return {"node_id": node_id, "status": live._compute_status(), "stats": live.stats}
# ------------------------------------------------------------------ #
# POST /nodes/{node_id}/dispatch — send a task #
# ------------------------------------------------------------------ #
@router.post("/{node_id}/dispatch", response_model=schemas.NodeDispatchResponse, summary="Dispatch Task to Node")
def dispatch_to_node(node_id: str, request: schemas.NodeDispatchRequest):
"""
Queues a task for an online node via its gRPC outbound queue.
Emits task_assigned event immediately for live UI feedback.
"""
registry = _registry()
live = registry.get_node(node_id)
if not live:
raise HTTPException(status_code=503, detail=f"Node '{node_id}' is not connected.")
task_id = str(uuid.uuid4())
# Emit to live UI immediately
registry.emit(node_id, "task_assigned",
{"command": request.command, "session_id": request.session_id},
task_id=task_id)
try:
import sys, os
sys.path.insert(0, "/app/poc-grpc-agent")
from protos import agent_pb2
from orchestrator.utils.crypto import sign_payload
payload = request.command or json.dumps(request.browser_action)
sig = sign_payload(payload)
task_req = agent_pb2.TaskRequest(
task_id=task_id,
payload_json=payload,
signature=sig,
timeout_ms=request.timeout_ms,
session_id=request.session_id or "",
)
live.queue.put(agent_pb2.ServerTaskMessage(task_request=task_req))
registry.emit(node_id, "task_start", {"command": request.command}, task_id=task_id)
logger.info(f"[nodes] Dispatched task {task_id} to {node_id}")
except ImportError as e:
logger.warning(f"[nodes] poc-grpc-agent not installed: {e}. Dispatch is stub only.")
return schemas.NodeDispatchResponse(task_id=task_id, status="accepted")
# ------------------------------------------------------------------ #
# WS /nodes/{node_id}/stream — single node live event stream #
# ------------------------------------------------------------------ #
@router.websocket("/{node_id}/stream")
async def node_event_stream(websocket: WebSocket, node_id: str):
"""
WebSocket stream for a single node's execution events.
Powers the single-node execution pane in the UI.
Message format:
{
"event": "task_stdout",
"label": "📤 Output",
"node_id": "node-alpha",
"task_id": "abc-123",
"timestamp": "2026-03-04T06:00:00",
"data": { ... }
}
"""
await websocket.accept()
registry = _registry()
# Push current snapshot immediately
live = registry.get_node(node_id)
await websocket.send_json({
"event": "snapshot",
"node_id": node_id,
"timestamp": __import__("datetime").datetime.utcnow().isoformat(),
"data": live.to_dict() if live else {"status": "offline"},
})
q: queue.Queue = queue.Queue()
registry.subscribe_node(node_id, q)
try:
while True:
# Drain all pending events first
drained = False
while True:
try:
event = q.get_nowait()
await websocket.send_json(event)
drained = True
except queue.Empty:
break
# Periodic heartbeat
await asyncio.sleep(HEARTBEAT_INTERVAL_S)
live = registry.get_node(node_id)
await websocket.send_json({
"event": "heartbeat",
"node_id": node_id,
"timestamp": __import__("datetime").datetime.utcnow().isoformat(),
"data": {"status": live._compute_status() if live else "offline",
"stats": live.stats if live else {}},
})
except WebSocketDisconnect:
pass
except Exception as e:
logger.error(f"[nodes/stream] {node_id}: {e}")
finally:
registry.unsubscribe_node(node_id, q)
# ------------------------------------------------------------------ #
# WS /nodes/stream/all — multi-node global execution bus #
# ------------------------------------------------------------------ #
@router.websocket("/stream/all")
async def all_nodes_event_stream(websocket: WebSocket, user_id: str):
"""
WebSocket stream for ALL of a user's node execution events combined.
Powers the multi-pane split-window execution UI.
The client receives events from every node the user owns,
disambiguated by the 'node_id' field in each event.
Use this to render:
- Per-node status columns (split by node_id)
- A unified chronological execution log
- Error/retry surface for all nodes simultaneously
"""
await websocket.accept()
registry = _registry()
# Send initial snapshot of all user's live nodes
all_nodes = registry.list_nodes(user_id=user_id)
await websocket.send_json({
"event": "initial_snapshot",
"user_id": user_id,
"timestamp": __import__("datetime").datetime.utcnow().isoformat(),
"data": {
"nodes": [n.to_dict() for n in all_nodes],
"count": len(all_nodes),
},
})
q: queue.Queue = queue.Queue()
registry.subscribe_user(user_id, q)
try:
while True:
while True:
try:
event = q.get_nowait()
await websocket.send_json(event)
except queue.Empty:
break
await asyncio.sleep(HEARTBEAT_INTERVAL_S)
# Push a periodic mesh health summary
live_nodes = registry.list_nodes(user_id=user_id)
await websocket.send_json({
"event": "mesh_heartbeat",
"user_id": user_id,
"timestamp": __import__("datetime").datetime.utcnow().isoformat(),
"data": {
"nodes": [
{"node_id": n.node_id, "status": n._compute_status(), "stats": n.stats}
for n in live_nodes
]
},
})
except WebSocketDisconnect:
pass
except Exception as e:
logger.error(f"[nodes/stream/all] user={user_id}: {e}")
finally:
registry.unsubscribe_user(user_id, q)
return router