import grpc
import time
import os
import agent_pb2
import agent_pb2_grpc
import threading
import subprocess
import json
import jwt
import datetime
import hmac
import hashlib
import queue
import sys
import platform
from concurrent import futures
from playwright.sync_api import sync_playwright

SECRET_KEY = "cortex-secret-shared-key" 

class BaseSkill:
    """Interface for pluggable node capabilities."""
    def execute(self, task, sandbox, on_complete, on_event=None):
        raise NotImplementedError

    def cancel(self, task_id):
        return False

class ShellSkill(BaseSkill):
    """Default Skill: Executing shell commands."""
    def __init__(self):
        self.processes = {} # task_id -> Popen
        self.lock = threading.Lock()

    def execute(self, task, sandbox, on_complete, on_event=None):
        try:
            cmd = task.payload_json
            
            allowed, status_msg = sandbox.verify(cmd)
            if not allowed:
                return on_complete(task.task_id, {"stderr": f"SANDBOX_VIOLATION: {status_msg}", "status": 2}, task.trace_id)

            print(f"    [🐚] Executing Shell: {cmd}", flush=True)
            p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
            
            with self.lock: self.processes[task.task_id] = p
            
            timeout = task.timeout_ms / 1000.0 if task.timeout_ms > 0 else None
            stdout, stderr = p.communicate(timeout=timeout)
            print(f"    [🐚] Shell Done: {cmd} | Stdout Size: {len(stdout)}", flush=True)
            
            on_complete(task.task_id, {"stdout": stdout, "stderr": stderr, "status": 1 if p.returncode == 0 else 2}, task.trace_id)
        except subprocess.TimeoutExpired:
            self.cancel(task.task_id)
            on_complete(task.task_id, {"stderr": "TIMEOUT", "status": 2}, task.trace_id)
        except Exception as e:
            on_complete(task.task_id, {"stderr": str(e), "status": 2}, task.trace_id)
        finally:
            with self.lock: self.processes.pop(task.task_id, None)

    def cancel(self, task_id):
        with self.lock:
            p = self.processes.get(task_id)
            if p:
                print(f"[🛑] Killing Shell Process: {task_id}")
                p.kill()
                return True
        return False

class BrowserSkill(BaseSkill):
    """The 'Antigravity Bridge': Persistent Browser Skill using a dedicated Actor thread."""
    def __init__(self):
        self.task_queue = queue.Queue()
        self.sessions = {} # session_id -> { "context": Context, "page": Page }
        self.lock = threading.Lock()
        threading.Thread(target=self._browser_actor, daemon=True, name="BrowserActor").start()

    def _setup_listeners(self, sid, page, on_event):
        if not on_event: return
        page.on("console", lambda msg: on_event(agent_pb2.BrowserEvent(
            session_id=sid, console_msg=agent_pb2.ConsoleMessage(
                level=msg.type, text=msg.text, timestamp_ms=int(time.time()*1000)
            )
        )))
        page.on("requestfinished", lambda req: on_event(agent_pb2.BrowserEvent(
            session_id=sid, network_req=agent_pb2.NetworkRequest(
                method=req.method, url=req.url, status=req.response().status if req.response() else 0,
                resource_type=req.resource_type, latency_ms=0
            )
        )))

    def _browser_actor(self):
        print("[🌐] Browser Actor Starting...", flush=True)
        try:
            pw = sync_playwright().start()
            browser = pw.chromium.launch(headless=True, args=[
                '--no-sandbox', '--disable-setuid-sandbox', '--disable-dev-shm-usage', '--disable-gpu'
            ])
            print("[🌐] Browser Engine Online.", flush=True)
        except Exception as e:
            print(f"[!] Browser Actor Startup Fail: {e}", flush=True)
            return

        while True:
            try:
                task, sandbox, on_complete, on_event = self.task_queue.get()
                action = task.browser_action
                sid = action.session_id or "default"
                
                with self.lock:
                    if sid not in self.sessions:
                        context = browser.new_context()
                        page = context.new_page()
                        self._setup_listeners(sid, page, on_event)
                        self.sessions[sid] = {"context": context, "page": page}
                    page = self.sessions[sid]["page"]
                
                print(f"    [🌐] Browser Actor Processing: {agent_pb2.BrowserAction.ActionType.Name(action.action)} | Session: {sid}", flush=True)
                
                res_data = {}
                if action.action == agent_pb2.BrowserAction.NAVIGATE:
                    page.goto(action.url, wait_until="commit")
                elif action.action == agent_pb2.BrowserAction.CLICK:
                    page.click(action.selector)
                elif action.action == agent_pb2.BrowserAction.TYPE:
                    page.fill(action.selector, action.text)
                elif action.action == agent_pb2.BrowserAction.SCREENSHOT:
                    res_data["snapshot"] = page.screenshot()
                elif action.action == agent_pb2.BrowserAction.GET_DOM:
                    res_data["dom_content"] = page.content()
                elif action.action == agent_pb2.BrowserAction.HOVER:
                    page.hover(action.selector)
                elif action.action == agent_pb2.BrowserAction.SCROLL:
                    page.mouse.wheel(x=0, y=action.y)
                elif action.action == agent_pb2.BrowserAction.EVAL:
                    res_data["eval_result"] = str(page.evaluate(action.text))
                elif action.action == agent_pb2.BrowserAction.GET_A11Y:
                    res_data["a11y_tree"] = json.dumps(page.accessibility.snapshot())
                elif action.action == agent_pb2.BrowserAction.CLOSE:
                    with self.lock:
                        sess = self.sessions.pop(sid, None)
                        if sess: sess["context"].close()

                # Refresh metadata After
                br_res = agent_pb2.BrowserResponse(
                    url=page.url, title=page.title(),
                    snapshot=res_data.get("snapshot", b""),
                    dom_content=res_data.get("dom_content", ""),
                    a11y_tree=res_data.get("a11y_tree", ""),
                    eval_result=res_data.get("eval_result", "")
                )
                on_complete(task.task_id, {"status": 1, "browser_result": br_res}, task.trace_id)
            except Exception as e:
                print(f"    [!] Browser Actor Error: {e}", flush=True)
                on_complete(task.task_id, {"stderr": str(e), "status": 2}, task.trace_id)

    def execute(self, task, sandbox, on_complete, on_event=None):
        self.task_queue.put((task, sandbox, on_complete, on_event))

    def cancel(self, task_id): return False

