import grpc
import os
from concurrent import futures
import time
import agent_pb2
import agent_pb2_grpc
import threading
import queue
import jwt
import hmac
import hashlib
import json

SECRET_KEY = "cortex-secret-shared-key" 

class TaskJournal:
    """State machine for tracking tasks through their lifecycle."""
    def __init__(self):
        self.lock = threading.Lock()
        self.tasks = {} # task_id -> { "event": Event, "result": None, "node_id": str }

    def register(self, task_id, node_id=None):
        event = threading.Event()
        with self.lock:
            self.tasks[task_id] = {"event": event, "result": None, "node_id": node_id}
        return event

    def fulfill(self, task_id, result):
        with self.lock:
            if task_id in self.tasks:
                self.tasks[task_id]["result"] = result
                self.tasks[task_id]["event"].set()
                return True
        return False

    def get_result(self, task_id):
        with self.lock:
            data = self.tasks.get(task_id)
            return data["result"] if data else None

    def pop(self, task_id):
        with self.lock:
            return self.tasks.pop(task_id, None)

class AbstractNodeRegistry:
    """Interface for finding and tracking Managers."""
    def register(self, node_id, data): raise NotImplementedError
    def update_stats(self, node_id, stats): raise NotImplementedError
    def get_best(self): raise NotImplementedError
    def get_node(self, node_id): raise NotImplementedError

class MemoryNodeRegistry(AbstractNodeRegistry):
    def __init__(self):
        self.lock = threading.Lock()
        self.nodes = {} # node_id -> { stats: {}, queue: queue, metadata: {} }

    def register(self, node_id, q, metadata):
        with self.lock:
            self.nodes[node_id] = {"stats": {}, "queue": q, "metadata": metadata}
            print(f"[📋] Registered: {node_id}")

    def update_stats(self, node_id, stats):
        with self.lock:
            if node_id in self.nodes: self.nodes[node_id]["stats"].update(stats)

    def get_best(self):
        with self.lock:
            if not self.nodes: return None
            # Pick based on active worker count
            return sorted(self.nodes.items(), key=lambda x: x[1]["stats"].get("active_worker_count", 999))[0][0]

    def get_node(self, node_id):
        with self.lock: return self.nodes.get(node_id)

class GlobalWorkPool:
    def __init__(self):
        self.lock = threading.Lock()
        self.available = {"shared-001": "uname -a", "shared-002": "uptime"}
    
    def claim(self, task_id, node_id):
        with self.lock:
            if task_id in self.available:
                return True, self.available.pop(task_id)
            return False, None

class TaskAssistant:
    """The High-Level AI API."""
    def __init__(self, registry, journal, pool):
        self.registry = registry
        self.journal = journal
        self.pool = pool

    def dispatch_single(self, node_id, cmd, timeout=30):
        # Implementation of retry logic and signing
        node = self.registry.get_node(node_id)
        if not node: return {"error": "Offline"}

        tid = f"task-{int(time.time()*1000)}"
        event = self.journal.register(tid, node_id)
        
        # Don't wrap in JSON, use raw cmd for shells
        msg_json = json.dumps({"command": cmd}) # Simulation legacy check
        sig = hmac.new(SECRET_KEY.encode(), cmd.encode(), hashlib.sha256).hexdigest()
        req = agent_pb2.ServerTaskMessage(task_request=agent_pb2.TaskRequest(
            task_id=tid, payload_json=cmd, signature=sig))
        
        print(f"[📤] Dispatching {tid} to {node_id}")
        node["queue"].put(req)
        
        if event.wait(timeout):
            res = self.journal.get_result(tid)
            self.journal.pop(tid)
            return res
        self.journal.pop(tid)
        return {"error": "Timeout"}

    def dispatch_browser(self, node_id, action, timeout=60):
        node = self.registry.get_node(node_id)
        if not node: return {"error": "Offline"}

        tid = f"br-{int(time.time()*1000)}"
        event = self.journal.register(tid, node_id)

        # Basic signature for POC: Sign the action enum name + URL
        sign_base = f"{action.action}:{action.url}:{action.session_id}"
        sig = hmac.new(SECRET_KEY.encode(), sign_base.encode(), hashlib.sha256).hexdigest()

        req = agent_pb2.ServerTaskMessage(task_request=agent_pb2.TaskRequest(
            task_id=tid, browser_action=action, signature=sig))
        
        node["queue"].put(req)
        if event.wait(timeout):
            res = self.journal.get_result(tid)
            self.journal.pop(tid)
            return res
        self.journal.pop(tid)
        return {"error": "Timeout"}

