"""
MCP (Model Context Protocol) Server Route — Anthropic / SSE Transport

Endpoints registered under /api/v1/mcp/*:
  GET  /mcp/sse       - Establish persistent SSE stream (per MCP spec)
  POST /mcp/messages  - Receive JSON-RPC 2.0 messages from client

Discovery:
  GET  /.well-known/mcp/manifest.json  (registered on root app in app.py)

Protocol flow:
  1. Client GETs /mcp/sse — server streams an `endpoint` event with the /messages URL
  2. Client POSTs JSON-RPC requests to /mcp/messages?session_id=<id>
  3. Server dispatches the tool, pushes JSON-RPC response back over the SSE stream
"""

import asyncio
import json
import uuid
import logging
from typing import Optional, AsyncIterator

from fastapi import APIRouter, Request, HTTPException, Query
from fastapi.responses import JSONResponse, StreamingResponse

from app.api.dependencies import ServiceContainer
from app.config import settings

logger = logging.getLogger(__name__)

# ─── In-process SSE session registry ─────────────────────────────────────────
# Maps session_id → asyncio.Queue of JSON-serializable dicts
_sse_sessions: dict[str, asyncio.Queue] = {}


def create_mcp_router(services: ServiceContainer) -> APIRouter:
    router = APIRouter(tags=["MCP"])

    # ─── SSE Transport — Client Connection ────────────────────────────────────
    @router.get("/sse")
    async def mcp_sse(
        request: Request,
        token: Optional[str] = Query(None, description="Optional user token (X-User-ID)"),
    ):
        """
        Establishes a Server-Sent Events (SSE) stream for an MCP client.

        Per the MCP spec the first event MUST be `endpoint`, whose data is the
        URL the client should POST messages to.
        """
        session_id = str(uuid.uuid4())
        queue: asyncio.Queue = asyncio.Queue()
        _sse_sessions[session_id] = queue

        # Build the absolute messages URL from the request's base URL
        base = str(request.base_url).rstrip("/")
        messages_url = f"{base}/api/v1/mcp/messages?session_id={session_id}"
        if token:
            messages_url += f"&token={token}"

        logger.info(f"[MCP] New SSE session: {session_id}")

        async def event_generator() -> AsyncIterator[str]:
            # Required first event — tells the client where to POST messages
            yield f"event: endpoint\ndata: {messages_url}\n\n"
            try:
                while True:
                    if await request.is_disconnected():
                        logger.info(f"[MCP] Client disconnected: {session_id}")
                        break
                    try:
                        msg = await asyncio.wait_for(queue.get(), timeout=25.0)
                        yield f"event: message\ndata: {json.dumps(msg)}\n\n"
                    except asyncio.TimeoutError:
                        yield ": keepalive\n\n"
            finally:
                _sse_sessions.pop(session_id, None)
                logger.info(f"[MCP] Session cleaned up: {session_id}")

        return StreamingResponse(
            event_generator(),
            media_type="text/event-stream",
            headers={
                "Cache-Control": "no-cache",
                "X-Accel-Buffering": "no",   # Disable Nginx buffering for SSE
                "Access-Control-Allow-Origin": "*",
            },
        )

    # ─── SSE Transport — Message Handler ─────────────────────────────────────
    @router.post("/messages")
    async def mcp_messages(
        request: Request,
        session_id: str = Query(...),
        token: Optional[str] = Query(None),
    ):
        """
        Receives a JSON-RPC 2.0 message from an MCP client.
        The response is pushed asynchronously back over the SSE stream.
        Returns 202 Accepted immediately so the client doesn't time out.
        """
        queue = _sse_sessions.get(session_id)
        if not queue:
            raise HTTPException(status_code=404, detail="MCP session not found or expired.")

        try:
            body = await request.json()
        except Exception:
            raise HTTPException(status_code=400, detail="Invalid JSON body.")

        rpc_id = body.get("id")
        method = body.get("method", "")
        params = body.get("params", {})

        logger.info(f"[MCP] [{session_id[:8]}] → {method}")

        # Dispatch asynchronously — don't block the HTTP response
        asyncio.create_task(_dispatch(queue, rpc_id, method, params, token, services))

        return JSONResponse(
            {"status": "accepted"},
            status_code=202,
            headers={"Access-Control-Allow-Origin": "*"},
        )

    return router


# ─── Dispatcher ───────────────────────────────────────────────────────────────

async def _dispatch(
    queue: asyncio.Queue,
    rpc_id,
    method: str,
    params: dict,
    token: Optional[str],
    services: ServiceContainer,
):
    """Run the method and push a JSON-RPC response onto the SSE queue."""
    try:
        result = await _execute(method, params, token, services)
        await queue.put({"jsonrpc": "2.0", "id": rpc_id, "result": result})
    except Exception as exc:
        logger.exception(f"[MCP] Tool error for '{method}': {exc}")
        await queue.put({
            "jsonrpc": "2.0",
            "id": rpc_id,
            "error": {"code": -32000, "message": str(exc)},
        })