class SkillManager:
    """Orchestrates multiple skills and manages the worker thread pool."""
    def __init__(self, max_workers=5):
        self.executor = futures.ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="skill-worker")
        self.active_tasks = {} # task_id -> future
        self.skills = {
            "shell": ShellSkill(),
            "browser": BrowserSkill()
        }
        self.max_workers = max_workers
        self.lock = threading.Lock()

    def submit(self, task, sandbox, on_complete, on_event=None):
        with self.lock:
            if len(self.active_tasks) >= self.max_workers:
                return False, "Node Capacity Reached"
            
            # Decide Skill
            if task.HasField("browser_action") or task.task_type == "browser":
                skill = self.skills["browser"]
            else:
                skill = self.skills["shell"]
            
            future = self.executor.submit(skill.execute, task, sandbox, on_complete, on_event)
            self.active_tasks[task.task_id] = future
            
            # Cleanup hook
            future.add_done_callback(lambda f: self._cleanup(task.task_id))
            return True, "Accepted"

    def cancel(self, task_id):
        with self.lock:
            # Tell all skills to try and cancel this ID
            cancelled = any(s.cancel(task_id) for s in self.skills.values())
            return cancelled

    def get_active_ids(self):
        with self.lock:
            return list(self.active_tasks.keys())

    def _cleanup(self, task_id):
        with self.lock: self.active_tasks.pop(task_id, None)

class SandboxEngine:
    def __init__(self):
        self.policy = None

    def sync(self, p):
        self.policy = {"MODE": "STRICT" if p.mode == agent_pb2.SandboxPolicy.STRICT else "PERMISSIVE",
                       "ALLOWED": list(p.allowed_commands), "DENIED": list(p.denied_commands), "SENSITIVE": list(p.sensitive_commands)}

    def verify(self, command_str):
        if not self.policy: return False, "No Policy"
        parts = (command_str or "").strip().split()
        if not parts: return False, "Empty"
        base_cmd = parts[0]
        if base_cmd in self.policy["DENIED"]: return False, "Forbidden"
        if self.policy["MODE"] == "STRICT" and base_cmd not in self.policy["ALLOWED"]:
             return False, "Not Whitelisted"
        return True, "OK"

