Newer
Older
cortex-hub / agent-node / agent_node / node.py
import threading
import queue
import time
import sys
import os
import hashlib
import psutil
from protos import agent_pb2, agent_pb2_grpc
from agent_node.skills.manager import SkillManager
from agent_node.core.sandbox import SandboxEngine
from agent_node.core.sync import NodeSyncManager
from agent_node.core.watcher import WorkspaceWatcher
from agent_node.utils.auth import verify_task_signature
from agent_node.utils.network import get_secure_stub
from agent_node.config import NODE_ID, NODE_DESC, AUTH_TOKEN, HEALTH_REPORT_INTERVAL, MAX_SKILL_WORKERS


class AgentNode:
    """The 'Agent Core': Orchestrates Local Skills and Maintains gRPC Connection."""
    def __init__(self, node_id=NODE_ID):
        self.node_id = node_id
        self.sandbox = SandboxEngine()
        self.sync_mgr = NodeSyncManager()
        self.skills = SkillManager(max_workers=MAX_SKILL_WORKERS, sync_mgr=self.sync_mgr)
        self.watcher = WorkspaceWatcher(self._on_sync_delta)
        self.task_queue = queue.Queue()
        self.stub = get_secure_stub()

    def sync_configuration(self):
        """Initial handshake to retrieve policy and metadata."""
        print(f"[*] Handshake with Orchestrator: {self.node_id}")
        reg_req = agent_pb2.RegistrationRequest(
            node_id=self.node_id, 
            auth_token=AUTH_TOKEN,
            node_description=NODE_DESC, 
            capabilities={"shell": "v1", "browser": "playwright-sync-bridge"}
        )

        
        try:
            res = self.stub.SyncConfiguration(reg_req)
            if res.success:
                self.sandbox.sync(res.policy)
                print("[OK] Sandbox Policy Synced.")
            else:
                print(f"[!] Rejection: {res.error_message}")
                sys.exit(1)
        except Exception as e:
            print(f"[!] Connection Fail: {e}")
            sys.exit(1)

    def start_health_reporting(self):
        """Streaming node metrics to the orchestrator for load balancing."""
        def _gen():
            while True:
                ids = self.skills.get_active_ids()
                cpu = psutil.cpu_percent(interval=None)
                mem = psutil.virtual_memory().percent
                yield agent_pb2.Heartbeat(
                    node_id=self.node_id, 
                    cpu_usage_percent=cpu,
                    memory_usage_percent=mem,
                    active_worker_count=len(ids), 
                    max_worker_capacity=MAX_SKILL_WORKERS, 
                    running_task_ids=ids
                )
                time.sleep(HEALTH_REPORT_INTERVAL)
        
        # Non-blocking thread for health heartbeat
        threading.Thread(
            target=lambda: list(self.stub.ReportHealth(_gen())), 
            daemon=True, name=f"Health-{self.node_id}"
        ).start()

    def run_task_stream(self):
        """Main Persistent Bi-directional Stream for Task Management."""
        def _gen():
            # Initial announcement for routing identity
            yield agent_pb2.ClientTaskMessage(
                announce=agent_pb2.NodeAnnounce(node_id=self.node_id)
            )
            while True: 
                yield self.task_queue.get()
        
        responses = self.stub.TaskStream(_gen())
        print(f"[*] Task Stream Online: {self.node_id}", flush=True)
        
        try:
            for msg in responses:
                kind = msg.WhichOneof('payload')
                print(f"    [📥] Received from Stream: {kind}", flush=True)
                self._process_server_message(msg)
        except Exception as e:
            print(f"[!] Task Stream Failure: {e}", flush=True)

    def _process_server_message(self, msg):
        kind = msg.WhichOneof('payload')
        print(f"[*] Inbound: {kind}", flush=True)
        
        if kind == 'task_request':
            self._handle_task(msg.task_request)
            
        elif kind == 'task_cancel':
            if self.skills.cancel(msg.task_cancel.task_id):
                self._send_response(msg.task_cancel.task_id, None, agent_pb2.TaskResponse.CANCELLED)
                
        elif kind == 'work_pool_update':
            # Claim logical idle tasks from global pool with slight randomized jitter
            # to prevent thundering herd where every node claims the same task at the exact same ms.
            if len(self.skills.get_active_ids()) < MAX_SKILL_WORKERS:
                for tid in msg.work_pool_update.available_task_ids:
                    # Deterministic delay based on node_id to distribute claims
                    import random
                    time.sleep(random.uniform(0.1, 0.5))
                    
                    self.task_queue.put(agent_pb2.ClientTaskMessage(
                        task_claim=agent_pb2.TaskClaimRequest(task_id=tid, node_id=self.node_id)
                    ))
                    
        elif kind == 'claim_status':
            status = "GRANTED" if msg.claim_status.granted else "DENIED"
            print(f"    [📦] Claim {msg.claim_status.task_id}: {status} ({msg.claim_status.reason})", flush=True)

        elif kind == 'file_sync':
            self._handle_file_sync(msg.file_sync)

    def _on_sync_delta(self, session_id, file_payload):
        """Callback from watcher to push local changes to server."""
        self.task_queue.put(agent_pb2.ClientTaskMessage(
            file_sync=agent_pb2.FileSyncMessage(
                session_id=session_id,
                file_data=file_payload
            )
        ))

    def _handle_file_sync(self, fs):
        """Processes inbound file synchronization messages from the Orchestrator."""
        sid = fs.session_id
        if fs.HasField("manifest"):
            needs_update = self.sync_mgr.handle_manifest(sid, fs.manifest)
            if needs_update:
                print(f"    [📁⚠️] Drift Detected for {sid}: {len(needs_update)} files need sync")
                self.task_queue.put(agent_pb2.ClientTaskMessage(
                    file_sync=agent_pb2.FileSyncMessage(
                        session_id=sid,
                        status=agent_pb2.SyncStatus(
                            code=agent_pb2.SyncStatus.RECONCILE_REQUIRED,
                            message=f"Drift detected in {len(needs_update)} files",
                            reconcile_paths=needs_update
                        )
                    )
                ))
            else:
                self.task_queue.put(agent_pb2.ClientTaskMessage(
                    file_sync=agent_pb2.FileSyncMessage(
                        session_id=sid,
                        status=agent_pb2.SyncStatus(code=agent_pb2.SyncStatus.OK, message="Synchronized")
                    )
                ))
        elif fs.HasField("file_data"):
            success = self.sync_mgr.write_chunk(sid, fs.file_data)
            if fs.file_data.is_final:
                print(f"    [📁] File Received: {fs.file_data.path} (Verified: {success})")
                status = agent_pb2.SyncStatus.OK if success else agent_pb2.SyncStatus.ERROR
                self.task_queue.put(agent_pb2.ClientTaskMessage(
                    file_sync=agent_pb2.FileSyncMessage(
                        session_id=sid, 
                        status=agent_pb2.SyncStatus(code=status, message=f"File {fs.file_data.path} synced")
                    )
                ))
        elif fs.HasField("control"):
            ctrl = fs.control
            if ctrl.action == agent_pb2.SyncControl.START_WATCHING:
                # Path relative to sync dir or absolute
                watch_path = ctrl.path if os.path.isabs(ctrl.path) else os.path.join(self.sync_mgr.get_session_dir(sid), ctrl.path)
                self.watcher.start_watching(sid, watch_path)
            elif ctrl.action == agent_pb2.SyncControl.STOP_WATCHING:
                self.watcher.stop_watching(sid)
            elif ctrl.action == agent_pb2.SyncControl.LOCK:
                self.watcher.set_lock(sid, True)
            elif ctrl.action == agent_pb2.SyncControl.UNLOCK:
                self.watcher.set_lock(sid, False)
            elif ctrl.action == agent_pb2.SyncControl.REFRESH_MANIFEST:
                if ctrl.request_paths:
                    print(f"    [📁📤] Pushing {len(ctrl.request_paths)} Requested Files for {sid}")
                    for path in ctrl.request_paths:
                        self._push_file(sid, path)
                else:
                    # Node -> Server Manifest Push
                    self._push_full_manifest(sid, ctrl.path)
            elif ctrl.action == agent_pb2.SyncControl.RESYNC:
                # Server -> Node asks for a check, but Node only has its own manifest?
                # Actually RESYNC usually comes with a manifest or implies "send me yours so I can check"
                # Here we'll treat RESYNC as "Send me your manifest"
                self._push_full_manifest(sid, ctrl.path)

    def _push_full_manifest(self, session_id, rel_path="."):
        """Pushes the current local manifest back to the server."""
        print(f"    [📁📤] Pushing Full Manifest for {session_id}")
        watch_path = rel_path if os.path.isabs(rel_path) else os.path.join(self.sync_mgr.get_session_dir(session_id), rel_path)
        
        # We need a manifest generator similar to GhostMirrorManager but on the node
        # For Phase 3, we'll implement a simple one here
        files = []
        for root, dirs, filenames in os.walk(watch_path):
            for filename in filenames:
                abs_path = os.path.join(root, filename)
                r_path = os.path.relpath(abs_path, watch_path)
                with open(abs_path, "rb") as f:
                    h = hashlib.sha256(f.read()).hexdigest()
                files.append(agent_pb2.FileInfo(path=r_path, size=os.path.getsize(abs_path), hash=h))
        
        self.task_queue.put(agent_pb2.ClientTaskMessage(
            file_sync=agent_pb2.FileSyncMessage(
                session_id=session_id,
                manifest=agent_pb2.DirectoryManifest(root_path=rel_path, files=files)
            )
        ))

    def _push_file(self, session_id, rel_path):
        """Pushes a specific file from node to server."""
        watch_path = self.watcher.get_watch_path(session_id)
        if not watch_path:
             # Fallback to sync dir if watcher not started
             watch_path = self.sync_mgr.get_session_dir(session_id)
             
        abs_path = os.path.join(watch_path, rel_path)
        if not os.path.exists(abs_path):
            print(f"    [📁❓] Requested file {rel_path} not found on node")
            return

        with open(abs_path, "rb") as f:
            full_data = f.read()
            full_hash = hashlib.sha256(full_data).hexdigest()
            f.seek(0)
            
            index = 0
            while True:
                chunk = f.read(1024 * 1024) # 1MB chunks
                is_final = len(chunk) < 1024 * 1024
                
                self.task_queue.put(agent_pb2.ClientTaskMessage(
                    file_sync=agent_pb2.FileSyncMessage(
                        session_id=session_id,
                        file_data=agent_pb2.FilePayload(
                            path=rel_path,
                            chunk=chunk,
                            chunk_index=index,
                            is_final=is_final,
                            hash=full_hash if is_final else ""
                        )
                    )
                ))
                
                if is_final or not chunk:
                    break
                index += 1

    def _handle_task(self, task):
        print(f"[*] Task Launch: {task.task_id}", flush=True)
        # 1. Cryptographic Signature Verification
        if not verify_task_signature(task):
            print(f"[!] Signature Validation Failed for {task.task_id}", flush=True)
            return

        print(f"[✅] Validated task {task.task_id}", flush=True)
        
        # 2. Skill Manager Submission
        success, reason = self.skills.submit(task, self.sandbox, self._on_finish, self._on_event)
        if not success:
            print(f"[!] Execution Rejected: {reason}", flush=True)

    def _on_event(self, event):
        """Live Event Tunneler: Routes browser/skill events into the main stream."""
        self.task_queue.put(agent_pb2.ClientTaskMessage(browser_event=event))

    def _on_finish(self, tid, res, trace):
        """Final Completion Callback: Routes task results back to server."""
        print(f"[*] Completion: {tid}", flush=True)
        status = agent_pb2.TaskResponse.SUCCESS if res['status'] == 1 else agent_pb2.TaskResponse.ERROR
        
        tr = agent_pb2.TaskResponse(
            task_id=tid, status=status, 
            stdout=res.get('stdout',''), 
            stderr=res.get('stderr',''), 
            trace_id=trace,
            browser_result=res.get("browser_result")
        )
        self._send_response(tid, tr)

    def _send_response(self, tid, tr=None, status=None):
        """Utility for placing response messages into the gRPC outbound queue."""
        if tr:
            self.task_queue.put(agent_pb2.ClientTaskMessage(task_response=tr))
        else:
            self.task_queue.put(agent_pb2.ClientTaskMessage(
                task_response=agent_pb2.TaskResponse(task_id=tid, status=status)
            ))

    def stop(self):
        """Gracefully stops all background services and skills."""
        print(f"\n[🛑] Stopping Agent Node: {self.node_id}")
        self.skills.shutdown()
        # Optionally close gRPC channel if we want to be very clean
        # self.channel.close()