import grpc
import time
import os
import agent_pb2
import agent_pb2_grpc
import threading
import subprocess
import json
import platform
import jwt
import datetime
import hmac
import hashlib

SECRET_KEY = "cortex-secret-shared-key" 

# --- Sandbox Policy Configuration ---
SANDBOX_POLICY = {
    "MODE": "PERMISSIVE", # Toggle between "STRICT" and "PERMISSIVE"
    "ALLOWED_COMMANDS": ["ls", "grep", "cat", "pwd", "git", "echo", "python", "whoami", "uname"],
    "SENSITIVE_COMMANDS": ["rm", "cp", "mv", "chmod", "chown", "pkill"],
    "DENIED_COMMANDS": ["sudo", "mkfs", "dd", "sh", "bash", "zsh"], 
    "WORKING_DIR": os.getcwd() 
}

class AgentNode:
    def verify_sandbox_policy(self, command_str):
        """Verifies if a command is allowed under the current sandbox policy."""
        parts = command_str.strip().split()
        if not parts:
            return False, "Empty command"
        
        base_cmd = parts[0]
        
        # 1. ALWAYS Block Denied List (Strictly forbidden in all modes)
        if base_cmd in SANDBOX_POLICY["DENIED_COMMANDS"]:
            return False, f"Command '{base_cmd}' is strictly FORBIDDEN."
        
        # 2. Path Guard (Simple string check for escaping ..)
        if ".." in command_str:
             return False, "Path traversal attempt detected (.. disallowed)."

        # 3. Mode-specific Logic
        if SANDBOX_POLICY["MODE"] == "STRICT":
            # In STRICT mode, we check against the whitelist
            if base_cmd not in SANDBOX_POLICY["ALLOWED_COMMANDS"] and base_cmd not in SANDBOX_POLICY["SENSITIVE_COMMANDS"]:
                return False, f"STRICT MODE: Command '{base_cmd}' is NOT whitelisted."
        
        # 4. Sensitive / Consent Check (Applied to both modes)
        if base_cmd in SANDBOX_POLICY["SENSITIVE_COMMANDS"]:
            return True, "SENSITIVE_CONSENT_REQUIRED"

        return True, "OK"

    def create_registration_token(self):
        payload = {
            "sub": "agent-node-007",
            "workspace_id": "ws-production-001",
            "iat": datetime.datetime.utcnow(),
            "exp": datetime.datetime.utcnow() + datetime.timedelta(minutes=10)
        }
        return jwt.encode(payload, SECRET_KEY, algorithm="HS256")

    def __init__(self, node_id="agent-node-007"):
        self.node_id = node_id
        
        # Load certificates for mTLS
        print("[🔐] Loading mTLS certificates...")
        try:
            with open('certs/client.key', 'rb') as f:
                private_key = f.read()
            with open('certs/client.crt', 'rb') as f:
                certificate_chain = f.read()
            with open('certs/ca.crt', 'rb') as f:
                root_certificates = f.read()

            # Create secure channel credentials
            credentials = grpc.ssl_channel_credentials(
                root_certificates=root_certificates,
                private_key=private_key,
                certificate_chain=certificate_chain
            )
            
            # Connect to localhost:50051 using secure channel
            self.channel = grpc.secure_channel('localhost:50051', credentials)
            self.stub = agent_pb2_grpc.AgentOrchestratorStub(self.channel)
            print(f"[*] Agent Node {self.node_id} initialized with secure mTLS channel.")
        except FileNotFoundError as e:
            print(f"[!] Error: Certificates not found. Ensure generate_certs.sh was run. | {e}")
            sys.exit(1)

    # ... (rest of methods)

if __name__ == '__main__':
    # We'll use a queue-based generator for better concurrency support
    import queue
    import sys
    msg_queue = queue.Queue()

    node = AgentNode()
    
    # 1. Registration (Pre-handshake credentials with JWT)
    token = node.create_registration_token()
    reg = agent_pb2.NodeMessage(
        registration=agent_pb2.RegistrationRequest(
            node_id=node.node_id,
            version="1.2.0",
            platform=platform.system() + "-" + platform.machine(),
            capabilities={"shell": True, "browser": False, "secure": True},
            auth_token=token
        )
    )
    msg_queue.put(reg)

    def heartbeat_thread():
        while True:
            time.sleep(10)
            hb = agent_pb2.NodeMessage(
                heartbeat=agent_pb2.Heartbeat(
                    node_id=node.node_id,
                    cpu_usage_percent=1.2,
                    active_task_count=0
                )
            )
            msg_queue.put(hb)

    threading.Thread(target=heartbeat_thread, daemon=True).start()

    def generator():
        while True:
            msg = msg_queue.get()
            yield msg

    try:
        responses = node.stub.Connect(generator())
        
        for response in responses:
            payload_type = response.WhichOneof('payload')
            if payload_type == 'registration_ack':
                ack = response.registration_ack
                if ack.success:
                    print(f"[*] Registered successfully. Session: {ack.session_id}")
                else:
                    print(f"[!] Registration REJECTED: {ack.error_message}")
                    sys.exit(1)
            elif payload_type == 'task_request':
                task = response.task_request
                print(f"[*] Task Received: {task.task_id}. Verifying signature...")
                
                # Verify payload signature
                expected_sig = hmac.new(SECRET_KEY.encode(), task.payload_json.encode(), hashlib.sha256).hexdigest()
                if hmac.compare_digest(task.signature, expected_sig):
                    print(f"    [OK] Signature verified. Checking sandbox policy...")
                    
                    payload = json.loads(task.payload_json)
                    cmd = payload.get("command")
                    
                    # --- Sandbox Enforcement ---
                    allowed, status_msg = node.verify_sandbox_policy(cmd)
                    
                    if not allowed:
                        print(f"    [⛔] Sandbox Violation: {status_msg}")
                        tr = agent_pb2.NodeMessage(
                            task_response=agent_pb2.TaskResponse(
                                task_id=task.task_id,
                                status=agent_pb2.TaskResponse.ERROR,
                                stderr=f"SANDBOX_VIOLATION: {status_msg}",
                                trace_id=task.trace_id
                            )
                        )
                        msg_queue.put(tr)
                        continue

                    if status_msg == "SENSITIVE_CONSENT_REQUIRED":
                        # In production: Wait for UI prompt. In POC: Log and proceed with a warning tag.
                        print(f"    [⚠️] Sensitive Command Encountered: {cmd}. Automated approval assumed in POC.")
                    # -------------------------------

                    print(f"    [OK] Execution starts: {cmd}")
                    res = subprocess.run(cmd, shell=True, capture_output=True, text=True)
                    
                    # Send result back
                    tr = agent_pb2.NodeMessage(
                        task_response=agent_pb2.TaskResponse(
                            task_id=task.task_id,
                            status=agent_pb2.TaskResponse.SUCCESS,
                            stdout=res.stdout,
                            stderr=res.stderr,
                            duration_ms=0,
                            trace_id=task.trace_id
                        )
                    )
                    msg_queue.put(tr)
                else:
                    print(f"    [FAIL] Invalid signature for Task {task.task_id}! REJECTING.")
    except grpc.RpcError as e:
        print(f"[!] RPC Error: {e.code()} | {e.details()}")
        if e.code() == grpc.StatusCode.UNAVAILABLE:
            print("    Is the server running and reachable?")
        elif e.code() == grpc.StatusCode.UNAUTHENTICATED:
            print("    Authentication failed. Check certificates.")
