"""
NodeRegistry Service — AI Hub Integration Layer
Manages live Agent Node registrations, their in-memory gRPC queues,
and a rich event bus for real-time UI streaming (split-pane terminal UX).

Persistence strategy (M2):
- The in-memory dict is the live connection cache (fast, gRPC-queue holder).
- The DB (via get_db_session) is the source of truth for node identity across reboots.
- On every register/deregister/heartbeat, we upsert the AgentNode DB record.
- On Hub startup all DB nodes are "offline"; they go "online" when they reconnect.
"""
import threading
import queue
import time
import logging
import re
import uuid
import asyncio
from datetime import datetime
from typing import Dict, Optional, List, Any
from concurrent.futures import ThreadPoolExecutor

logger = logging.getLogger(__name__)


# All event types emitted across the system — rendered in the live UI
EVENT_TYPES = {
    "node_online":    "🟢 Node Online",
    "node_offline":   "⚫ Node Offline",
    "node_stale":     "🟡 Node Stale",
    "heartbeat":      "💓 Heartbeat",
    "task_assigned":  "📦 Task Assigned",
    "task_start":     "🚀 Task Starting",
    "task_stdout":    "📤 Output",
    "task_done":      "✅ Task Done",
    "task_error":     "❌ Task Error",
    "task_cancelled": "🛑 Task Cancelled",
    "task_retry":     "🔄 Retry",
    "sync_manifest":  "📋 Manifest Sync",
    "sync_push":      "📁 File Push",
    "sync_drift":     "⚠️ Drift Detected",
    "sync_recovery":  "🏥 Recovery",
    "sync_locked":    "🔒 Workspace Locked",
    "sync_unlocked":  "🔓 Workspace Unlocked",
    "info":           "ℹ️ Info",
}


class LiveNodeRecord:
    """Represents a single connected Agent Node and its associated live state."""
    def __init__(self, node_id: str, user_id: str, metadata: dict):
        self.node_id = node_id
        self.user_id = user_id           # Owner — maps node to a Hub user
        self.metadata = metadata         # desc, caps (capabilities dict)
        # Increased queue size to 1000 to handle high-concurrency file sync without dropping interactive tasks
        self.queue: queue.PriorityQueue = queue.PriorityQueue(maxsize=1000)
        self.stats: dict = {
            "active_worker_count": 0,
            "cpu_usage_percent": 0.0,
            "memory_usage_percent": 0.0,
            "running": [],
        }
        self.connected_at: datetime = datetime.utcnow()
        self.last_heartbeat_at: datetime = datetime.utcnow()
        self.session_id: str = str(uuid.uuid4())
        self.terminal_history: List[str] = [] # Recent PTY lines for AI reading
        self._registry_executor = None # Set by registry

    def send_message(self, msg: Any, priority: int = 2):
        """
        Thread-safe and Async-safe message dispatcher.
        priority: 0 (Admin/Control), 1 (Terminal/Interactive), 2 (File Sync)
        """
        item = (priority, time.time(), msg)
        
        def _blocking_put():
            try:
                self.queue.put(item, block=True, timeout=2.0)
            except queue.Full:
                logger.warning(f"[📋⚠️] Message dropped for {self.node_id}: outbound queue FULL. Node may be unresponsive.")
            except Exception as e:
                logger.error(f"[📋❌] Sync error sending to {self.node_id}: {e}")

        try:
            # Check if we are in an async loop (FastAPI context)
            loop = asyncio.get_running_loop()
            if loop.is_running():
                if self._registry_executor:
                    self._registry_executor.submit(_blocking_put)
                else:
                    # Fallback to fire-and-forget thread if executor not yet ready
                    threading.Thread(target=_blocking_put, daemon=True).start()
                return
        except RuntimeError:
            pass # Not in async loop
        
        # Standard sync put (from gRPC thread)
        _blocking_put()

    def update_stats(self, stats: dict):
        self.stats.update(stats)
        self.last_heartbeat_at = datetime.utcnow()

    def to_dict(self) -> dict:
        return {
            "node_id": self.node_id,
            "user_id": self.user_id,
            "description": self.metadata.get("desc", ""),
            "capabilities": self.metadata.get("caps", {}),
            "stats": self.stats,
            "connected_at": self.connected_at.isoformat(),
            "last_heartbeat_at": self.last_heartbeat_at.isoformat(),
            "status": self._compute_status(),
        }

    def _compute_status(self) -> str:
        delta = (datetime.utcnow() - self.last_heartbeat_at).total_seconds()
        if delta > 30:
            return "stale"
        return "online"
        
    def is_healthy(self) -> bool:
        """True if the node has reported metrics recently and has an active stream."""
        return self._compute_status() == "online"


