import grpc
from concurrent import futures
import time
import agent_pb2
import agent_pb2_grpc
import queue
import threading
import jwt
import hmac
import hashlib

# In production, these would be in .env
SECRET_KEY = "cortex-secret-shared-key" 

class AgentOrchestratorServicer(agent_pb2_grpc.AgentOrchestratorServicer):
    def __init__(self):
        self.nodes = {} # node_id -> queue for messages to node

    def Connect(self, request_iterator, context):
        node_id = None
        
        # We need a way to send messages to the client from another thread/input
        # In a real app, this would be triggered by an AI task planner
        send_queue = queue.Queue()

        def stream_messages():
            while context.is_active():
                try:
                    msg = send_queue.get(timeout=1.0)
                    yield msg
                except queue.Empty:
                    continue

        # Start a thread to handle incoming messages from the client
        incoming_thread = threading.Thread(target=self._handle_incoming, args=(request_iterator, context, send_queue))
        incoming_thread.start()

        # Yield messages from the queue to the client
        for msg in stream_messages():
            yield msg

    def _handle_incoming(self, request_iterator, context, send_queue):
        try:
            for message in request_iterator:
                payload_type = message.WhichOneof('payload')
                if payload_type == 'registration':
                    reg = message.registration
                    print(f"[*] Node {reg.node_id} registration request. Verifying token...")

                    try:
                        # Verify JWT
                        # In real app, we check if node_id matches token subject
                        decoded = jwt.decode(reg.auth_token, SECRET_KEY, algorithms=["HS256"])
                        print(f"    [OK] Token verified for workspace: {decoded.get('workspace_id')}")
                        
                        # Send ACK
                        ack = agent_pb2.ServerMessage(
                            registration_ack=agent_pb2.RegistrationResponse(
                                success=True,
                                session_id="session-secure-123"
                            )
                        )
                        send_queue.put(ack)

                        # Test 1: Allowed Command
                        t1_payload = '{"command": "whoami"}'
                        t1_sig = hmac.new(SECRET_KEY.encode(), t1_payload.encode(), hashlib.sha256).hexdigest()
                        send_queue.put(agent_pb2.ServerMessage(
                            task_request=agent_pb2.TaskRequest(
                                task_id="task-001-ALLOWED",
                                task_type="shell",
                                payload_json=t1_payload,
                                trace_id="trace-001",
                                signature=t1_sig
                            )
                        ))

                        # Test 2: Sensitive Command (Consent Required)
                        t2_payload = '{"command": "rm -rf /tmp/node-test"}'
                        t2_sig = hmac.new(SECRET_KEY.encode(), t2_payload.encode(), hashlib.sha256).hexdigest()
                        send_queue.put(agent_pb2.ServerMessage(
                            task_request=agent_pb2.TaskRequest(
                                task_id="task-002-SENSITIVE",
                                task_type="shell",
                                payload_json=t2_payload,
                                trace_id="trace-002",
                                signature=t2_sig
                            )
                        ))

                        # Test 3: Path Traversal Attempt (JAILBREAK)
                        t3_payload = '{"command": "cat ../.env.gitbucket"}'
                        t3_sig = hmac.new(SECRET_KEY.encode(), t3_payload.encode(), hashlib.sha256).hexdigest()
                        send_queue.put(agent_pb2.ServerMessage(
                            task_request=agent_pb2.TaskRequest(
                                task_id="task-003-TRAVERSAL",
                                task_type="shell",
                                payload_json=t3_payload,
                                trace_id="trace-003",
                                signature=t3_sig
                            )
                        ))

                        # Test 4: STRICTLY Forbidden
                        t4_payload = '{"command": "sudo apt update"}'
                        t4_sig = hmac.new(SECRET_KEY.encode(), t4_payload.encode(), hashlib.sha256).hexdigest()
                        send_queue.put(agent_pb2.ServerMessage(
                            task_request=agent_pb2.TaskRequest(
                                task_id="task-004-FORBIDDEN",
                                task_type="shell",
                                payload_json=t4_payload,
                                trace_id="trace-004",
                                signature=t4_sig
                            )
                        ))

                        # Test 5: Non-whitelisted but Allowed in PERMISSIVE
                        t5_payload = '{"command": "df -h"}'
                        t5_sig = hmac.new(SECRET_KEY.encode(), t5_payload.encode(), hashlib.sha256).hexdigest()
                        send_queue.put(agent_pb2.ServerMessage(
                            task_request=agent_pb2.TaskRequest(
                                task_id="task-005-PERMISSIVE",
                                task_type="shell",
                                payload_json=t5_payload,
                                trace_id="trace-005",
                                signature=t5_sig
                            )
                        ))

                        print("[*] Sequence of 5 test tasks dispatched to verify Sandbox Policy (PERMISSIVE mode).")

                    except jwt.ExpiredSignatureError:
                        print(f"    [FAIL] Token for {reg.node_id} expired.")
                        ack = agent_pb2.ServerMessage(
                            registration_ack=agent_pb2.RegistrationResponse(
                                success=False,
                                error_message="Authentication token expired."
                            )
                        )
                        send_queue.put(ack)
                    except jwt.InvalidTokenError as e:
                        print(f"    [FAIL] Invalid token for {reg.node_id}: {e}")
                        ack = agent_pb2.ServerMessage(
                            registration_ack=agent_pb2.RegistrationResponse(
                                success=False,
                                error_message=f"Invalid authentication token: {e}"
                            )
                        )
                        send_queue.put(ack)

                elif payload_type == 'heartbeat':
                    hb = message.heartbeat
                    # print(f"[+] Heartbeat from {hb.node_id}: CPU {hb.cpu_usage_percent}%")
                    pass

                elif payload_type == 'task_response':
                    res = message.task_response
                    print(f"[!] Task Finished: {res.task_id} | Status: {res.status}")
                    print(f"    Stdout: {res.stdout.strip()}")
        except Exception as e:
            print(f"[!] Error handling incoming stream: {e}")

def serve():
    # Load certificates for mTLS
    print("[🔐] Loading mTLS certificates...")
    with open('certs/server.key', 'rb') as f:
        private_key = f.read()
    with open('certs/server.crt', 'rb') as f:
        certificate_chain = f.read()
    with open('certs/ca.crt', 'rb') as f:
        root_certificates = f.read()

    # Create server credentials
    # require_client_auth=True enforces bidirectional verification (mTLS)
    server_credentials = grpc.ssl_server_credentials(
        [(private_key, certificate_chain)],
        root_certificates=root_certificates,
        require_client_auth=True
    )

    server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
    agent_pb2_grpc.add_AgentOrchestratorServicer_to_server(AgentOrchestratorServicer(), server)
    
    # Use secure_port instead of insecure_port
    server.add_secure_port('[::]:50051', server_credentials)
    print("[*] Cortex Secure Server POC listening on port 50051 (mTLS enabled)...")
    server.start()
    server.wait_for_termination()

if __name__ == '__main__':
    serve()
