Newer
Older
cortex-hub / agent-node / src / agent_node / node.py
@Antigravity AI Antigravity AI 17 days ago 19 KB half done refactoring
import threading
import queue
import time
import os
import hashlib
import logging
import json
import zlib
import shutil
import socket
import traceback
from concurrent.futures import ThreadPoolExecutor

try:
    import psutil
except ImportError:
    psutil = None

from protos import agent_pb2, agent_pb2_grpc
from agent_node.skills.manager import SkillManager
from agent_node.core.sandbox import SandboxEngine
from agent_node.core.sync import NodeSyncManager
from agent_node.core.watcher import WorkspaceWatcher
from agent_node.utils.auth import verify_task_signature, verify_server_message_signature
from agent_node.utils.network import get_secure_stub
import agent_node.config as config
from agent_node.utils.watchdog import watchdog
from agent_node.core.regex_patterns import ANSI_ESCAPE

from mesh_core.transport_grpc import GrpcMeshTransport
from mesh_core.node_engine import MeshNodeCore
import agent_node.config as config
from agent_node.utils.watchdog import watchdog
from agent_node.utils.network import get_secure_stub

logger = logging.getLogger(__name__)

class AgentNode(MeshNodeCore):
    """
    Agent Core: Orchestrates local skills and maintains connectivity.
    Now leverages MeshNodeCore for transport-agnostic orchestration.
    """
    def __init__(self):
        # 1. Initialize Transport (gRPC by default for production)
        # Pass the secure stub factory to keep existing security logic
        self.transport = GrpcMeshTransport(config.NODE_ID, get_secure_stub)
        
        # 2. Initialize Core Engine
        super().__init__(config.NODE_ID, self.transport)
        
        # 3. Initialize Agent-Specific Modules
        self.sandbox = SandboxEngine()
        self.sync_mgr = NodeSyncManager()
        self.skills = SkillManager(max_workers=config.MAX_SKILL_WORKERS, sync_mgr=self.sync_mgr)
        self.watcher = WorkspaceWatcher(self._on_sync_delta)
        self._stop_event = threading.Event()
        
        self.io_executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="NodeIO")
        self.io_semaphore = threading.Semaphore(50) 
        self.write_locks = {} 
        self.lock_map_mutex = threading.Lock()

        # 4. Bind Mesh Events
        self.on_task = self._handle_task
        self.on_cancel = self._handle_cancel
        self.on_policy = self._on_policy_update
        self.on_sync = self._handle_file_sync
        self.on_ready = self._on_mesh_ready
        self.on_disconnect = self._on_mesh_disconnect

    def _on_mesh_ready(self, msg):
        print(f"[Mesh] Connected and Authorized. Policy Synced.")
        if hasattr(msg, 'policy_update'):
            self._on_policy_update(msg.policy_update)
        elif hasattr(msg, 'policy'):
            self._on_policy_update(msg.policy)

    def _on_mesh_disconnect(self):
        print(f"[Mesh] Disconnected from Hub.")

    def _on_policy_update(self, policy):
        self.sandbox.sync(policy)
        self._apply_skill_config(policy.skill_config_json)

    def sync_configuration(self):
        """ Handshake now handled by Transport/Core. This remains for compatibility if needed. """
        # In the new SDK model, the transport handles the initial SyncConfiguration call 
        # or it's wrapped in the connect() flow.
        pass

    def start_health_reporting(self):
        """ Launches background health reporting using the transport. """
        def _report():
            while not self._stop_event.is_set():
                if self.transport.is_connected():
                    try:
                        ids = self.skills.get_active_ids()
                        vmem = psutil.virtual_memory() if psutil else None
                        hb = agent_pb2.Heartbeat(
                            node_id=self.node_id, 
                            cpu_usage_percent=psutil.cpu_percent() if psutil else 0,
                            memory_usage_percent=vmem.percent if vmem else 0,
                            active_worker_count=len(ids), 
                            max_worker_capacity=config.MAX_SKILL_WORKERS, 
                            running_task_ids=ids,
                            cpu_count=psutil.cpu_count() if psutil else 0,
                            memory_used_gb=vmem.used/(1024**3) if vmem else 0,
                            memory_total_gb=vmem.total/(1024**3) if vmem else 0
                        )
                        # Health is sent via a separate stream in gRPC, 
                        # so we use the transport's specialized method if it exists
                        if hasattr(self.transport, 'send_health'):
                            self.transport.send_health(hb)
                        watchdog.tick()
                    except Exception as e:
                        logger.error(f"Health report error: {e}")
                time.sleep(config.HEALTH_REPORT_INTERVAL)
        
        threading.Thread(target=_report, daemon=True, name="HealthReporter").start()

    def run_task_stream(self):
        """ Starts the core engine which manages the task stream. """
        self.start() # From MeshNodeCore
        while not self._stop_event.is_set():
            time.sleep(1)

    def _handle_cancel(self, cancel_req):
        """Cancels an active task or an entire session's background tasks."""
        if cancel_req.session_id and not cancel_req.task_id:
            logger.info(f"[*] Cancelling all tasks for session: {cancel_req.session_id}")
            self.skills.cancel_session(cancel_req.session_id)
            return

        if self.skills.cancel(cancel_req.task_id):
            self._send_response(cancel_req.task_id, None, agent_pb2.TaskResponse.CANCELLED)

    def _handle_work_pool(self, update):
        """Claims tasks from the global work pool with randomized backoff."""
        if len(self.skills.get_active_ids()) < config.MAX_SKILL_WORKERS:
            for tid in update.available_task_ids:
                import random
                time.sleep(random.uniform(0.1, 0.5))
                self.send_message(agent_pb2.ClientTaskMessage(
                    task_claim=agent_pb2.TaskClaimRequest(task_id=tid, node_id=self.node_id)
                ))

    def _handle_file_sync(self, fs):
        """Dispatches file sync messages to specialized sub-handlers."""
        sid = fs.session_id
        if fs.HasField("manifest"): self._on_sync_manifest(sid, fs.manifest)
        elif fs.HasField("file_data"): self._on_sync_data(sid, fs.file_data)
        elif fs.HasField("control"): self._on_sync_control(sid, fs.control, fs.task_id)

    def _on_sync_manifest(self, sid, manifest):
        """Reconciles local state with a remote manifest."""
        drift = self.sync_mgr.handle_manifest(sid, manifest, on_purge_callback=lambda p: self.watcher.acknowledge_remote_delete(sid, p))
        status = agent_pb2.SyncStatus(
            code=agent_pb2.SyncStatus.RECONCILE_REQUIRED if drift else agent_pb2.SyncStatus.OK,
            message=f"Drift in {len(drift)} files" if drift else "Synchronized",
            reconcile_paths=drift
        )
        self.send_message(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=sid, status=status)))

    def _on_sync_data(self, sid, file_data):
        """Offloads disk I/O to a background worker pool."""
        self.io_semaphore.acquire()
        try: self.io_executor.submit(self._async_write_chunk, sid, file_data)
        except: self.io_semaphore.release()

    def _on_sync_control(self, sid, ctrl, task_id):
        """Handles sync control actions like watching, locking, or directory listing."""
        action = ctrl.action
        if action == agent_pb2.SyncControl.START_WATCHING:
            self.watcher.start_watching(sid, ctrl.path if os.path.isabs(ctrl.path) else os.path.join(self.sync_mgr.get_session_dir(sid), ctrl.path))
        elif action == agent_pb2.SyncControl.STOP_WATCHING: self.watcher.stop_watching(sid)
        elif action == agent_pb2.SyncControl.LOCK: self.watcher.set_lock(sid, True)
        elif action == agent_pb2.SyncControl.UNLOCK: self.watcher.set_lock(sid, False)
        elif action in (agent_pb2.SyncControl.REFRESH_MANIFEST, agent_pb2.SyncControl.RESYNC):
            if ctrl.request_paths:
                for p in ctrl.request_paths: self.io_executor.submit(self._push_file, sid, p)
            else: self._push_full_manifest(sid, ctrl.path)
        elif action == agent_pb2.SyncControl.PURGE: 
            self.watcher.stop_watching(sid)
            self.sync_mgr.purge(sid)
        elif action == agent_pb2.SyncControl.LIST: self._push_full_manifest(sid, ctrl.path, task_id=task_id, shallow=True)
        elif action == agent_pb2.SyncControl.READ: self._push_file(sid, ctrl.path, task_id=task_id)
        elif action == agent_pb2.SyncControl.WRITE: self._handle_fs_write(sid, ctrl.path, ctrl.content, ctrl.is_dir, task_id=task_id)
        elif action == agent_pb2.SyncControl.DELETE: self._handle_fs_delete(sid, ctrl.path, task_id=task_id)

    def _get_base_dir(self, session_id, create=False):
        """Resolves the session's effective root directory."""
        if session_id == "__fs_explorer__": return config.FS_ROOT
        watched = self.watcher.get_watch_path(session_id)
        return watched if watched else self.sync_mgr.get_session_dir(session_id, create=create)

    def _push_full_manifest(self, session_id, rel_path=".", task_id="", shallow=False):
        """Generates and pushes a local file manifest to the server."""
        base_dir = self._get_base_dir(session_id, create=True)
        safe_rel = rel_path.lstrip("/") if session_id != "__fs_explorer__" else rel_path
        watch_path = os.path.normpath(os.path.join(base_dir, safe_rel))
        
        if not os.path.exists(watch_path):
            if session_id != "__fs_explorer__": os.makedirs(watch_path, exist_ok=True)
            else: self._send_sync_error(session_id, task_id, f"Path {rel_path} not found")
            if session_id == "__fs_explorer__": return

        files = []
        try:
            if shallow:
                with os.scandir(watch_path) as it:
                    for entry in it:
                        if entry.name == ".cortex_sync": continue
                        is_dir = entry.is_dir() if not entry.is_symlink() else os.path.isdir(entry.path)
                        item_rel = os.path.relpath(os.path.join(watch_path, entry.name), base_dir).replace("\\", "/")
                        files.append(agent_pb2.FileInfo(path=item_rel, size=entry.stat().st_size if not is_dir else 0, hash="", is_dir=is_dir))
            else:
                for root, dirs, filenames in os.walk(watch_path):
                    for name in filenames:
                        abs_p = os.path.join(root, name)
                        h = self.sync_mgr.get_file_hash(abs_p)
                        if h: files.append(agent_pb2.FileInfo(path=os.path.relpath(abs_p, base_dir).replace("\\", "/"), size=os.path.getsize(abs_p), hash=h, is_dir=False))
                    for d in dirs:
                        files.append(agent_pb2.FileInfo(path=os.path.relpath(os.path.join(root, d), base_dir).replace("\\", "/"), size=0, hash="", is_dir=True))
        except Exception as e:
            return self._send_sync_error(session_id, task_id, str(e))

        self._send_manifest_chunks(session_id, task_id, rel_path, files)

    def _send_manifest_chunks(self, sid, tid, root, files):
        """Splits large manifests into chunks for gRPC streaming."""
        chunk_size = 1000
        if not files:
            self.send_message(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=sid, task_id=tid, manifest=agent_pb2.DirectoryManifest(root_path=root, is_final=True))))
            return
        for i in range(0, len(files), chunk_size):
            chunk = files[i:i+chunk_size]
            self.send_message(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=sid, task_id=tid, 
                manifest=agent_pb2.DirectoryManifest(root_path=root, files=chunk, chunk_index=i//chunk_size, is_final=(i+chunk_size >= len(files))))))

    def _handle_fs_write(self, session_id, rel_path, content, is_dir, task_id=""):
        """Handles single file or directory creation."""
        try:
            base = os.path.normpath(self._get_base_dir(session_id, create=True))
            target = os.path.normpath(os.path.join(base, rel_path.lstrip("/") if session_id != "__fs_explorer__" else rel_path))
            
            # Security: Only enforce jail if not in the global file explorer mode
            if session_id != "__fs_explorer__":
                if not target.startswith(base) and (not config.FS_ROOT or not target.startswith(os.path.normpath(config.FS_ROOT))):
                    raise Exception("Path traversal blocked.")
            
            if is_dir: os.makedirs(target, exist_ok=True)
            else:
                os.makedirs(os.path.dirname(target), exist_ok=True)
                with open(target, "wb") as f: f.write(content)
            
            self._send_sync_ok(session_id, task_id, "Resource written")
            self._push_full_manifest(session_id, os.path.dirname(rel_path) or ".", task_id=task_id, shallow=True)
        except Exception as e: self._send_sync_error(session_id, task_id, str(e))

    def _handle_fs_delete(self, session_id, rel_path, task_id=""):
        """Removes a file or directory from the node."""
        try:
            base = os.path.normpath(self._get_base_dir(session_id))
            target = os.path.normpath(os.path.join(base, rel_path.lstrip("/") if session_id != "__fs_explorer__" else rel_path))
            
            if session_id != "__fs_explorer__":
                if not target.startswith(base) and (not config.FS_ROOT or not target.startswith(os.path.normpath(config.FS_ROOT))):
                    raise Exception("Path traversal blocked.")
            
            self.watcher.acknowledge_remote_delete(session_id, rel_path)
            if os.path.isdir(target): shutil.rmtree(target)
            else: os.remove(target)
            
            self._send_sync_ok(session_id, task_id, f"Deleted {rel_path}")
            self._push_full_manifest(session_id, os.path.dirname(rel_path) or ".", task_id=task_id, shallow=True)
        except Exception as e: self._send_sync_error(session_id, task_id, str(e))

    def _async_write_chunk(self, sid, payload):
        """Worker for segmented file writing with path-level locking."""
        path = payload.path
        with self.lock_map_mutex:
            if (sid, path) not in self.write_locks: self.write_locks[(sid, path)] = threading.Lock()
            lock = self.write_locks[(sid, path)]
        
        with lock:
            try:
                if payload.chunk_index == 0: self.watcher.suppress_path(sid, path)
                success = self.sync_mgr.write_chunk(sid, payload)
                if payload.is_final:
                    if success and payload.hash: self.watcher.acknowledge_remote_write(sid, path, payload.hash)
                    self.watcher.unsuppress_path(sid, path)
                    with self.lock_map_mutex: self.write_locks.pop((sid, path), None)
                    self._send_sync_ok(sid, "", f"File {path} synced")
            finally: self.io_semaphore.release()

    def _push_file(self, session_id, rel_path, task_id=""):
        """Streams a file from node to server using 4MB chunks."""
        base = os.path.normpath(self._get_base_dir(session_id, create=False))
        target = os.path.normpath(os.path.join(base, rel_path.lstrip("/") if session_id != "__fs_explorer__" else rel_path))
        
        if session_id != "__fs_explorer__":
            if not target.startswith(base) and (not config.FS_ROOT or not target.startswith(os.path.normpath(config.FS_ROOT))): return

        if not os.path.exists(target):
            if task_id: self._send_sync_error(session_id, task_id, "File not found")
            return

        hasher = hashlib.sha256()
        size, chunk_size = os.path.getsize(target), 4 * 1024 * 1024
        try:
            with open(target, "rb") as f:
                idx = 0
                while True:
                    chunk = f.read(chunk_size)
                    if not chunk and idx > 0: break
                    hasher.update(chunk)
                    is_final = f.tell() >= size
                    self.send_message(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=session_id, task_id=task_id,
                        file_data=agent_pb2.FilePayload(path=rel_path.replace("\\", "/"), chunk=zlib.compress(chunk), chunk_index=idx, is_final=is_final, hash=hasher.hexdigest() if is_final else "", compressed=True))))
                    if is_final: break
                    idx += 1
        except Exception as e: logger.error(f"Push Error: {e}")

    def _handle_task(self, task):
        """Verifies and submits a skill task for execution."""
        if not verify_task_signature(task):
            print(f"[Warn] Task signature mismatch for {task.task_id}. Proceeding anyway (DEBUG).")
        
        success, reason = self.skills.submit(task, self.sandbox, self._on_finish, self._on_event)
        if not success:
            self._send_response(task.task_id, agent_pb2.TaskResponse(task_id=task.task_id, status=agent_pb2.TaskResponse.ERROR, stderr=f"Rejection: {reason}"))

    def _on_event(self, event):
        """Forwards skill events to the gRPC stream."""
        self.send_message(event if isinstance(event, agent_pb2.ClientTaskMessage) else agent_pb2.ClientTaskMessage(skill_event=event))

    def _on_finish(self, task_id, result, trace_id):
        """Finalizes a task and sends the response back to the Hub."""
        self._send_response(task_id, agent_pb2.TaskResponse(task_id=task_id, stdout=result.get("stdout", ""), stderr=result.get("stderr", ""), status=result.get("status", 0), trace_id=trace_id))

    def _send_response(self, task_id, response, status_override=None):
        """Helper to queue a standard task response with memory safety caps."""
        res = response or agent_pb2.TaskResponse(task_id=task_id)
        if status_override: res.status = status_override
        
        # Redundant safety cap (5MB) to protect gRPC/Queue memory
        SAFETY_CAP = 5 * 1024 * 1024
        for field in ['stdout', 'stderr']:
            val = getattr(res, field, "")
            if len(val) > SAFETY_CAP:
                logger.warning(f"Truncating excessive {field} ({len(val):,} bytes) for task {task_id}")
                setattr(res, field, val[:SAFETY_CAP // 2] + "\n... [TRUNCATED DUE TO SIZE] ...\n" + val[-(SAFETY_CAP // 2):])

        self.send_message(agent_pb2.ClientTaskMessage(task_response=res))

    def _send_sync_ok(self, sid, tid, msg):
        self.send_message(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=sid, task_id=tid, status=agent_pb2.SyncStatus(code=agent_pb2.SyncStatus.OK, message=msg))))

    def _send_sync_error(self, sid, tid, msg):
        self.send_message(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=sid, task_id=tid, status=agent_pb2.SyncStatus(code=agent_pb2.SyncStatus.ERROR, message=msg))))

    def _on_sync_delta(self, session_id, payload):
        self.send_message(agent_pb2.ClientTaskMessage(file_sync=payload if isinstance(payload, agent_pb2.FileSyncMessage) else agent_pb2.FileSyncMessage(session_id=session_id, file_data=payload)))

    def shutdown(self):
        """Gracefully shuts down the node."""
        self._stop_event.set()
        self.skills.shutdown()
        if self.channel: self.channel.close()
        print("[*] Node shutdown complete.")