class NodeRegistryService:
    """
    Persistent + in-memory registry of live Agent Nodes.

    Two-tier storage:
      Tier 1 — In-memory (_nodes dict):  live connections, gRPC queues, real-time stats.
      Tier 2 — Database (AgentNode model): node identity, capabilities, invite_token,
                                           last_status, last_seen_at survive Hub restarts.

    When a node reconnects after a Hub restart it calls SyncConfiguration again,
    which calls register() again → the DB record is updated to 'online'.
    Nodes that haven't reconnected stay in the DB with 'offline' status so the UI
    can still show them as known (but disconnected) nodes.
    """

    def __init__(self):
        self._nodes: Dict[str, LiveNodeRecord] = {}
        self._lock = threading.Lock()
        self._connection_history: Dict[str, List[datetime]] = {}
        # Per-node WS subscribers:  node_id  -> [queue, ...]
        self._node_listeners: Dict[str, List[queue.Queue]] = {}
        # Per-user WS subscribers:  user_id  -> [queue, ...] (ALL nodes for that user)
        self._user_listeners: Dict[str, List[queue.Queue]] = {}
        self._FLAP_WINDOW_S = 60
        self._FLAP_THRESHOLD = 5
        # Shared Hub-wide work executor to prevent thread-spawning leaks
        self.executor = ThreadPoolExecutor(max_workers=20, thread_name_prefix="RegistryWorker")

    # ------------------------------------------------------------------ #
    #  DB Helpers                                                          #
    # ------------------------------------------------------------------ #

    def _db_upsert_node(self, node_id: str, user_id: str, metadata: dict):
        """Update the AgentNode DB record on connect. The admin must pre-create the node."""
        from app.db.models import AgentNode
        from app.db.session import get_db_session
        try:
            with get_db_session() as db:
                record = db.query(AgentNode).filter(AgentNode.node_id == node_id).first()
                if record:
                    record.capabilities = metadata.get("caps", {})
                    record.last_status = "online"
                    record.last_seen_at = datetime.utcnow()
                    db.commit()
                else:
                    # Node not pre-registered by admin — log warning but don't crash
                    print(f"[NodeRegistry] WARNING: Node '{node_id}' connected but has no DB record. Admin must register it first.")
        except Exception as e:
            print(f"[NodeRegistry] DB upsert failed for {node_id}: {e}")

    def _db_mark_offline(self, node_id: str):
        """Update last_seen_at and mark last_status = 'offline' on disconnect."""
        from app.db.models import AgentNode
        from app.db.session import get_db_session
        try:
            with get_db_session() as db:
                record = db.query(AgentNode).filter(AgentNode.node_id == node_id).first()
                if record:
                    record.last_status = "offline"
                    record.last_seen_at = datetime.utcnow()
                    db.commit()
        except Exception as e:
            print(f"[NodeRegistry] DB mark-offline failed for {node_id}: {e}")

    def _db_update_heartbeat(self, node_id: str):
        """Bump last_seen_at on each heartbeat so we know when the node was last active."""
        from app.db.models import AgentNode
        from app.db.session import get_db_session
        try:
            with get_db_session() as db:
                record = db.query(AgentNode).filter(AgentNode.node_id == node_id).first()
                if record:
                    record.last_seen_at = datetime.utcnow()
                    record.last_status = "online"
                    db.commit()
        except Exception as e:
            print(f"[NodeRegistry] DB heartbeat update failed for {node_id}: {e}")

    def reset_all_statuses(self):
        """Reset all nodes in the DB to offline (call on Hub startup)."""
        from app.db.models import AgentNode
        from app.db.session import get_db_session
        try:
            with get_db_session() as db:
                db.query(AgentNode).update({"last_status": "offline"})
                db.commit()
                logger.info("[NodeRegistry] Reset all DB node statuses to 'offline'.")
        except Exception as e:
            logger.error(f"[NodeRegistry] Failed to reset DB statuses: {e}")

    def clear_memory_cache(self):
        """DANGEROUS: Clears all live connections from memory. Use for emergency resets."""
        with self._lock:
            count = len(self._nodes)
            self._nodes.clear()
            self._connection_history.clear()
            logger.info(f"[NodeRegistry] EMERGENCY: Cleared {count} nodes from memory cache.")
        return count
    
    def validate_invite_token(self, node_id: str, token: str) -> dict:
        """
        Directly validates an invite token against the DB.
        Used by the gRPC server to avoid HTTP self-call deadlocks during startup.
        """
        from app.db.models import AgentNode
        from app.db.session import get_db_session
        try:
            with get_db_session() as db:
                node = db.query(AgentNode).filter(
                    AgentNode.node_id == node_id,
                    AgentNode.invite_token == token,
                    AgentNode.is_active == True,
                ).first()
                if not node:
                    return {"valid": False, "reason": "Invalid token or unknown node."}
                return {
                    "valid": True,
                    "node_id": node.node_id,
                    "display_name": node.display_name,
                    "user_id": node.registered_by,
                    "skill_config": node.skill_config or {},
                }
        except Exception as e:
            logger.error(f"[NodeRegistry] Token validation exception: {e}")
            return {"valid": False, "reason": str(e)}


    # ------------------------------------------------------------------ #
    #  Registration                                                        #
    # ------------------------------------------------------------------ #

    def register(self, node_id: str, user_id: str, metadata: dict) -> LiveNodeRecord:
        """
        Register or re-register a node.
        Called from gRPC SyncConfiguration on every node connect/reconnect.
        Persists to DB so the node survives Hub restarts as a known entity.
        Includes a flapping detection check to warn about unstable containers.
        """
        with self._lock:
            # 1. Flapping Detection
            now = datetime.utcnow()
            history = self._connection_history.get(node_id, [])
            # Filter history to only include events within the window
            history = [t for t in history if (now - t).total_seconds() < self._FLAP_WINDOW_S]
            history.append(now)
            self._connection_history[node_id] = history

            if len(history) > self._FLAP_THRESHOLD:
                logger.warning(f"[⚠️] FLAPPING DETECTED for node '{node_id}': {len(history)} connects in {self._FLAP_WINDOW_S}s.")

            # 2. Register the live connection
            record = LiveNodeRecord(node_id=node_id, user_id=user_id, metadata=metadata)
            record._registry_executor = self.executor # Inject shared executor
            self._nodes[node_id] = record

        # Persist to DB asynchronously to avoid blocking gRPC stream setup during NFS lag
        self.executor.submit(self._db_upsert_node, node_id, user_id, metadata)

        logger.info(f"[📋] NodeRegistry: Registered {node_id} (owner: {user_id}) | Stats enabled")
        self.emit(node_id, "node_online", record.to_dict())
        return record

    def deregister(self, node_id: str, record: Optional[LiveNodeRecord] = None):
        """
        Remove a node from live_registry when its gRPC stream closes.
        Safely only removes if the passed record matches the current live registration.
        """
        with self._lock:
            current = self._nodes.get(node_id)
            if record and current != record:
                logger.debug(f"[📋] NodeRegistry: Ignoring deregister for {node_id} (session mismatch)")
                return
            
            node = self._nodes.pop(node_id, None)
            user_id = node.user_id if node else None

        self.executor.submit(self._db_mark_offline, node_id)
        self.emit(node_id, "node_offline", {"node_id": node_id, "user_id": user_id})
        logger.info(f"[📋] NodeRegistry: Deregistered {node_id}")

    # ------------------------------------------------------------------ #
    #  Queries                                                             #
    # ------------------------------------------------------------------ #

    def get_node(self, node_id: str) -> Optional[LiveNodeRecord]:
        """Returns a live record only if the node is currently connected."""
        with self._lock:
            return self._nodes.get(node_id)

    def list_nodes(self, user_id: Optional[str] = None) -> List[LiveNodeRecord]:
        """Returns only currently LIVE nodes (use the DB for the full list)."""
        with self._lock:
            if user_id:
                # Use robust string comparison to handle any object/string mismatch (e.g. UUID)
                live = [n for n in self._nodes.values() if str(n.user_id) == str(user_id)]
                # logger.debug(f"[Registry] list_nodes(user_id={user_id}): found {len(live)} nodes.")
                return live
            return list(self._nodes.values())

    def get_best(self, user_id: Optional[str] = None) -> Optional[str]:
        """Pick the least-loaded live node for a given owner."""
        nodes = self.list_nodes(user_id=user_id)
        if not nodes:
            return None
        return sorted(nodes, key=lambda n: n.stats.get("active_worker_count", 999))[0].node_id

    def update_stats(self, node_id: str, stats: dict):
        """Called every heartbeat interval. Updates in-memory stats and bumps DB last_seen."""
        with self._lock:
            node = self._nodes.get(node_id)
            if node:
                node.update_stats(stats)
                if stats.get("cpu_usage_percent", 0) > 0 or stats.get("memory_usage_percent", 0) > 0:
                    logger.debug(f"[💓] Heartbeat {node_id}: CPU {stats.get('cpu_usage_percent')}% | MEM {stats.get('memory_usage_percent')}%")
        # Persist heartbeat timestamp to DB asynchronously
        self.executor.submit(self._db_update_heartbeat, node_id)
        # Emit heartbeat event to live UI
        self.emit(node_id, "heartbeat", stats)

    # ------------------------------------------------------------------ #
    #  Event Bus                                                          #
    # ------------------------------------------------------------------ #

    def emit(self, node_id: str, event_type: str, data: Any = None, task_id: str = ""):
        """
        Emit a rich structured execution event.
        Delivered to:
          - Per-node WS subscribers → powers the single-node execution pane
          - Per-user WS subscribers → powers the global multi-node execution bus
        """
        with self._lock:
            node = self._nodes.get(node_id)
            user_id = node.user_id if node else data.get("user_id", "") if isinstance(data, dict) else ""
            node_qs = list(self._node_listeners.get(node_id, []))
            user_qs = list(self._user_listeners.get(user_id, [])) if user_id else []

        if user_id and not user_qs and event_type in ["node_online", "node_offline"]:
             logger.debug(f"[Registry] emit({event_type}) for node {node_id}: No user listeners found for user {user_id}")

        event = {
            "event": event_type,
            "label": EVENT_TYPES.get(event_type, event_type),
            "node_id": node_id,
            "user_id": user_id,
            "task_id": task_id,
            "timestamp": datetime.utcnow().isoformat(),
            "data": data,
        }

        # M6: Store terminal history locally for AI reading
        # We only store raw shell output and the commands themselves to keep the context clean.
        if node:
            if event_type == "task_assigned" and isinstance(data, dict):
                cmd = data.get("command")
                if cmd:
                    # Skip TTY keypress echos (manual typing) to keep AI context clean
                    # We usually only care about the final result or purposeful command execution
                    # If it's a JSON dict for tty, it's likely a character-by-character input
                    is_tty_char = isinstance(cmd, str) and cmd.startswith('{"tty"')
                    if not is_tty_char:
                        node.terminal_history.append(f"$ {cmd}\n")
            elif event_type == "task_stdout" and isinstance(data, str):
                # NEW: Strip ANSI codes and CAP size to 100KB per chunk to prevent memory bloat
                ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')
                clean_output = ansi_escape.sub('', data)
                if len(clean_output) > 100_000:
                    clean_output = clean_output[:100_000] + "\n[... Output Truncated ...]\n"
                node.terminal_history.append(clean_output)
            elif event_type == "skill_event" and isinstance(data, dict):
                if data.get("type") == "output":
                    output_data = data.get("data", "")
                    # Strip ANSI codes and CAP size
                    ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')
                    clean_output = ansi_escape.sub('', output_data)
                    if len(clean_output) > 100_000:
                        clean_output = clean_output[:100_000] + "\n[... Output Truncated ...]\n"
                    node.terminal_history.append(clean_output)

            # Keep a rolling buffer of 150 terminal interaction chunks
            if len(node.terminal_history) > 150:
                node.terminal_history = node.terminal_history[-150:]

        seen = set()
        for q in node_qs + user_qs:
            if id(q) not in seen:
                seen.add(id(q))
                try:
                    q.put_nowait(event)
                except Exception:
                    pass

    # ------------------------------------------------------------------ #
    #  WS Subscriptions                                                   #
    # ------------------------------------------------------------------ #

    def subscribe_node(self, node_id: str, q: queue.Queue):
        """Subscribe to execution events for a specific node (single-pane view)."""
        with self._lock:
            self._node_listeners.setdefault(node_id, []).append(q)

    def unsubscribe_node(self, node_id: str, q: queue.Queue):
        with self._lock:
            lst = self._node_listeners.get(node_id, [])
            if q in lst:
                lst.remove(q)

    def subscribe_user(self, user_id: str, q: queue.Queue):
        """Subscribe to ALL node events for a user (multi-pane global bus)."""
        with self._lock:
            self._user_listeners.setdefault(user_id, []).append(q)

    def unsubscribe_user(self, user_id: str, q: queue.Queue):
        with self._lock:
            lst = self._user_listeners.get(user_id, [])
            if q in lst:
                lst.remove(q)