class AgentNode:
    def __init__(self, node_id="agent-node-007"):
        self.node_id = node_id
        self.skills = SkillManager()
        self.sandbox = SandboxEngine()
        self.task_queue = queue.Queue()
        
        # gRPC Setup
        with open('certs/client.key', 'rb') as f: pkey = f.read()
        with open('certs/client.crt', 'rb') as f: cert = f.read()
        with open('certs/ca.crt', 'rb') as f: ca = f.read()
        creds = grpc.ssl_channel_credentials(ca, pkey, cert)
        self.channel = grpc.secure_channel('localhost:50051', creds)
        self.stub = agent_pb2_grpc.AgentOrchestratorStub(self.channel)

    def _create_token(self):
        return jwt.encode({"sub": self.node_id, "iat": datetime.datetime.utcnow(), 
                           "exp": datetime.datetime.utcnow() + datetime.timedelta(minutes=10)}, SECRET_KEY, algorithm="HS256")

    def sync_configuration(self):
        print(f"[*] Handshake: {self.node_id}")
        reg = agent_pb2.RegistrationRequest(node_id=self.node_id, auth_token=self._create_token(),
                                           node_description="Refactored Stateful Node with Browser Skill", 
                                           capabilities={"shell": "v1", "browser": "playwright-1.42"})
        res = self.stub.SyncConfiguration(reg)
        if res.success: self.sandbox.sync(res.policy); print("[OK] Policy Synced.")
        else: print(f"[!] Rejected: {res.error_message}"); sys.exit(1)

    def start_health_reporting(self):
        def _gen():
            while True:
                ids = self.skills.get_active_ids()
                yield agent_pb2.Heartbeat(node_id=self.node_id, cpu_usage_percent=1.0, 
                                         active_worker_count=len(ids), max_worker_capacity=5, running_task_ids=ids)
                time.sleep(10)
        threading.Thread(target=lambda: list(self.stub.ReportHealth(_gen())), daemon=True).start()

    def run_task_stream(self):
        def _gen():
            yield agent_pb2.ClientTaskMessage(announce=agent_pb2.NodeAnnounce(node_id=self.node_id))
            while True: yield self.task_queue.get()
        responses = self.stub.TaskStream(_gen())
        print(f"[*] Stream processing started: {self.node_id}", flush=True)
        try:
            for msg in responses:
                kind = msg.WhichOneof('payload')
                print(f"[*] Received message from server: {kind}", flush=True)
                if kind == 'task_request':
                    self._handle_task(msg.task_request)
                elif kind == 'task_cancel':
                    if self.skills.cancel(msg.task_cancel.task_id):
                        self.task_queue.put(agent_pb2.ClientTaskMessage(task_response=agent_pb2.TaskResponse(
                            task_id=msg.task_cancel.task_id, status=agent_pb2.TaskResponse.CANCELLED)))
                elif kind == 'work_pool_update':
                    for tid in msg.work_pool_update.available_task_ids:
                        if len(self.skills.get_active_ids()) < self.skills.max_workers:
                            self.task_queue.put(agent_pb2.ClientTaskMessage(task_claim=agent_pb2.TaskClaimRequest(task_id=tid, node_id=self.node_id)))
        except Exception as e:
            print(f"[!] Stream Error: {e}", flush=True)

    def _handle_task(self, task):
        print(f"[*] Handling Task: {task.task_id}", flush=True)
        # Sig Verify logic based on payload type
        if task.HasField("browser_action"):
            a = task.browser_action
            sign_base = f"{a.action}:{a.url}:{a.session_id}".encode()
        else:
            sign_base = task.payload_json.encode()
        
        expected_sig = hmac.new(SECRET_KEY.encode(), sign_base, hashlib.sha256).hexdigest()
        if not hmac.compare_digest(task.signature, expected_sig):
            return print(f"[!] Sig Fail for {task.task_id} | Raw: {sign_base}", flush=True)
        
        print(f"[✅] Signature Verified for {task.task_id}", flush=True)
        self.skills.submit(task, self.sandbox, self._on_finish, self._on_event)

    def _on_event(self, event):
        self.task_queue.put(agent_pb2.ClientTaskMessage(browser_event=event))

    def _on_finish(self, tid, res, trace):
        print(f"[*] Task {tid} finished.", flush=True)
        status = agent_pb2.TaskResponse.SUCCESS if res['status'] == 1 else agent_pb2.TaskResponse.ERROR
        
        tr = agent_pb2.TaskResponse(
            task_id=tid, status=status, 
            stdout=res.get('stdout',''), 
            stderr=res.get('stderr',''), 
            trace_id=trace,
            browser_result=res.get("browser_result")
        )
        self.task_queue.put(agent_pb2.ClientTaskMessage(task_response=tr))

if __name__ == '__main__':
    node = AgentNode()
    node.sync_configuration()
    node.start_health_reporting()
    node.run_task_stream()
