"""
MCP (Model Context Protocol) Server Route — Streamable HTTP + Legacy SSE Transport
Supports:
MCP spec 2025-11-25 — Streamable HTTP (primary, recommended)
MCP spec 2024-11-05 — HTTP+SSE (legacy, backwards-compat)
Endpoints (mounted under /api/v1/mcp/*):
POST /mcp/sse — Streamable HTTP: JSON-RPC in, JSON response out
POST /mcp/ — Same, aliased for clients using the base path
GET /mcp/sse — Legacy SSE stream (sends endpoint event)
POST /mcp/messages — Legacy SSE message handler
Discovery:
GET /.well-known/mcp/manifest.json (mounted in app.py)
"""
import asyncio
import json
import uuid
import logging
from typing import Optional, AsyncIterator
from fastapi import APIRouter, Request, HTTPException, Query
from fastapi.responses import JSONResponse, StreamingResponse
from app.api.dependencies import ServiceContainer
from app.config import settings
logger = logging.getLogger(__name__)
MCP_VERSION = "2025-11-25" # Latest MCP specification version
# ─── In-process SSE session registry ─────────────────────────────────────────
# Maps session_id → asyncio.Queue of JSON-serializable dicts
_sse_sessions: dict[str, asyncio.Queue] = {}
def create_mcp_router(services: ServiceContainer) -> APIRouter:
router = APIRouter(tags=["MCP"])
# ─── SSE Transport — Client Connection ────────────────────────────────────
@router.get("/sse")
async def mcp_sse(
request: Request,
token: Optional[str] = Query(None, description="Optional user token (X-User-ID)"),
):
"""
Legacy SSE transport (MCP 2024-11-05).
Opens a persistent SSE stream; first event is `endpoint` telling the
client where to POST messages.
"""
session_id = str(uuid.uuid4())
queue: asyncio.Queue = asyncio.Queue()
_sse_sessions[session_id] = queue
base = str(request.base_url).rstrip("/")
messages_url = f"{base}/api/v1/mcp/messages?session_id={session_id}"
if token:
messages_url += f"&token={token}"
logger.info(f"[MCP] New SSE session: {session_id}")
async def event_generator() -> AsyncIterator[str]:
yield f"event: endpoint\ndata: {messages_url}\n\n"
try:
while True:
if await request.is_disconnected():
break
try:
msg = await asyncio.wait_for(queue.get(), timeout=25.0)
yield f"event: message\ndata: {json.dumps(msg)}\n\n"
except asyncio.TimeoutError:
yield ": keepalive\n\n"
finally:
_sse_sessions.pop(session_id, None)
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
"Access-Control-Allow-Origin": "*",
},
)
# ─── Streamable HTTP Transport (MCP 2025-11-25) ───────────────────────────
@router.post("/sse")
@router.post("/")
async def mcp_streamable_http(
request: Request,
token: Optional[str] = Query(None),
):
"""
Streamable HTTP transport (MCP 2025-11-25 / 2025-03-26).
Client POSTs JSON-RPC and receives the response synchronously.
"""
# Origin validation — MUST per MCP 2025-11-25 security spec
origin = request.headers.get("origin")
if origin:
allowed = [
"https://ai.jerxie.com",
"http://localhost:3000",
"http://localhost:8080",
]
# Also allow the server's own origin
server_host = request.headers.get("host", "")
allowed.append(f"https://{server_host}")
allowed.append(f"http://{server_host}")
if not any(origin.startswith(a) for a in allowed):
logger.warning(f"[MCP] Rejected request from disallowed origin: {origin}")
return JSONResponse(
{"jsonrpc": "2.0", "error": {"code": -32000, "message": "Forbidden origin"}},
status_code=403,
)
try:
body = await request.json()
except Exception:
raise HTTPException(status_code=400, detail="Invalid JSON body.")
# Batch requests (JSON array)
if isinstance(body, list):
results = []
for item in body:
results.append(await _handle_single(item, token, services))
return JSONResponse(
[r for r in results if r is not None],
headers={"Access-Control-Allow-Origin": "*", "MCP-Protocol-Version": MCP_VERSION},
)
# Single request
response = await _handle_single(body, token, services)
if response is None: # notification — no id
return JSONResponse(
None, status_code=202,
headers={"Access-Control-Allow-Origin": "*"},
)
# If initialize, attach a session ID (MAY per spec)
headers = {"Access-Control-Allow-Origin": "*", "MCP-Protocol-Version": MCP_VERSION}
if body.get("method") == "initialize":
headers["Mcp-Session-Id"] = str(uuid.uuid4())
return JSONResponse(response, headers=headers)
# ─── SSE Transport — Message Handler ─────────────────────────────────────
@router.post("/messages")
async def mcp_messages(
request: Request,
session_id: str = Query(...),
token: Optional[str] = Query(None),
):
"""
Legacy SSE message handler — receives JSON-RPC 2.0 from a client that
first established a GET /sse stream, then pushes results over that stream.
"""
queue = _sse_sessions.get(session_id)
if not queue:
raise HTTPException(status_code=404, detail="MCP session not found or expired.")
try:
body = await request.json()
except Exception:
raise HTTPException(status_code=400, detail="Invalid JSON body.")
rpc_id = body.get("id")
method = body.get("method", "")
params = body.get("params", {})
logger.info(f"[MCP] [{session_id[:8]}] → {method}")
asyncio.create_task(_dispatch(queue, rpc_id, method, params, token, services))
return JSONResponse(
{"status": "accepted"},
status_code=202,
headers={"Access-Control-Allow-Origin": "*"},
)
return router
# ─── Single-request handler (used by Streamable HTTP) ────────────────────────
async def _handle_single(body: dict, token: Optional[str], services: ServiceContainer):
"""Process one JSON-RPC object; return response dict or None for notifications."""
rpc_id = body.get("id") # None means it's a notification
method = body.get("method", "")
params = body.get("params", {})
logger.info(f"[MCP-HTTP] → {method}")
try:
result = await _execute(method, params, token, services)
if rpc_id is None:
return None # notification — no response
return {"jsonrpc": "2.0", "id": rpc_id, "result": result}
except Exception as exc:
logger.exception(f"[MCP-HTTP] Error for '{method}': {exc}")
if rpc_id is None:
return None
return {
"jsonrpc": "2.0",
"id": rpc_id,
"error": {"code": -32000, "message": str(exc)},
}
# ─── Dispatcher ───────────────────────────────────────────────────────────────
async def _dispatch(
queue: asyncio.Queue,
rpc_id,
method: str,
params: dict,
token: Optional[str],
services: ServiceContainer,
):
"""Run the method and push a JSON-RPC response onto the SSE queue."""
try:
result = await _execute(method, params, token, services)
await queue.put({"jsonrpc": "2.0", "id": rpc_id, "result": result})
except Exception as exc:
logger.exception(f"[MCP] Tool error for '{method}': {exc}")
await queue.put({
"jsonrpc": "2.0",
"id": rpc_id,
"error": {"code": -32000, "message": str(exc)},
})
async def _execute(method: str, params: dict, token: Optional[str], services: ServiceContainer):
"""Route a JSON-RPC method to its implementation."""
# ── MCP Handshake ─────────────────────────────────────────────────────────
if method == "initialize":
return {
"protocolVersion": MCP_VERSION,
"capabilities": {"tools": {}},
"serverInfo": {"name": "Cortex Hub", "version": "1.0.0"},
}
if method == "ping":
return {}
# ── Tool Discovery ────────────────────────────────────────────────────────
if method == "tools/list":
return {
"tools": [
_tool_def("list_nodes",
"List all agent nodes in the Cortex swarm mesh and their status.",
{}),
_tool_def("get_app_info",
"Get metadata about this Cortex Hub instance.",
{}),
_tool_def("get_node_details",
"Get full details for a specific agent node.",
{"node_id": {"type": "string", "description": "Unique node ID"}},
required=["node_id"]),
_tool_def("list_agents",
"List all autonomous agents configured in the system.",
{}),
_tool_def("list_skills",
"List all skill folders (tool libraries) registered in the system.",
{}),
]
}
# ── Tool Execution ────────────────────────────────────────────────────────
if method == "tools/call":
name = params.get("name", "")
args = params.get("arguments", {})
return await _call_tool(name, args, token, services)
raise ValueError(f"Unknown method: '{method}'")
def _tool_def(name: str, description: str, properties: dict, required: list = None) -> dict:
schema = {"type": "object", "properties": properties}
if required:
schema["required"] = required
return {"name": name, "description": description, "inputSchema": schema}
# ─── Tool Implementations ─────────────────────────────────────────────────────
async def _call_tool(name: str, args: dict, token: Optional[str], services: ServiceContainer) -> dict:
"""Execute a named tool and return a standard MCP content block."""
def _ok(data) -> dict:
text = json.dumps(data, indent=2, default=str) if not isinstance(data, str) else data
return {"content": [{"type": "text", "text": text}]}
# Run DB queries in a thread pool so we don't block the event loop
loop = asyncio.get_running_loop()
if name == "list_nodes":
def _query():
from app.db.session import get_db_session
from app.db import models
with get_db_session() as db:
rows = db.query(models.AgentNode).all()
return {
"nodes": [
{
"id": n.node_id,
"name": n.display_name,
"status": n.last_status,
"os": (n.capabilities or {}).get("os"),
"is_active": n.is_active,
}
for n in rows
]
}
return _ok(await loop.run_in_executor(None, _query))
if name == "get_app_info":
def _query():
from app.db.session import get_db_session
from app.db import models
with get_db_session() as db:
total = db.query(models.AgentNode).count()
online = db.query(models.AgentNode).filter(models.AgentNode.last_status == "online").count()
return {
"name": "Cortex Hub",
"version": "1.0.0",
"capabilities": ["swarms", "webmcp", "mcp-sse", "voice-chat", "rag"],
"nodes": {"total": total, "online": online},
"mcp_transport": "sse",
"sse_endpoint": f"{settings.HUB_PUBLIC_URL}/api/v1/mcp/sse",
}
return _ok(await loop.run_in_executor(None, _query))
if name == "get_node_details":
node_id = args.get("node_id")
if not node_id:
raise ValueError("node_id is required.")
def _query():
from app.db.session import get_db_session
from app.db import models
with get_db_session() as db:
n = db.query(models.AgentNode).filter(models.AgentNode.node_id == node_id).first()
if not n:
return None
return {
"node_id": n.node_id,
"display_name": n.display_name,
"description": n.description,
"status": n.last_status,
"is_active": n.is_active,
"capabilities": n.capabilities,
"skill_config": n.skill_config,
"registered_by": n.registered_by,
"last_seen_at": str(n.last_seen_at) if n.last_seen_at else None,
}
result = await loop.run_in_executor(None, _query)
if result is None:
raise ValueError(f"Node '{node_id}' not found.")
return _ok(result)
if name == "list_agents":
def _query():
from app.db.session import get_db_session
from app.db import models
with get_db_session() as db:
rows = db.query(models.AgentInstance).all()
return {
"agents": [
{
"id": str(a.id),
"name": a.template.name if a.template else None,
"status": a.status,
"node": a.mesh_node_id,
"last_heartbeat": str(a.last_heartbeat) if a.last_heartbeat else None,
"total_runs": a.total_runs,
"quality_score": a.latest_quality_score,
}
for a in rows
]
}
return _ok(await loop.run_in_executor(None, _query))
if name == "list_skills":
def _query():
from app.db.session import get_db_session
from app.db import models
with get_db_session() as db:
rows = db.query(models.Skill).filter(models.Skill.is_enabled == True).all()
return {
"skills": [
{
"id": s.id,
"name": s.name,
"description": s.description,
"type": s.skill_type,
}
for s in rows
]
}
return _ok(await loop.run_in_executor(None, _query))
raise ValueError(f"Unknown tool: '{name}'")