async def _execute(method: str, params: dict, token: Optional[str], services: ServiceContainer):
    """Route a JSON-RPC method to its implementation."""

    # ── MCP Handshake ─────────────────────────────────────────────────────────
    if method == "initialize":
        return {
            "protocolVersion": "2024-11-05",
            "capabilities": {"tools": {}},
            "serverInfo": {"name": "Cortex Hub", "version": "1.0.0"},
        }

    if method == "ping":
        return {}

    # ── Tool Discovery ────────────────────────────────────────────────────────
    if method == "tools/list":
        return {
            "tools": [
                _tool_def("list_nodes",
                          "List all agent nodes in the Cortex swarm mesh and their status.",
                          {}),
                _tool_def("get_app_info",
                          "Get metadata about this Cortex Hub instance.",
                          {}),
                _tool_def("get_node_details",
                          "Get full details for a specific agent node.",
                          {"node_id": {"type": "string", "description": "Unique node ID"}},
                          required=["node_id"]),
                _tool_def("list_agents",
                          "List all autonomous agents configured in the system.",
                          {}),
                _tool_def("list_skills",
                          "List all skill folders (tool libraries) registered in the system.",
                          {}),
            ]
        }

    # ── Tool Execution ────────────────────────────────────────────────────────
    if method == "tools/call":
        name = params.get("name", "")
        args = params.get("arguments", {})
        return await _call_tool(name, args, token, services)

    raise ValueError(f"Unknown method: '{method}'")


def _tool_def(name: str, description: str, properties: dict, required: list = None) -> dict:
    schema = {"type": "object", "properties": properties}
    if required:
        schema["required"] = required
    return {"name": name, "description": description, "inputSchema": schema}


# ─── Tool Implementations ─────────────────────────────────────────────────────

async def _call_tool(name: str, args: dict, token: Optional[str], services: ServiceContainer) -> dict:
    """Execute a named tool and return a standard MCP content block."""

    def _ok(data) -> dict:
        text = json.dumps(data, indent=2, default=str) if not isinstance(data, str) else data
        return {"content": [{"type": "text", "text": text}]}

    # Run DB queries in a thread pool so we don't block the event loop
    loop = asyncio.get_running_loop()

    if name == "list_nodes":
        def _query():
            from app.db.session import get_db_session
            from app.db import models
            with get_db_session() as db:
                rows = db.query(models.AgentNode).all()
                return {
                    "nodes": [
                        {
                            "id": n.node_id,
                            "name": n.display_name,
                            "status": n.last_status,
                            "os": (n.capabilities or {}).get("os"),
                            "is_active": n.is_active,
                        }
                        for n in rows
                    ]
                }
        return _ok(await loop.run_in_executor(None, _query))

    if name == "get_app_info":
        def _query():
            from app.db.session import get_db_session
            from app.db import models
            with get_db_session() as db:
                total = db.query(models.AgentNode).count()
                online = db.query(models.AgentNode).filter(models.AgentNode.last_status == "online").count()
            return {
                "name": "Cortex Hub",
                "version": "1.0.0",
                "capabilities": ["swarms", "webmcp", "mcp-sse", "voice-chat", "rag"],
                "nodes": {"total": total, "online": online},
                "mcp_transport": "sse",
                "sse_endpoint": f"{settings.HUB_PUBLIC_URL}/api/v1/mcp/sse",
            }
        return _ok(await loop.run_in_executor(None, _query))

    if name == "get_node_details":
        node_id = args.get("node_id")
        if not node_id:
            raise ValueError("node_id is required.")
        def _query():
            from app.db.session import get_db_session
            from app.db import models
            with get_db_session() as db:
                n = db.query(models.AgentNode).filter(models.AgentNode.node_id == node_id).first()
                if not n:
                    return None
                return {
                    "node_id": n.node_id,
                    "display_name": n.display_name,
                    "description": n.description,
                    "status": n.last_status,
                    "is_active": n.is_active,
                    "capabilities": n.capabilities,
                    "skill_config": n.skill_config,
                    "registered_by": n.registered_by,
                    "last_seen_at": str(n.last_seen_at) if n.last_seen_at else None,
                }
        result = await loop.run_in_executor(None, _query)
        if result is None:
            raise ValueError(f"Node '{node_id}' not found.")
        return _ok(result)

    if name == "list_agents":
        def _query():
            from app.db.session import get_db_session
            from app.db import models
            with get_db_session() as db:
                rows = db.query(models.AgentInstance).all()
                return {
                    "agents": [
                        {
                            "id": str(a.id),
                            "name": a.template.name if a.template else None,
                            "status": a.status,
                            "node": a.mesh_node_id,
                            "last_heartbeat": str(a.last_heartbeat) if a.last_heartbeat else None,
                            "total_runs": a.total_runs,
                            "quality_score": a.latest_quality_score,
                        }
                        for a in rows
                    ]
                }
        return _ok(await loop.run_in_executor(None, _query))

    if name == "list_skills":
        def _query():
            from app.db.session import get_db_session
            from app.db import models
            with get_db_session() as db:
                rows = db.query(models.Skill).filter(models.Skill.is_enabled == True).all()
                return {
                    "skills": [
                        {
                            "id": s.id,
                            "name": s.name,
                            "description": s.description,
                            "type": s.skill_type,
                        }
                        for s in rows
                    ]
                }
        return _ok(await loop.run_in_executor(None, _query))

    raise ValueError(f"Unknown tool: '{name}'")
