import threading
import queue
import time
import sys
from protos import agent_pb2, agent_pb2_grpc
from agent_node.skills.manager import SkillManager
from agent_node.core.sandbox import SandboxEngine
from agent_node.utils.auth import create_auth_token, verify_task_signature
from agent_node.utils.network import get_secure_stub
from agent_node.config import NODE_ID, NODE_DESC, HEALTH_REPORT_INTERVAL, MAX_SKILL_WORKERS

class AgentNode:
    """The 'Agent Core': Orchestrates Local Skills and Maintains gRPC Connection."""
    def __init__(self, node_id=NODE_ID):
        self.node_id = node_id
        self.skills = SkillManager(max_workers=MAX_SKILL_WORKERS)
        self.sandbox = SandboxEngine()
        self.task_queue = queue.Queue()
        self.stub = get_secure_stub()

    def sync_configuration(self):
        """Initial handshake to retrieve policy and metadata."""
        print(f"[*] Handshake with Orchestrator: {self.node_id}")
        reg_req = agent_pb2.RegistrationRequest(
            node_id=self.node_id, 
            auth_token=create_auth_token(self.node_id),
            node_description=NODE_DESC, 
            capabilities={"shell": "v1", "browser": "playwright-sync-bridge"}
        )
        
        try:
            res = self.stub.SyncConfiguration(reg_req)
            if res.success:
                self.sandbox.sync(res.policy)
                print("[OK] Sandbox Policy Synced.")
            else:
                print(f"[!] Rejection: {res.error_message}")
                sys.exit(1)
        except Exception as e:
            print(f"[!] Connection Fail: {e}")
            sys.exit(1)

    def start_health_reporting(self):
        """Streaming node metrics to the orchestrator for load balancing."""
        def _gen():
            while True:
                ids = self.skills.get_active_ids()
                yield agent_pb2.Heartbeat(
                    node_id=self.node_id, cpu_usage_percent=1.0, 
                    active_worker_count=len(ids), 
                    max_worker_capacity=MAX_SKILL_WORKERS, 
                    running_task_ids=ids
                )
                time.sleep(HEALTH_REPORT_INTERVAL)
        
        # Non-blocking thread for health heartbeat
        threading.Thread(
            target=lambda: list(self.stub.ReportHealth(_gen())), 
            daemon=True, name=f"Health-{self.node_id}"
        ).start()

    def run_task_stream(self):
        """Main Persistent Bi-directional Stream for Task Management."""
        def _gen():
            # Initial announcement for routing identity
            yield agent_pb2.ClientTaskMessage(
                announce=agent_pb2.NodeAnnounce(node_id=self.node_id)
            )
            while True: 
                yield self.task_queue.get()
        
        responses = self.stub.TaskStream(_gen())
        print(f"[*] Task Stream Online: {self.node_id}", flush=True)
        
        try:
            for msg in responses:
                kind = msg.WhichOneof('payload')
                print(f"    [📥] Received from Stream: {kind}", flush=True)
                self._process_server_message(msg)
        except Exception as e:
            print(f"[!] Task Stream Failure: {e}", flush=True)

    def _process_server_message(self, msg):
        kind = msg.WhichOneof('payload')
        print(f"[*] Inbound: {kind}", flush=True)
        
        if kind == 'task_request':
            self._handle_task(msg.task_request)
            
        elif kind == 'task_cancel':
            if self.skills.cancel(msg.task_cancel.task_id):
                self._send_response(msg.task_cancel.task_id, None, agent_pb2.TaskResponse.CANCELLED)
                
        elif kind == 'work_pool_update':
            # Claim logical idle tasks from global pool
            if len(self.skills.get_active_ids()) < MAX_SKILL_WORKERS:
                for tid in msg.work_pool_update.available_task_ids:
                    self.task_queue.put(agent_pb2.ClientTaskMessage(
                        task_claim=agent_pb2.TaskClaimRequest(task_id=tid, node_id=self.node_id)
                    ))

    def _handle_task(self, task):
        print(f"[*] Task Launch: {task.task_id}", flush=True)
        # 1. Cryptographic Signature Verification
        if not verify_task_signature(task):
            print(f"[!] Signature Validation Failed for {task.task_id}", flush=True)
            return

        print(f"[✅] Validated task {task.task_id}", flush=True)
        
        # 2. Skill Manager Submission
        success, reason = self.skills.submit(task, self.sandbox, self._on_finish, self._on_event)
        if not success:
            print(f"[!] Execution Rejected: {reason}", flush=True)

    def _on_event(self, event):
        """Live Event Tunneler: Routes browser/skill events into the main stream."""
        self.task_queue.put(agent_pb2.ClientTaskMessage(browser_event=event))

    def _on_finish(self, tid, res, trace):
        """Final Completion Callback: Routes task results back to server."""
        print(f"[*] Completion: {tid}", flush=True)
        status = agent_pb2.TaskResponse.SUCCESS if res['status'] == 1 else agent_pb2.TaskResponse.ERROR
        
        tr = agent_pb2.TaskResponse(
            task_id=tid, status=status, 
            stdout=res.get('stdout',''), 
            stderr=res.get('stderr',''), 
            trace_id=trace,
            browser_result=res.get("browser_result")
        )
        self._send_response(tid, tr)

    def _send_response(self, tid, tr=None, status=None):
        """Utility for placing response messages into the gRPC outbound queue."""
        if tr:
            self.task_queue.put(agent_pb2.ClientTaskMessage(task_response=tr))
        else:
            self.task_queue.put(agent_pb2.ClientTaskMessage(
                task_response=agent_pb2.TaskResponse(task_id=tid, status=status)
            ))

    def stop(self):
        """Gracefully stops all background services and skills."""
        print(f"\n[🛑] Stopping Agent Node: {self.node_id}")
        self.skills.shutdown()
        # Optionally close gRPC channel if we want to be very clean
        # self.channel.close()
