import threading
import queue

class AbstractNodeRegistry:
    """Interface for finding and tracking Agent Nodes."""
    def register(self, node_id: str, q: queue.Queue, metadata: dict): raise NotImplementedError
    def update_stats(self, node_id: str, stats: dict): raise NotImplementedError
    def get_best(self) -> str: raise NotImplementedError
    def get_node(self, node_id: str) -> dict: raise NotImplementedError

class MemoryNodeRegistry(AbstractNodeRegistry):
    """In-memory implementation of the Node Registry."""
    def __init__(self):
        self.lock = threading.Lock()
        self.nodes = {} # node_id -> { stats: {}, queue: queue, metadata: {} }

    def register(self, node_id, q, metadata):
        with self.lock:
            self.nodes[node_id] = {"stats": {}, "queue": q, "metadata": metadata}
            print(f"[📋] Registered Agent Node: {node_id}")

    def update_stats(self, node_id, stats):
        with self.lock:
            if node_id in self.nodes:
                self.nodes[node_id]["stats"].update(stats)

    def get_best(self):
        """Picks the agent with the lowest active worker count."""
        with self.lock:
            if not self.nodes: return None
            # Simple heuristic: sort by active worker count
            return sorted(self.nodes.items(), key=lambda x: x[1]["stats"].get("active_worker_count", 999))[0][0]

    def get_node(self, node_id):
        with self.lock:
            return self.nodes.get(node_id)
