Newer
Older
cortex-hub / agent-node / src / agent_node / node.py
import threading
import queue
import time
import os
import hashlib
import logging
import json
import zlib
import shutil
import socket
import traceback
from concurrent.futures import ThreadPoolExecutor

try:
    import psutil
except ImportError:
    psutil = None

from mesh_core import agent_pb2, agent_pb2_grpc
from mesh_core.engines import MeshNodeCore
from agent_node.skills.manager import SkillManager
from agent_node.core.sandbox import SandboxEngine
from agent_node.core.sync import NodeSyncManager
from agent_node.core.watcher import WorkspaceWatcher
from agent_node.utils.auth import verify_task_signature, verify_server_message_signature
from agent_node.utils.network import get_secure_stub
import agent_node.config as config
from agent_node.utils.watchdog import watchdog
from agent_node.core.regex_patterns import ANSI_ESCAPE

from mesh_core.transport import GrpcMeshTransport
from mesh_core.utils import DataChunker
import agent_node.config as config
from agent_node.utils.watchdog import watchdog
from agent_node.utils.network import get_secure_stub

logger = logging.getLogger(__name__)

class AgentNode(MeshNodeCore):
    """
    Agent Core: Orchestrates local skills and maintains connectivity.
    Now leverages MeshNodeCore for transport-agnostic orchestration.
    """
    def __init__(self):
        # 1. Initialize Transport (gRPC by default for production)
        # Pass the secure stub factory to keep existing security logic
        self.transport = GrpcMeshTransport(
            config.NODE_ID, get_secure_stub,
            auth_token=config.AUTH_TOKEN,
            hub_http_url=config.HUB_URL,
            secret_key=config.SECRET_KEY,
        )
        
        # 2. Initialize Core Engine
        super().__init__(config.NODE_ID, self.transport)
        
        # 3. Initialize Agent-Specific Modules
        self.sandbox = SandboxEngine()
        self.sync_mgr = NodeSyncManager()
        self.skills = SkillManager(max_workers=config.MAX_SKILL_WORKERS, sync_mgr=self.sync_mgr)
        self.watcher = WorkspaceWatcher(self._on_sync_delta)
        self._stop_event = threading.Event()
        
        self.io_executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="NodeIO")
        self.io_semaphore = threading.Semaphore(50) 
        self.write_locks = {} 
        self.lock_map_mutex = threading.Lock()

        # 4. Bind Mesh Events
        self.on_task = self._handle_task
        self.on_cancel = self._handle_cancel
        self.on_policy = self._on_policy_update
        self.on_sync = self._handle_file_sync
        self.on_ready = self._on_mesh_ready
        self.on_disconnect = self._on_mesh_disconnect
        self.on_work_pool = self._handle_work_pool

    def _on_mesh_ready(self, msg):
        print(f"[Mesh] Connected and Authorized. Policy Synced.")
        if hasattr(msg, 'policy_update'):
            self._on_policy_update(msg.policy_update)
        elif hasattr(msg, 'policy'):
            self._on_policy_update(msg.policy)

    def _on_mesh_disconnect(self):
        print(f"[Mesh] Disconnected from Hub.")

    def _on_policy_update(self, policy):
        self.sandbox.sync(policy)

    def _apply_skill_config(self, config_json: str):
        """Placeholder for applying skill configuration from policy."""
        print(f"    [] Applying skill config: {config_json}")
        # TODO: Implement skill specific configuration updates

    def sync_configuration(self):
        """ Handshake now handled by Transport/Core. This remains for compatibility if needed. """
        # In the new SDK model, the transport handles the initial SyncConfiguration call 
        # or it's wrapped in the connect() flow.
        pass

    def start_health_reporting(self):
        """ Launches background health reporting using the transport. """
        def _report():
            while not self._stop_event.is_set():
                # Always tick — proves this thread is alive regardless of hub
                # connectivity. Watchdog fires only if the reporter itself deadlocks,
                # not merely because the hub is temporarily unreachable.
                watchdog.tick()

                is_conn = self.transport.is_connected()
                logger.debug(f"HealthReporter: connected={is_conn}")
                if is_conn:
                    try:
                        ids = self.skills.get_active_ids()
                        vmem = psutil.virtual_memory() if psutil else None
                        hb = agent_pb2.Heartbeat(
                            node_id=self.node_id,
                            cpu_usage_percent=psutil.cpu_percent() if psutil else 0,
                            memory_usage_percent=vmem.percent if vmem else 0,
                            active_worker_count=len(ids),
                            max_worker_capacity=config.MAX_SKILL_WORKERS,
                            running_task_ids=ids,
                            cpu_count=psutil.cpu_count() if psutil else 0,
                            memory_used_gb=vmem.used/(1024**3) if vmem else 0,
                            memory_total_gb=vmem.total/(1024**3) if vmem else 0
                        )
                        if hasattr(self.transport, 'send_health'):
                            self.transport.send_health(hb)
                    except Exception as e:
                        logger.error(f"Health report error: {e}")
                else:
                    logger.warning("HealthReporter: hub unreachable, heartbeat skipped (gRPC reconnecting)")
                time.sleep(config.HEALTH_REPORT_INTERVAL)
        
        threading.Thread(target=_report, daemon=True, name="HealthReporter").start()

    def run_task_stream(self):
        """ Starts the core engine which manages the task stream. """
        if not self.start(): # From MeshNodeCore
            raise RuntimeError("Handshake failed, node cannot start.")
        
        # Wait until connected (timeout 30s)
        connected = False
        for _ in range(30):
            if self.transport.is_connected():
                connected = True
                break
            time.sleep(1)
            
        if not connected:
            raise RuntimeError("Transport failed to connect within timeout.")

        while not self._stop_event.is_set():
            time.sleep(1)

    def _handle_cancel(self, cancel_req):
        """Cancels an active task or an entire session's background tasks."""
        session_id = getattr(cancel_req, 'session_id', None)
        if session_id and not cancel_req.task_id:
            logger.info(f"[*] Cancelling all tasks for session: {session_id}")
            self.skills.cancel_session(session_id)
            return

        if self.skills.cancel(cancel_req.task_id):
            self._send_response(cancel_req.task_id, None, agent_pb2.TaskResponse.CANCELLED)

    def _handle_work_pool(self, update):
        """Claims tasks from the global work pool with randomized backoff."""
        def _claim():
            if len(self.skills.get_active_ids()) < config.MAX_SKILL_WORKERS:
                for tid in update.available_task_ids:
                    import random
                    time.sleep(random.uniform(0.1, 0.5))
                    self.send_message(agent_pb2.ClientTaskMessage(
                        task_claim=agent_pb2.TaskClaimRequest(task_id=tid, node_id=self.node_id)
                    ))
        
        import threading
        threading.Thread(target=_claim, daemon=True, name="WorkPoolClaimer").start()

    def _handle_file_sync(self, fs):
        """Dispatches file sync messages to specialized sub-handlers."""
        sid = fs.session_id
        if fs.HasField("manifest"): self._on_sync_manifest(sid, fs.manifest)
        elif fs.HasField("file_data"): self._on_sync_data(sid, fs.file_data)
        elif fs.HasField("control"): self._on_sync_control(sid, fs.control, fs.task_id)

    def _on_sync_manifest(self, sid, manifest):
        """Reconciles local state with a remote manifest."""
        drift = self.sync_mgr.handle_manifest(sid, manifest, on_purge_callback=lambda p: self.watcher.acknowledge_remote_delete(sid, p))
        status = agent_pb2.SyncStatus(
            code=agent_pb2.SyncStatus.RECONCILE_REQUIRED if drift else agent_pb2.SyncStatus.OK,
            message=f"Drift in {len(drift)} files" if drift else "Synchronized",
            reconcile_paths=drift
        )
        self.send_message(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=sid, status=status)))

    def _on_sync_data(self, sid, file_data):
        """Offloads disk I/O to a background worker pool."""
        self.io_semaphore.acquire()
        try: self.io_executor.submit(self._async_write_chunk, sid, file_data)
        except: self.io_semaphore.release()

    def _on_sync_control(self, sid, ctrl, task_id):
        router = {
            agent_pb2.SyncControl.START_WATCHING: self._ctrl_start_watching,
            agent_pb2.SyncControl.STOP_WATCHING: self._ctrl_stop_watching,
            agent_pb2.SyncControl.LOCK: self._ctrl_lock,
            agent_pb2.SyncControl.UNLOCK: self._ctrl_unlock,
            agent_pb2.SyncControl.REFRESH_MANIFEST: self._ctrl_refresh_manifest,
            agent_pb2.SyncControl.RESYNC: self._ctrl_refresh_manifest,
            agent_pb2.SyncControl.PURGE: self._ctrl_purge,
            agent_pb2.SyncControl.LIST: self._ctrl_list,
            agent_pb2.SyncControl.READ: self._ctrl_read,
            agent_pb2.SyncControl.WRITE: self._ctrl_write,
            agent_pb2.SyncControl.DELETE: self._ctrl_delete
        }
        
        handler = router.get(ctrl.action)
        if handler:
            handler(sid, ctrl, task_id)

    def _ctrl_start_watching(self, sid, ctrl, task_id):
        self.watcher.start_watching(sid, ctrl.path if os.path.isabs(ctrl.path) else os.path.join(self.sync_mgr.get_session_dir(sid), ctrl.path))

    def _ctrl_stop_watching(self, sid, ctrl, task_id):
        self.watcher.stop_watching(sid)

    def _ctrl_lock(self, sid, ctrl, task_id):
        self.watcher.set_lock(sid, True)

    def _ctrl_unlock(self, sid, ctrl, task_id):
        self.watcher.set_lock(sid, False)

    def _ctrl_refresh_manifest(self, sid, ctrl, task_id):
        if ctrl.request_paths:
            for p in ctrl.request_paths: 
                self.io_executor.submit(self._push_file, sid, p)
        else: 
            self._push_full_manifest(sid, ctrl.path)

    def _ctrl_purge(self, sid, ctrl, task_id):
        self.watcher.stop_watching(sid)
        self.sync_mgr.purge(sid)

    def _ctrl_list(self, sid, ctrl, task_id):
        self._push_full_manifest(sid, ctrl.path, task_id=task_id, shallow=True)

    def _ctrl_read(self, sid, ctrl, task_id):
        self._push_file(sid, ctrl.path, task_id=task_id)

    def _ctrl_write(self, sid, ctrl, task_id):
        self._handle_fs_write(sid, ctrl.path, ctrl.content, ctrl.is_dir, task_id=task_id)

    def _ctrl_delete(self, sid, ctrl, task_id):
        self._handle_fs_delete(sid, ctrl.path, task_id=task_id)

    def _get_base_dir(self, session_id, create=False):
        """Resolves the session's effective root directory."""
        if session_id == "__fs_explorer__": return config.FS_ROOT
        watched = self.watcher.get_watch_path(session_id)
        return watched if watched else self.sync_mgr.get_session_dir(session_id, create=create)

    def _push_full_manifest(self, session_id, rel_path=".", task_id="", shallow=False):
        """Generates and pushes a local file manifest to the server."""
        base_dir = self._get_base_dir(session_id, create=True)
        safe_rel = rel_path.lstrip("/") if session_id != "__fs_explorer__" else rel_path
        watch_path = os.path.normpath(os.path.join(base_dir, safe_rel))
        
        if not os.path.exists(watch_path):
            if session_id != "__fs_explorer__": os.makedirs(watch_path, exist_ok=True)
            else: self._send_sync_error(session_id, task_id, f"Path {rel_path} not found")
            if session_id == "__fs_explorer__": return

        files = []
        try:
            if shallow:
                with os.scandir(watch_path) as it:
                    for entry in it:
                        if entry.name == ".cortex_sync": continue
                        is_dir = entry.is_dir() if not entry.is_symlink() else os.path.isdir(entry.path)
                        item_rel = os.path.relpath(os.path.join(watch_path, entry.name), base_dir).replace("\\", "/")
                        files.append(agent_pb2.FileInfo(path=item_rel, size=entry.stat().st_size if not is_dir else 0, hash="", is_dir=is_dir))
            else:
                for root, dirs, filenames in os.walk(watch_path):
                    for name in filenames:
                        abs_p = os.path.join(root, name)
                        h = self.sync_mgr.get_file_hash(abs_p)
                        if h: files.append(agent_pb2.FileInfo(path=os.path.relpath(abs_p, base_dir).replace("\\", "/"), size=os.path.getsize(abs_p), hash=h, is_dir=False))
                    for d in dirs:
                        files.append(agent_pb2.FileInfo(path=os.path.relpath(os.path.join(root, d), base_dir).replace("\\", "/"), size=0, hash="", is_dir=True))
        except Exception as e:
            return self._send_sync_error(session_id, task_id, str(e))

        self._send_manifest_chunks(session_id, task_id, rel_path, files)

    def _send_manifest_chunks(self, sid, tid, root, files):
        """Splits large manifests into chunks for gRPC streaming."""
        chunk_size = 1000
        if not files:
            self.send_message(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=sid, task_id=tid, manifest=agent_pb2.DirectoryManifest(root_path=root, is_final=True))))
            return
        for i in range(0, len(files), chunk_size):
            chunk = files[i:i+chunk_size]
            self.send_message(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=sid, task_id=tid, 
                manifest=agent_pb2.DirectoryManifest(root_path=root, files=chunk, chunk_index=i//chunk_size, is_final=(i+chunk_size >= len(files))))))

    def _handle_fs_write(self, session_id, rel_path, content, is_dir, task_id=""):
        """Handles single file or directory creation."""
        try:
            base = os.path.normpath(self._get_base_dir(session_id, create=True))
            target = os.path.normpath(os.path.join(base, rel_path.lstrip("/") if session_id != "__fs_explorer__" else rel_path))
            
            # Security: Only enforce jail if not in the global file explorer mode
            if session_id != "__fs_explorer__":
                if not target.startswith(base) and (not config.FS_ROOT or not target.startswith(os.path.normpath(config.FS_ROOT))):
                    raise Exception("Path traversal blocked.")
            
            if is_dir: os.makedirs(target, exist_ok=True)
            else:
                os.makedirs(os.path.dirname(target), exist_ok=True)
                with open(target, "wb") as f: f.write(content)
            
            self._send_sync_ok(session_id, task_id, "Resource written")
            self._push_full_manifest(session_id, os.path.dirname(rel_path) or ".", task_id=task_id, shallow=True)
        except Exception as e: self._send_sync_error(session_id, task_id, str(e))

    def _handle_fs_delete(self, session_id, rel_path, task_id=""):
        """Removes a file or directory from the node."""
        try:
            base = os.path.normpath(self._get_base_dir(session_id))
            target = os.path.normpath(os.path.join(base, rel_path.lstrip("/") if session_id != "__fs_explorer__" else rel_path))
            
            if session_id != "__fs_explorer__":
                if not target.startswith(base) and (not config.FS_ROOT or not target.startswith(os.path.normpath(config.FS_ROOT))):
                    raise Exception("Path traversal blocked.")
            
            self.watcher.acknowledge_remote_delete(session_id, rel_path)
            if os.path.isdir(target): shutil.rmtree(target)
            else: os.remove(target)
            
            self._send_sync_ok(session_id, task_id, f"Deleted {rel_path}")
            self._push_full_manifest(session_id, os.path.dirname(rel_path) or ".", task_id=task_id, shallow=True)
        except Exception as e: self._send_sync_error(session_id, task_id, str(e))

    def _async_write_chunk(self, sid, payload):
        """Worker for segmented file writing with path-level locking."""
        path = payload.path
        with self.lock_map_mutex:
            if (sid, path) not in self.write_locks: self.write_locks[(sid, path)] = threading.Lock()
            lock = self.write_locks[(sid, path)]
        
        with lock:
            try:
                if payload.chunk_index == 0: self.watcher.suppress_path(sid, path)
                success = self.sync_mgr.write_chunk(sid, payload)
                if payload.is_final:
                    if success and payload.hash: self.watcher.acknowledge_remote_write(sid, path, payload.hash)
                    self.watcher.unsuppress_path(sid, path)
                    with self.lock_map_mutex: self.write_locks.pop((sid, path), None)
                    self._send_sync_ok(sid, "", f"File {path} synced")
            finally: self.io_semaphore.release()

    def _push_file(self, session_id, rel_path, task_id=""):
        """Streams a file from node to server using 4MB chunks."""
        base = os.path.normpath(self._get_base_dir(session_id, create=False))
        target = os.path.normpath(os.path.join(base, rel_path.lstrip("/") if session_id != "__fs_explorer__" else rel_path))
        
        if session_id != "__fs_explorer__":
            if not target.startswith(base) and (not config.FS_ROOT or not target.startswith(os.path.normpath(config.FS_ROOT))): return

        if not os.path.exists(target):
            if task_id: self._send_sync_error(session_id, task_id, "File not found")
            return

        hasher = hashlib.sha256()
        size = os.path.getsize(target)
        try:
            with open(target, "rb") as f:
                for idx, chunk in enumerate(DataChunker.chunk_file(f)):
                    hasher.update(chunk)
                    is_final = f.tell() >= size
                    self.send_message(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(
                        session_id=session_id, 
                        task_id=task_id,
                        file_data=agent_pb2.FilePayload(
                            path=rel_path.replace("\\", "/"), 
                            chunk=zlib.compress(chunk), 
                            chunk_index=idx, 
                            is_final=is_final, 
                            hash=hasher.hexdigest() if is_final else "", 
                            compressed=True
                        )
                    )))
        except Exception as e: logger.error(f"Push Error: {e}")

    def _handle_task(self, task):
        """Verifies and submits a skill task for execution."""
        print(f"    [📥] Received Task: {task.task_id} (Type: {task.task_type})")
        if not verify_task_signature(task):
            print(f"[Warn] Task signature mismatch for {task.task_id}. Proceeding anyway (DEBUG).")
        
        success, reason = self.skills.submit(task, self.sandbox, self._on_finish, self._on_event)
        if not success:
            self._send_response(task.task_id, agent_pb2.TaskResponse(task_id=task.task_id, status=agent_pb2.TaskResponse.ERROR, stderr=f"Rejection: {reason}"))

    def _on_event(self, event):
        """Forwards skill events to the gRPC stream."""
        self.send_message(event if isinstance(event, agent_pb2.ClientTaskMessage) else agent_pb2.ClientTaskMessage(skill_event=event))

    def _on_finish(self, task_id, result, trace_id):
        """Finalizes a task and sends the response back to the Hub."""
        self._send_response(task_id, agent_pb2.TaskResponse(task_id=task_id, stdout=result.get("stdout", ""), stderr=result.get("stderr", ""), status=result.get("status", 0), trace_id=trace_id))

    def _send_response(self, task_id, response, status_override=None):
        """Helper to queue a standard task response with memory safety caps."""
        res = response or agent_pb2.TaskResponse(task_id=task_id)
        if status_override: res.status = status_override
        
        # Redundant safety cap (5MB) to protect gRPC/Queue memory
        SAFETY_CAP = 5 * 1024 * 1024
        for field in ['stdout', 'stderr']:
            val = getattr(res, field, "")
            if len(val) > SAFETY_CAP:
                logger.warning(f"Truncating excessive {field} ({len(val):,} bytes) for task {task_id}")
                setattr(res, field, val[:SAFETY_CAP // 2] + "\n... [TRUNCATED DUE TO SIZE] ...\n" + val[-(SAFETY_CAP // 2):])

        self.send_message(agent_pb2.ClientTaskMessage(task_response=res))

    def _send_sync_ok(self, sid, tid, msg):
        self.send_message(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=sid, task_id=tid, status=agent_pb2.SyncStatus(code=agent_pb2.SyncStatus.OK, message=msg))))

    def _send_sync_error(self, sid, tid, msg):
        self.send_message(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=sid, task_id=tid, status=agent_pb2.SyncStatus(code=agent_pb2.SyncStatus.ERROR, message=msg))))

    def _on_sync_delta(self, session_id, payload):
        self.send_message(agent_pb2.ClientTaskMessage(file_sync=payload if isinstance(payload, agent_pb2.FileSyncMessage) else agent_pb2.FileSyncMessage(session_id=session_id, file_data=payload)))

    def shutdown(self):
        """Gracefully shuts down the node."""
        self._stop_event.set()
        self.skills.shutdown()
        self.stop() # From MeshNodeCore
        print("[*] Node shutdown complete.")