class AgentOrchestrator(agent_pb2_grpc.AgentOrchestratorServicer):
    def __init__(self):
        self.registry = MemoryNodeRegistry()
        self.journal = TaskJournal()
        self.pool = GlobalWorkPool()
        self.assistant = TaskAssistant(self.registry, self.journal, self.pool)

    def SyncConfiguration(self, request, context):
        # Pre-registration for metadata search
        self.registry.register(request.node_id, queue.Queue(), {"desc": request.node_description, "caps": dict(request.capabilities)})
        return agent_pb2.RegistrationResponse(success=True, 
                                             policy=agent_pb2.SandboxPolicy(mode=agent_pb2.SandboxPolicy.STRICT, 
                                             allowed_commands=["ls", "uname", "echo", "sleep"]))

    def TaskStream(self, request_iterator, context):
        try:
            # 1. Blocking wait for identity
            first_msg = next(request_iterator)
            if first_msg.WhichOneof('payload') != 'announce':
                print("[!] Stream rejected: No NodeAnnounce")
                return
            
            node_id = first_msg.announce.node_id
            node = self.registry.get_node(node_id)
            if not node:
                print(f"[!] Stream rejected: Node {node_id} not registered via SyncConfiguration")
                return
            
            print(f"[📶] Stream established for {node_id}")

            # 2. Results Listener
            def _read_results():
                for msg in request_iterator:
                    kind = msg.WhichOneof('payload')
                    if kind == 'task_claim':
                        success, payload = self.pool.claim(msg.task_claim.task_id, node_id)
                        if success:
                            sig = hmac.new(SECRET_KEY.encode(), payload.encode(), hashlib.sha256).hexdigest()
                            node["queue"].put(agent_pb2.ServerTaskMessage(
                                task_request=agent_pb2.TaskRequest(task_id=msg.task_claim.task_id, payload_json=payload, signature=sig)))
                    elif kind == 'task_response':
                        res_obj = {"stdout": msg.task_response.stdout, "status": msg.task_response.status}
                        if msg.task_response.HasField("browser_result"):
                            br = msg.task_response.browser_result
                            res_obj["browser"] = {
                                "url": br.url, "title": br.title, "has_snapshot": len(br.snapshot) > 0,
                                "a11y": br.a11y_tree[:100] + "..." if br.a11y_tree else None,
                                "eval": br.eval_result
                            }
                        self.journal.fulfill(msg.task_response.task_id, res_obj)
                    elif kind == 'browser_event':
                        e = msg.browser_event
                        if e.HasField("console_msg"):
                            print(f"    [🖥️] Live Browser Console: {e.console_msg.text}", flush=True)
                        elif e.HasField("network_req"):
                            print(f"    [🌐] Live Network Request: {e.network_req.method} {e.network_req.url}", flush=True)
            
            threading.Thread(target=_read_results, daemon=True).start()

            # 3. Work Dispatcher (Main Stream)
            while context.is_active():
                try:
                    msg = node["queue"].get(timeout=1.0)
                    yield msg
                    print(f"[🚀] Pushed message from queue to stream: {node_id}")
                except queue.Empty:
                    continue
                    
        except StopIteration:
            pass
        except Exception as e:
            print(f"[!] TaskStream Error: {e}")
            # Broadcast pool
            if self.pool.available:
                yield agent_pb2.ServerTaskMessage(work_pool_update=agent_pb2.WorkPoolUpdate(available_task_ids=list(self.pool.available.keys())))
            
            # Send direct tasks
            if node_id and self.registry.get_node(node_id):
                try:
                    msg = self.registry.get_node(node_id)["queue"].get(timeout=2)
                    yield msg
                except queue.Empty: pass
            else: time.sleep(1)

    def ReportHealth(self, request_iterator, context):
        for hb in request_iterator:
            self.registry.update_stats(hb.node_id, {"active_worker_count": hb.active_worker_count, "running": list(hb.running_task_ids)})
            yield agent_pb2.HealthCheckResponse(server_time_ms=int(time.time()*1000))

