Newer
Older
cortex-hub / agent-node / src / agent_node / node.py
@Antigravity AI Antigravity AI 21 days ago 23 KB cleanup
import threading
import queue
import time
import os
import hashlib
import logging
import json
import zlib
import shutil
import socket
import traceback
from concurrent.futures import ThreadPoolExecutor

try:
    import psutil
except ImportError:
    psutil = None

from protos import agent_pb2, agent_pb2_grpc
from agent_node.skills.manager import SkillManager
from agent_node.core.sandbox import SandboxEngine
from agent_node.core.sync import NodeSyncManager
from agent_node.core.watcher import WorkspaceWatcher
from agent_node.utils.auth import verify_task_signature, verify_server_message_signature
from agent_node.utils.network import get_secure_stub
import agent_node.config as config
from agent_node.utils.watchdog import watchdog
from agent_node.core.regex_patterns import ANSI_ESCAPE

logger = logging.getLogger(__name__)

class AgentNode:
    """
    Agent Core: Orchestrates local skills and maintains gRPC connectivity.
    Refactored for structural clarity and modular message handling.
    """
    def __init__(self):
        self.node_id = config.NODE_ID
        self.sandbox = SandboxEngine()
        self.sync_mgr = NodeSyncManager()
        self.skills = SkillManager(max_workers=config.MAX_SKILL_WORKERS, sync_mgr=self.sync_mgr)
        self.watcher = WorkspaceWatcher(self._on_sync_delta)
        self.task_queue = queue.Queue(maxsize=250) 
        self.stub = None
        self.channel = None
        self._stop_event = threading.Event()
        self._refresh_stub()

        self.io_executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="NodeIO")
        self.io_semaphore = threading.Semaphore(50) 
        self.write_locks = {} 
        self.lock_map_mutex = threading.Lock()

    def _refresh_stub(self):
        """Force refreshes the gRPC channel and stub."""
        if self.channel:
            try: self.channel.close()
            except: pass
        self.stub, self.channel = get_secure_stub()
        self._setup_connectivity_watcher()

    def _setup_connectivity_watcher(self):
        """Monitors gRPC channel state."""
        import grpc
        self._last_grpc_state = None
        def _on_state_change(state):
            if not self._stop_event.is_set() and state != self._last_grpc_state:
                print(f"[*] [gRPC-State] {state}", flush=True)
                self._last_grpc_state = state
        self.channel.subscribe(_on_state_change, try_to_connect=True)

    def sync_configuration(self):
        """Handshake with the Orchestrator to sync policy and metadata."""
        while True:
            config.reload()
            self.node_id = config.NODE_ID
            if not self.stub: self._refresh_stub()

            print(f"[*] Handshake with Orchestrator: {self.node_id}")
            caps = self._collect_capabilities()
            reg_req = agent_pb2.RegistrationRequest(
                node_id=self.node_id, auth_token=config.AUTH_TOKEN,
                node_description=config.NODE_DESC,
                capabilities={k: str(v).lower() if isinstance(v, bool) else str(v) for k, v in caps.items()}
            )

            try:
                res = self.stub.SyncConfiguration(reg_req, timeout=10)
                if res.success:
                    self.sandbox.sync(res.policy)
                    self._apply_skill_config(res.policy.skill_config_json)
                    print("[OK] Handshake successful. Policy Synced.")
                    break
                else:
                    print(f"[Error] Rejection: {res.error_message}. Retrying in 5s...")
                    time.sleep(5)
            except Exception as e:
                print(f"[Error] Connection Fail: {str(e)}. Retrying in 5s...")
                time.sleep(5)

    def _apply_skill_config(self, config_json):
        """Applies dynamic skill configurations from the server."""
        if not config_json: return
        try:
            cfg = json.loads(config_json)
            for skill in self.skills.skills.values():
                if hasattr(skill, "apply_config"): skill.apply_config(cfg)
        except Exception as e:
            logger.error(f"Error applying skill config: {e}")

    def _collect_capabilities(self) -> dict:
        """Collects hardware and OS metadata."""
        from agent_node.utils.platform_metrics import get_platform_metrics
        caps = get_platform_metrics().collect_capabilities()
        try:
            s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
            s.settimeout(0)
            s.connect(('10.254.254.254', 1))
            caps["local_ip"] = s.getsockname()[0]
            s.close()
        except: caps["local_ip"] = "unknown"
        return caps

    def start_health_reporting(self):
        """Launches the background health reporting stream."""
        from agent_node.utils.platform_metrics import get_platform_metrics
        metrics_tool = get_platform_metrics()
        
        def _report():
            while not self._stop_event.is_set():
                try:
                    def _gen():
                        while not self._stop_event.is_set():
                            ids = self.skills.get_active_ids()
                            vmem = psutil.virtual_memory() if psutil else None
                            yield agent_pb2.Heartbeat(
                                node_id=self.node_id, 
                                cpu_usage_percent=psutil.cpu_percent() if psutil else 0,
                                memory_usage_percent=vmem.percent if vmem else 0,
                                active_worker_count=len(ids), 
                                max_worker_capacity=config.MAX_SKILL_WORKERS, 
                                running_task_ids=ids,
                                cpu_count=psutil.cpu_count() if psutil else 0,
                                memory_used_gb=vmem.used/(1024**3) if vmem else 0,
                                memory_total_gb=vmem.total/(1024**3) if vmem else 0,
                                load_avg=metrics_tool.get_load_avg()
                            )
                            time.sleep(max(0, config.HEALTH_REPORT_INTERVAL - 1.0))
                    
                    for _ in self.stub.ReportHealth(_gen()): watchdog.tick() 
                except Exception as e:
                    time.sleep(5)
        
        threading.Thread(target=_report, daemon=True, name="HealthReporter").start()

    def run_task_stream(self):
        """Main bi-directional task stream with auto-reconnection."""
        while True:
            try:
                def _gen():
                    yield agent_pb2.ClientTaskMessage(announce=agent_pb2.NodeAnnounce(node_id=self.node_id))
                    last_heartbeat = time.time()
                    while not self._stop_event.is_set():
                        # Use a small timeout to ensure we check the heartbeat timer regardless of traffic
                        try:
                            msg = self.task_queue.get(timeout=1.0)
                            try:
                                yield msg
                            except Exception as ye:
                                print(f"[!] Critical Error yielding TaskMessage: {ye}")
                                break
                        except queue.Empty:
                            pass
                        
                        # Absolute heartbeat check (every 10s)
                        if time.time() - last_heartbeat >= 10.0:
                            try:
                                yield agent_pb2.ClientTaskMessage(
                                    skill_event=agent_pb2.SkillEvent(keep_alive=True)
                                )
                                last_heartbeat = time.time()
                            except Exception as ye:
                                print(f"[!] Critical Error yielding KeepAlive: {ye}")
                                break
                
                responses = self.stub.TaskStream(_gen())
                print(f"[*] Task stream connected ({self.node_id}).")
                for msg in responses:
                    watchdog.tick()
                    self._process_server_message(msg)
            except Exception as e:
                print(f"[Error] Task stream error: {e}")
                self._refresh_stub()
                time.sleep(5)

    def _process_server_message(self, msg):
        """Routes inbound server messages to their respective handlers."""
        try:
            kind = msg.WhichOneof('payload')
            if not verify_server_message_signature(msg):
                print(f"[!] Signature mismatch for {kind}. Proceeding anyway (DEBUG).")
            
            if kind == 'task_request': self._handle_task(msg.task_request)
            elif kind == 'task_cancel': self._handle_cancel(msg.task_cancel)
            elif kind == 'work_pool_update': self._handle_work_pool(msg.work_pool_update)
            elif kind == 'file_sync':
                # M6: Offload ALL file sync processing to executor to avoid blocking gRPC stream
                self.io_executor.submit(self._handle_file_sync, msg.file_sync)
            elif kind == 'policy_update':
                self.sandbox.sync(msg.policy_update)
                self._apply_skill_config(msg.policy_update.skill_config_json)
        except Exception as e:
            print(f"[!] Error processing server message '{kind}': {e}")
            traceback.print_exc()

    def _handle_cancel(self, cancel_req):
        """Cancels an active task or an entire session's background tasks."""
        if cancel_req.session_id and not cancel_req.task_id:
            logger.info(f"[*] Cancelling all tasks for session: {cancel_req.session_id}")
            self.skills.cancel_session(cancel_req.session_id)
            return

        if self.skills.cancel(cancel_req.task_id):
            self._send_response(cancel_req.task_id, None, agent_pb2.TaskResponse.CANCELLED)

    def _handle_work_pool(self, update):
        """Claims tasks from the global work pool with randomized backoff."""
        if len(self.skills.get_active_ids()) < config.MAX_SKILL_WORKERS:
            for tid in update.available_task_ids:
                import random
                time.sleep(random.uniform(0.1, 0.5))
                self.task_queue.put(agent_pb2.ClientTaskMessage(
                    task_claim=agent_pb2.TaskClaimRequest(task_id=tid, node_id=self.node_id)
                ))

    def _handle_file_sync(self, fs):
        """Dispatches file sync messages to specialized sub-handlers."""
        sid = fs.session_id
        if fs.HasField("manifest"): self._on_sync_manifest(sid, fs.manifest)
        elif fs.HasField("file_data"): self._on_sync_data(sid, fs.file_data)
        elif fs.HasField("control"): self._on_sync_control(sid, fs.control, fs.task_id)

    def _on_sync_manifest(self, sid, manifest):
        """Reconciles local state with a remote manifest."""
        drift = self.sync_mgr.handle_manifest(sid, manifest, on_purge_callback=lambda p: self.watcher.acknowledge_remote_delete(sid, p))
        status = agent_pb2.SyncStatus(
            code=agent_pb2.SyncStatus.RECONCILE_REQUIRED if drift else agent_pb2.SyncStatus.OK,
            message=f"Drift in {len(drift)} files" if drift else "Synchronized",
            reconcile_paths=drift
        )
        self.task_queue.put(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=sid, status=status)))

    def _on_sync_data(self, sid, file_data):
        """Offloads disk I/O to a background worker pool."""
        self.io_semaphore.acquire()
        try: self.io_executor.submit(self._async_write_chunk, sid, file_data)
        except: self.io_semaphore.release()

    def _on_sync_control(self, sid, ctrl, task_id):
        """Handles sync control actions like watching, locking, or directory listing."""
        action = ctrl.action
        if action == agent_pb2.SyncControl.START_WATCHING:
            self.watcher.start_watching(sid, ctrl.path if os.path.isabs(ctrl.path) else os.path.join(self.sync_mgr.get_session_dir(sid), ctrl.path))
        elif action == agent_pb2.SyncControl.STOP_WATCHING: self.watcher.stop_watching(sid)
        elif action == agent_pb2.SyncControl.LOCK: self.watcher.set_lock(sid, True)
        elif action == agent_pb2.SyncControl.UNLOCK: self.watcher.set_lock(sid, False)
        elif action in (agent_pb2.SyncControl.REFRESH_MANIFEST, agent_pb2.SyncControl.RESYNC):
            if ctrl.request_paths:
                for p in ctrl.request_paths: self.io_executor.submit(self._push_file, sid, p)
            else: self._push_full_manifest(sid, ctrl.path)
        elif action == agent_pb2.SyncControl.PURGE: 
            self.watcher.stop_watching(sid)
            self.sync_mgr.purge(sid)
        elif action == agent_pb2.SyncControl.LIST: self._push_full_manifest(sid, ctrl.path, task_id=task_id, shallow=True)
        elif action == agent_pb2.SyncControl.READ: self._push_file(sid, ctrl.path, task_id=task_id)
        elif action == agent_pb2.SyncControl.WRITE: self._handle_fs_write(sid, ctrl.path, ctrl.content, ctrl.is_dir, task_id=task_id)
        elif action == agent_pb2.SyncControl.DELETE: self._handle_fs_delete(sid, ctrl.path, task_id=task_id)

    def _get_base_dir(self, session_id, create=False):
        """Resolves the session's effective root directory."""
        if session_id == "__fs_explorer__": return config.FS_ROOT
        watched = self.watcher.get_watch_path(session_id)
        return watched if watched else self.sync_mgr.get_session_dir(session_id, create=create)

    def _push_full_manifest(self, session_id, rel_path=".", task_id="", shallow=False):
        """Generates and pushes a local file manifest to the server."""
        base_dir = self._get_base_dir(session_id, create=True)
        safe_rel = rel_path.lstrip("/") if session_id != "__fs_explorer__" else rel_path
        watch_path = os.path.normpath(os.path.join(base_dir, safe_rel))
        
        if not os.path.exists(watch_path):
            if session_id != "__fs_explorer__": os.makedirs(watch_path, exist_ok=True)
            else: self._send_sync_error(session_id, task_id, f"Path {rel_path} not found")
            if session_id == "__fs_explorer__": return

        files = []
        try:
            if shallow:
                with os.scandir(watch_path) as it:
                    for entry in it:
                        if entry.name == ".cortex_sync": continue
                        is_dir = entry.is_dir() if not entry.is_symlink() else os.path.isdir(entry.path)
                        item_rel = os.path.relpath(os.path.join(watch_path, entry.name), base_dir).replace("\\", "/")
                        files.append(agent_pb2.FileInfo(path=item_rel, size=entry.stat().st_size if not is_dir else 0, hash="", is_dir=is_dir))
            else:
                for root, dirs, filenames in os.walk(watch_path):
                    for name in filenames:
                        abs_p = os.path.join(root, name)
                        h = self.sync_mgr.get_file_hash(abs_p)
                        if h: files.append(agent_pb2.FileInfo(path=os.path.relpath(abs_p, base_dir).replace("\\", "/"), size=os.path.getsize(abs_p), hash=h, is_dir=False))
                    for d in dirs:
                        files.append(agent_pb2.FileInfo(path=os.path.relpath(os.path.join(root, d), base_dir).replace("\\", "/"), size=0, hash="", is_dir=True))
        except Exception as e:
            return self._send_sync_error(session_id, task_id, str(e))

        self._send_manifest_chunks(session_id, task_id, rel_path, files)

    def _send_manifest_chunks(self, sid, tid, root, files):
        """Splits large manifests into chunks for gRPC streaming."""
        chunk_size = 1000
        if not files:
            self.task_queue.put(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=sid, task_id=tid, manifest=agent_pb2.DirectoryManifest(root_path=root, is_final=True))))
            return
        for i in range(0, len(files), chunk_size):
            chunk = files[i:i+chunk_size]
            self.task_queue.put(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=sid, task_id=tid, 
                manifest=agent_pb2.DirectoryManifest(root_path=root, files=chunk, chunk_index=i//chunk_size, is_final=(i+chunk_size >= len(files))))))

    def _handle_fs_write(self, session_id, rel_path, content, is_dir, task_id=""):
        """Handles single file or directory creation."""
        try:
            base = os.path.normpath(self._get_base_dir(session_id, create=True))
            target = os.path.normpath(os.path.join(base, rel_path.lstrip("/") if session_id != "__fs_explorer__" else rel_path))
            
            # Security: Only enforce jail if not in the global file explorer mode
            if session_id != "__fs_explorer__":
                if not target.startswith(base) and (not config.FS_ROOT or not target.startswith(os.path.normpath(config.FS_ROOT))):
                    raise Exception("Path traversal blocked.")
            
            if is_dir: os.makedirs(target, exist_ok=True)
            else:
                os.makedirs(os.path.dirname(target), exist_ok=True)
                with open(target, "wb") as f: f.write(content)
            
            self._send_sync_ok(session_id, task_id, "Resource written")
            self._push_full_manifest(session_id, os.path.dirname(rel_path) or ".", task_id=task_id, shallow=True)
        except Exception as e: self._send_sync_error(session_id, task_id, str(e))

    def _handle_fs_delete(self, session_id, rel_path, task_id=""):
        """Removes a file or directory from the node."""
        try:
            base = os.path.normpath(self._get_base_dir(session_id))
            target = os.path.normpath(os.path.join(base, rel_path.lstrip("/") if session_id != "__fs_explorer__" else rel_path))
            
            if session_id != "__fs_explorer__":
                if not target.startswith(base) and (not config.FS_ROOT or not target.startswith(os.path.normpath(config.FS_ROOT))):
                    raise Exception("Path traversal blocked.")
            
            self.watcher.acknowledge_remote_delete(session_id, rel_path)
            if os.path.isdir(target): shutil.rmtree(target)
            else: os.remove(target)
            
            self._send_sync_ok(session_id, task_id, f"Deleted {rel_path}")
            self._push_full_manifest(session_id, os.path.dirname(rel_path) or ".", task_id=task_id, shallow=True)
        except Exception as e: self._send_sync_error(session_id, task_id, str(e))

    def _async_write_chunk(self, sid, payload):
        """Worker for segmented file writing with path-level locking."""
        path = payload.path
        with self.lock_map_mutex:
            if (sid, path) not in self.write_locks: self.write_locks[(sid, path)] = threading.Lock()
            lock = self.write_locks[(sid, path)]
        
        with lock:
            try:
                if payload.chunk_index == 0: self.watcher.suppress_path(sid, path)
                success = self.sync_mgr.write_chunk(sid, payload)
                if payload.is_final:
                    if success and payload.hash: self.watcher.acknowledge_remote_write(sid, path, payload.hash)
                    self.watcher.unsuppress_path(sid, path)
                    with self.lock_map_mutex: self.write_locks.pop((sid, path), None)
                    self._send_sync_ok(sid, "", f"File {path} synced")
            finally: self.io_semaphore.release()

    def _push_file(self, session_id, rel_path, task_id=""):
        """Streams a file from node to server using 4MB chunks."""
        base = os.path.normpath(self._get_base_dir(session_id, create=False))
        target = os.path.normpath(os.path.join(base, rel_path.lstrip("/") if session_id != "__fs_explorer__" else rel_path))
        
        if session_id != "__fs_explorer__":
            if not target.startswith(base) and (not config.FS_ROOT or not target.startswith(os.path.normpath(config.FS_ROOT))): return

        if not os.path.exists(target):
            if task_id: self._send_sync_error(session_id, task_id, "File not found")
            return

        hasher = hashlib.sha256()
        size, chunk_size = os.path.getsize(target), 4 * 1024 * 1024
        try:
            with open(target, "rb") as f:
                idx = 0
                while True:
                    chunk = f.read(chunk_size)
                    if not chunk and idx > 0: break
                    hasher.update(chunk)
                    is_final = f.tell() >= size
                    self.task_queue.put(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=session_id, task_id=task_id,
                        file_data=agent_pb2.FilePayload(path=rel_path.replace("\\", "/"), chunk=zlib.compress(chunk), chunk_index=idx, is_final=is_final, hash=hasher.hexdigest() if is_final else "", compressed=True))))
                    if is_final: break
                    idx += 1
        except Exception as e: logger.error(f"Push Error: {e}")

    def _handle_task(self, task):
        """Verifies and submits a skill task for execution."""
        if not verify_task_signature(task):
            print(f"[Warn] Task signature mismatch for {task.task_id}. Proceeding anyway (DEBUG).")
        
        success, reason = self.skills.submit(task, self.sandbox, self._on_finish, self._on_event)
        if not success:
            self._send_response(task.task_id, agent_pb2.TaskResponse(task_id=task.task_id, status=agent_pb2.TaskResponse.ERROR, stderr=f"Rejection: {reason}"))

    def _on_event(self, event):
        """Forwards skill events to the gRPC stream."""
        self.task_queue.put(event if isinstance(event, agent_pb2.ClientTaskMessage) else agent_pb2.ClientTaskMessage(skill_event=event))

    def _on_finish(self, task_id, result, trace_id):
        """Finalizes a task and sends the response back to the Hub."""
        self._send_response(task_id, agent_pb2.TaskResponse(task_id=task_id, stdout=result.get("stdout", ""), stderr=result.get("stderr", ""), status=result.get("status", 0), trace_id=trace_id))

    def _send_response(self, task_id, response, status_override=None):
        """Helper to queue a standard task response with memory safety caps."""
        res = response or agent_pb2.TaskResponse(task_id=task_id)
        if status_override: res.status = status_override
        
        # Redundant safety cap (5MB) to protect gRPC/Queue memory
        SAFETY_CAP = 5 * 1024 * 1024
        for field in ['stdout', 'stderr']:
            val = getattr(res, field, "")
            if len(val) > SAFETY_CAP:
                logger.warning(f"Truncating excessive {field} ({len(val):,} bytes) for task {task_id}")
                setattr(res, field, val[:SAFETY_CAP // 2] + "\n... [TRUNCATED DUE TO SIZE] ...\n" + val[-(SAFETY_CAP // 2):])

        self.task_queue.put(agent_pb2.ClientTaskMessage(task_response=res))

    def _send_sync_ok(self, sid, tid, msg):
        self.task_queue.put(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=sid, task_id=tid, status=agent_pb2.SyncStatus(code=agent_pb2.SyncStatus.OK, message=msg))))

    def _send_sync_error(self, sid, tid, msg):
        self.task_queue.put(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=sid, task_id=tid, status=agent_pb2.SyncStatus(code=agent_pb2.SyncStatus.ERROR, message=msg))))

    def _on_sync_delta(self, session_id, payload):
        self.task_queue.put(agent_pb2.ClientTaskMessage(file_sync=payload if isinstance(payload, agent_pb2.FileSyncMessage) else agent_pb2.FileSyncMessage(session_id=session_id, file_data=payload)))

    def shutdown(self):
        """Gracefully shuts down the node."""
        self._stop_event.set()
        self.skills.shutdown()
        if self.channel: self.channel.close()
        print("[*] Node shutdown complete.")