def serve():
    with open('certs/server.key', 'rb') as f: pkey = f.read()
    with open('certs/server.crt', 'rb') as f: cert = f.read()
    with open('certs/ca.crt', 'rb') as f: ca = f.read()
    creds = grpc.ssl_server_credentials([(pkey, cert)], ca, True)
    server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
    orch = AgentOrchestrator()
    agent_pb2_grpc.add_AgentOrchestratorServicer_to_server(orch, server)
    server.add_secure_port('[::]:50051', creds)
    
    print("[🛡️] Boss Plane Refactored & Online.")
    server.start()
    
    # Simple AI Simulation loop
    time.sleep(10)
    print("\n[🧠] AI Simulation Start...", flush=True)
    res_single = orch.assistant.dispatch_single('agent-node-007', 'uname -a')
    print(f"    Uname: {res_single}", flush=True)
    
    # NEW: Browser Phase
    print("\n[🧠] AI Phase 4: Navigating Browser (Antigravity Bridge)...")
    nav_action = agent_pb2.BrowserAction(
        action=agent_pb2.BrowserAction.NAVIGATE, 
        url="https://example.com", 
        session_id="antigravity-session-1"
    )
    res_nav = orch.assistant.dispatch_browser("agent-node-007", nav_action)
    print(f"    Nav Result: {res_nav}")

    print("\n[🧠] AI Phase 5: Multi-Action Persistence (Screenshot)...")
    snap_action = agent_pb2.BrowserAction(
        action=agent_pb2.BrowserAction.SCREENSHOT,
        session_id="antigravity-session-1"
    )
    res_snap = orch.assistant.dispatch_browser("agent-node-007", snap_action)
    print(f"    Snap Result: {res_snap.get('browser', {}).get('title')} | Snapshot captured: {res_snap.get('browser', {}).get('has_snapshot')}")

    # NEW: Phase 4 Pro Features
    print("\n[🧠] AI Phase 4 Pro: Perception & Advanced Logic...")
    a11y_action = agent_pb2.BrowserAction(
        action=agent_pb2.BrowserAction.GET_A11Y,
        session_id="antigravity-session-1"
    )
    res_a11y = orch.assistant.dispatch_browser("agent-node-007", a11y_action)
    print(f"    A11y Result: {res_a11y.get('browser', {}).get('a11y')}")

    eval_action = agent_pb2.BrowserAction(
        action=agent_pb2.BrowserAction.EVAL,
        text="window.performance.now()",
        session_id="antigravity-session-1"
    )
    res_eval = orch.assistant.dispatch_browser("agent-node-007", eval_action)
    print(f"    Eval Result (Timestamp): {res_eval.get('browser', {}).get('eval')}")

    # NEW: Phase 4 Pro Features - Real-time Events
    print("\n[🧠] AI Phase 4 Pro: Triggering Real-time Events (Tunneling)...")
    trigger_action = agent_pb2.BrowserAction(
        action=agent_pb2.BrowserAction.EVAL,
        text="console.log('Hello from Antigravity Bridge!'); fetch('https://example.com/api/ping');",
        session_id="antigravity-session-1"
    )
    orch.assistant.dispatch_browser("agent-node-007", trigger_action)

    server.wait_for_termination()

if __name__ == '__main__':
    serve()
