Newer
Older
cortex-hub / agent-node / src / agent_node / node.py
import threading
import queue
import time
import os
import hashlib
import logging
import json
import zlib
import shutil
import socket
from concurrent.futures import ThreadPoolExecutor

try:
    import psutil
except ImportError:
    psutil = None

from protos import agent_pb2, agent_pb2_grpc
from agent_node.skills.manager import SkillManager
from agent_node.core.sandbox import SandboxEngine
from agent_node.core.sync import NodeSyncManager
from agent_node.core.watcher import WorkspaceWatcher
from agent_node.utils.auth import verify_task_signature, verify_server_message_signature
from agent_node.utils.network import get_secure_stub
import agent_node.config as config
from agent_node.utils.watchdog import watchdog
from agent_node.core.regex_patterns import ANSI_ESCAPE

logger = logging.getLogger(__name__)

class AgentNode:
    """
    Agent Core: Orchestrates local skills and maintains gRPC connectivity.
    Refactored for structural clarity and modular message handling.
    """
    def __init__(self):
        self.node_id = config.NODE_ID
        self.sandbox = SandboxEngine()
        self.sync_mgr = NodeSyncManager()
        self.skills = SkillManager(max_workers=config.MAX_SKILL_WORKERS, sync_mgr=self.sync_mgr)
        self.watcher = WorkspaceWatcher(self._on_sync_delta)
        self.task_queue = queue.Queue(maxsize=250) 
        self.stub = None
        self.channel = None
        self._stop_event = threading.Event()
        self._refresh_stub()

        self.io_executor = ThreadPoolExecutor(max_workers=10, thread_name_prefix="NodeIO")
        self.io_semaphore = threading.Semaphore(50) 
        self.write_locks = {} 
        self.lock_map_mutex = threading.Lock()

    def _refresh_stub(self):
        """Force refreshes the gRPC channel and stub."""
        if self.channel:
            try: self.channel.close()
            except: pass
        self.stub, self.channel = get_secure_stub()
        self._setup_connectivity_watcher()

    def _setup_connectivity_watcher(self):
        """Monitors gRPC channel state."""
        import grpc
        self._last_grpc_state = None
        def _on_state_change(state):
            if not self._stop_event.is_set() and state != self._last_grpc_state:
                print(f"[*] [gRPC-State] {state}", flush=True)
                self._last_grpc_state = state
        self.channel.subscribe(_on_state_change, try_to_connect=True)

    def sync_configuration(self):
        """Handshake with the Orchestrator to sync policy and metadata."""
        while True:
            config.reload()
            self.node_id = config.NODE_ID
            if not self.stub: self._refresh_stub()

            print(f"[*] Handshake with Orchestrator: {self.node_id}")
            caps = self._collect_capabilities()
            reg_req = agent_pb2.RegistrationRequest(
                node_id=self.node_id, auth_token=config.AUTH_TOKEN,
                node_description=config.NODE_DESC,
                capabilities={k: str(v).lower() if isinstance(v, bool) else str(v) for k, v in caps.items()}
            )

            try:
                res = self.stub.SyncConfiguration(reg_req, timeout=10)
                if res.success:
                    self.sandbox.sync(res.policy)
                    self._apply_skill_config(res.policy.skill_config_json)
                    print("[OK] Handshake successful. Policy Synced.")
                    break
                else:
                    print(f"[!] Rejection: {res.error_message}. Retrying in 5s...")
                    time.sleep(5)
            except Exception as e:
                print(f"[!] Connection Fail: {str(e)}. Retrying in 5s...")
                time.sleep(5)

    def _apply_skill_config(self, config_json):
        """Applies dynamic skill configurations from the server."""
        if not config_json: return
        try:
            cfg = json.loads(config_json)
            for skill in self.skills.skills.values():
                if hasattr(skill, "apply_config"): skill.apply_config(cfg)
        except Exception as e:
            logger.error(f"Error applying skill config: {e}")

    def _collect_capabilities(self) -> dict:
        """Collects hardware and OS metadata."""
        from agent_node.utils.platform_metrics import get_platform_metrics
        caps = get_platform_metrics().collect_capabilities()
        try:
            s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
            s.settimeout(0)
            s.connect(('10.254.254.254', 1))
            caps["local_ip"] = s.getsockname()[0]
            s.close()
        except: caps["local_ip"] = "unknown"
        return caps

    def start_health_reporting(self):
        """Launches the background health reporting stream."""
        from agent_node.utils.platform_metrics import get_platform_metrics
        metrics_tool = get_platform_metrics()
        
        def _report():
            while not self._stop_event.is_set():
                try:
                    def _gen():
                        while not self._stop_event.is_set():
                            ids = self.skills.get_active_ids()
                            vmem = psutil.virtual_memory() if psutil else None
                            yield agent_pb2.Heartbeat(
                                node_id=self.node_id, 
                                cpu_usage_percent=psutil.cpu_percent() if psutil else 0,
                                memory_usage_percent=vmem.percent if vmem else 0,
                                active_worker_count=len(ids), 
                                max_worker_capacity=config.MAX_SKILL_WORKERS, 
                                running_task_ids=ids,
                                cpu_count=psutil.cpu_count() if psutil else 0,
                                memory_used_gb=vmem.used/(1024**3) if vmem else 0,
                                memory_total_gb=vmem.total/(1024**3) if vmem else 0,
                                load_avg=metrics_tool.get_load_avg()
                            )
                            time.sleep(max(0, config.HEALTH_REPORT_INTERVAL - 1.0))
                    
                    for _ in self.stub.ReportHealth(_gen()): watchdog.tick() 
                except Exception as e:
                    time.sleep(5)
        
        threading.Thread(target=_report, daemon=True, name="HealthReporter").start()

    def run_task_stream(self):
        """Main bi-directional task stream with auto-reconnection."""
        while True:
            try:
                def _gen():
                    yield agent_pb2.ClientTaskMessage(announce=agent_pb2.NodeAnnounce(node_id=self.node_id))
                    while True: yield self.task_queue.get()
                
                responses = self.stub.TaskStream(_gen())
                print(f"[*] Task stream connected ({self.node_id}).")
                for msg in responses:
                    watchdog.tick()
                    self._process_server_message(msg)
            except Exception as e:
                print(f"[!] Task stream error: {e}. Reconnecting...")
                self._refresh_stub()
                time.sleep(5)

    def _process_server_message(self, msg):
        """Routes inbound server messages to their respective handlers."""
        if not verify_server_message_signature(msg):
            logger.warning("Invalid server message signature. Dropping.")
            return
            
        kind = msg.WhichOneof('payload')
        if kind == 'task_request': self._handle_task(msg.task_request)
        elif kind == 'task_cancel': self._handle_cancel(msg.task_cancel)
        elif kind == 'work_pool_update': self._handle_work_pool(msg.work_pool_update)
        elif kind == 'file_sync': self._handle_file_sync(msg.file_sync)
        elif kind == 'policy_update':
            self.sandbox.sync(msg.policy_update)
            self._apply_skill_config(msg.policy_update.skill_config_json)

    def _handle_cancel(self, cancel_req):
        """Cancels an active task."""
        if self.skills.cancel(cancel_req.task_id):
            self._send_response(cancel_req.task_id, None, agent_pb2.TaskResponse.CANCELLED)

    def _handle_work_pool(self, update):
        """Claims tasks from the global work pool with randomized backoff."""
        if len(self.skills.get_active_ids()) < config.MAX_SKILL_WORKERS:
            for tid in update.available_task_ids:
                import random
                time.sleep(random.uniform(0.1, 0.5))
                self.task_queue.put(agent_pb2.ClientTaskMessage(
                    task_claim=agent_pb2.TaskClaimRequest(task_id=tid, node_id=self.node_id)
                ))

    def _handle_file_sync(self, fs):
        """Dispatches file sync messages to specialized sub-handlers."""
        sid = fs.session_id
        if fs.HasField("manifest"): self._on_sync_manifest(sid, fs.manifest)
        elif fs.HasField("file_data"): self._on_sync_data(sid, fs.file_data)
        elif fs.HasField("control"): self._on_sync_control(sid, fs.control, fs.task_id)

    def _on_sync_manifest(self, sid, manifest):
        """Reconciles local state with a remote manifest."""
        drift = self.sync_mgr.handle_manifest(sid, manifest, on_purge_callback=lambda p: self.watcher.acknowledge_remote_delete(sid, p))
        status = agent_pb2.SyncStatus(
            code=agent_pb2.SyncStatus.RECONCILE_REQUIRED if drift else agent_pb2.SyncStatus.OK,
            message=f"Drift in {len(drift)} files" if drift else "Synchronized",
            reconcile_paths=drift
        )
        self.task_queue.put(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=sid, status=status)))

    def _on_sync_data(self, sid, file_data):
        """Offloads disk I/O to a background worker pool."""
        self.io_semaphore.acquire()
        try: self.io_executor.submit(self._async_write_chunk, sid, file_data)
        except: self.io_semaphore.release()

    def _on_sync_control(self, sid, ctrl, task_id):
        """Handles sync control actions like watching, locking, or directory listing."""
        action = ctrl.action
        if action == agent_pb2.SyncControl.START_WATCHING:
            self.watcher.start_watching(sid, ctrl.path if os.path.isabs(ctrl.path) else os.path.join(self.sync_mgr.get_session_dir(sid), ctrl.path))
        elif action == agent_pb2.SyncControl.STOP_WATCHING: self.watcher.stop_watching(sid)
        elif action == agent_pb2.SyncControl.LOCK: self.watcher.set_lock(sid, True)
        elif action == agent_pb2.SyncControl.UNLOCK: self.watcher.set_lock(sid, False)
        elif action in (agent_pb2.SyncControl.REFRESH_MANIFEST, agent_pb2.SyncControl.RESYNC):
            if ctrl.request_paths:
                for p in ctrl.request_paths: self.io_executor.submit(self._push_file, sid, p)
            else: self._push_full_manifest(sid, ctrl.path)
        elif action == agent_pb2.SyncControl.PURGE: 
            self.watcher.stop_watching(sid)
            self.sync_mgr.purge(sid)
        elif action == agent_pb2.SyncControl.LIST: self._push_full_manifest(sid, ctrl.path, task_id=task_id, shallow=True)
        elif action == agent_pb2.SyncControl.READ: self._push_file(sid, ctrl.path, task_id=task_id)
        elif action == agent_pb2.SyncControl.WRITE: self._handle_fs_write(sid, ctrl.path, ctrl.content, ctrl.is_dir, task_id=task_id)
        elif action == agent_pb2.SyncControl.DELETE: self._handle_fs_delete(sid, ctrl.path, task_id=task_id)

    def _get_base_dir(self, session_id, create=False):
        """Resolves the session's effective root directory."""
        if session_id == "__fs_explorer__": return config.FS_ROOT
        watched = self.watcher.get_watch_path(session_id)
        return watched if watched else self.sync_mgr.get_session_dir(session_id, create=create)

    def _push_full_manifest(self, session_id, rel_path=".", task_id="", shallow=False):
        """Generates and pushes a local file manifest to the server."""
        base_dir = self._get_base_dir(session_id, create=True)
        safe_rel = rel_path.lstrip("/") if session_id != "__fs_explorer__" else rel_path
        watch_path = os.path.normpath(os.path.join(base_dir, safe_rel))
        
        if not os.path.exists(watch_path):
            if session_id != "__fs_explorer__": os.makedirs(watch_path, exist_ok=True)
            else: self._send_sync_error(session_id, task_id, f"Path {rel_path} not found")
            if session_id == "__fs_explorer__": return

        files = []
        try:
            if shallow:
                with os.scandir(watch_path) as it:
                    for entry in it:
                        if entry.name == ".cortex_sync": continue
                        is_dir = entry.is_dir() if not entry.is_symlink() else os.path.isdir(entry.path)
                        item_rel = os.path.relpath(os.path.join(watch_path, entry.name), base_dir).replace("\\", "/")
                        files.append(agent_pb2.FileInfo(path=item_rel, size=entry.stat().st_size if not is_dir else 0, hash="", is_dir=is_dir))
            else:
                for root, dirs, filenames in os.walk(watch_path):
                    for name in filenames:
                        abs_p = os.path.join(root, name)
                        h = self.sync_mgr.get_file_hash(abs_p)
                        if h: files.append(agent_pb2.FileInfo(path=os.path.relpath(abs_p, base_dir).replace("\\", "/"), size=os.path.getsize(abs_p), hash=h, is_dir=False))
                    for d in dirs:
                        files.append(agent_pb2.FileInfo(path=os.path.relpath(os.path.join(root, d), base_dir).replace("\\", "/"), size=0, hash="", is_dir=True))
        except Exception as e:
            return self._send_sync_error(session_id, task_id, str(e))

        self._send_manifest_chunks(session_id, task_id, rel_path, files)

    def _send_manifest_chunks(self, sid, tid, root, files):
        """Splits large manifests into chunks for gRPC streaming."""
        chunk_size = 1000
        if not files:
            self.task_queue.put(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=sid, task_id=tid, manifest=agent_pb2.DirectoryManifest(root_path=root, is_final=True))))
            return
        for i in range(0, len(files), chunk_size):
            chunk = files[i:i+chunk_size]
            self.task_queue.put(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=sid, task_id=tid, 
                manifest=agent_pb2.DirectoryManifest(root_path=root, files=chunk, chunk_index=i//chunk_size, is_final=(i+chunk_size >= len(files))))))

    def _handle_fs_write(self, session_id, rel_path, content, is_dir, task_id=""):
        """Handles single file or directory creation."""
        try:
            base = os.path.normpath(self._get_base_dir(session_id, create=True))
            target = os.path.normpath(os.path.join(base, rel_path.lstrip("/") if session_id != "__fs_explorer__" else rel_path))
            if not target.startswith(base) and (not config.FS_ROOT or not target.startswith(os.path.normpath(config.FS_ROOT))):
                raise Exception("Path traversal blocked.")
            
            if is_dir: os.makedirs(target, exist_ok=True)
            else:
                os.makedirs(os.path.dirname(target), exist_ok=True)
                with open(target, "wb") as f: f.write(content)
            
            self._send_sync_ok(session_id, task_id, "Resource written")
            self._push_full_manifest(session_id, os.path.dirname(rel_path) or ".", task_id=task_id, shallow=True)
        except Exception as e: self._send_sync_error(session_id, task_id, str(e))

    def _handle_fs_delete(self, session_id, rel_path, task_id=""):
        """Removes a file or directory from the node."""
        try:
            base = os.path.normpath(self._get_base_dir(session_id))
            target = os.path.normpath(os.path.join(base, rel_path.lstrip("/") if session_id != "__fs_explorer__" else rel_path))
            if not target.startswith(base) and (not config.FS_ROOT or not target.startswith(os.path.normpath(config.FS_ROOT))):
                raise Exception("Path traversal blocked.")
            
            self.watcher.acknowledge_remote_delete(session_id, rel_path)
            if os.path.isdir(target): shutil.rmtree(target)
            else: os.remove(target)
            
            self._send_sync_ok(session_id, task_id, f"Deleted {rel_path}")
            self._push_full_manifest(session_id, os.path.dirname(rel_path) or ".", task_id=task_id, shallow=True)
        except Exception as e: self._send_sync_error(session_id, task_id, str(e))

    def _async_write_chunk(self, sid, payload):
        """Worker for segmented file writing with path-level locking."""
        path = payload.path
        with self.lock_map_mutex:
            if (sid, path) not in self.write_locks: self.write_locks[(sid, path)] = threading.Lock()
            lock = self.write_locks[(sid, path)]
        
        with lock:
            try:
                if payload.chunk_index == 0: self.watcher.suppress_path(sid, path)
                success = self.sync_mgr.write_chunk(sid, payload)
                if payload.is_final:
                    if success and payload.hash: self.watcher.acknowledge_remote_write(sid, path, payload.hash)
                    self.watcher.unsuppress_path(sid, path)
                    with self.lock_map_mutex: self.write_locks.pop((sid, path), None)
                    self._send_sync_ok(sid, "", f"File {path} synced")
            finally: self.io_semaphore.release()

    def _push_file(self, session_id, rel_path, task_id=""):
        """Streams a file from node to server using 4MB chunks."""
        base = os.path.normpath(self._get_base_dir(session_id, create=False))
        target = os.path.normpath(os.path.join(base, rel_path.lstrip("/") if session_id != "__fs_explorer__" else rel_path))
        if not target.startswith(base) and (not config.FS_ROOT or not target.startswith(os.path.normpath(config.FS_ROOT))): return

        if not os.path.exists(target):
            if task_id: self._send_sync_error(session_id, task_id, "File not found")
            return

        hasher = hashlib.sha256()
        size, chunk_size = os.path.getsize(target), 4 * 1024 * 1024
        try:
            with open(target, "rb") as f:
                idx = 0
                while True:
                    chunk = f.read(chunk_size)
                    if not chunk and idx > 0: break
                    hasher.update(chunk)
                    is_final = f.tell() >= size
                    self.task_queue.put(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=session_id, task_id=task_id,
                        file_data=agent_pb2.FilePayload(path=rel_path.replace("\\", "/"), chunk=zlib.compress(chunk), chunk_index=idx, is_final=is_final, hash=hasher.hexdigest() if is_final else "", compressed=True))))
                    if is_final: break
                    idx += 1
        except Exception as e: logger.error(f"Push Error: {e}")

    def _handle_task(self, task):
        """Verifies and submits a skill task for execution."""
        if not verify_task_signature(task):
            return self._send_response(task.task_id, agent_pb2.TaskResponse(task_id=task.task_id, status=agent_pb2.TaskResponse.ERROR, stderr="HMAC signature mismatch"))
        
        success, reason = self.skills.submit(task, self.sandbox, self._on_finish, self._on_event)
        if not success:
            self._send_response(task.task_id, agent_pb2.TaskResponse(task_id=task.task_id, status=agent_pb2.TaskResponse.ERROR, stderr=f"Rejection: {reason}"))

    def _on_event(self, event):
        """Forwards skill events to the gRPC stream."""
        self.task_queue.put(event if isinstance(event, agent_pb2.ClientTaskMessage) else agent_pb2.ClientTaskMessage(skill_event=event))

    def _on_finish(self, task_id, result, trace_id):
        """Finalizes a task and sends the response back to the Hub."""
        self._send_response(task_id, agent_pb2.TaskResponse(task_id=task_id, stdout=result.get("stdout", ""), stderr=result.get("stderr", ""), status=result.get("status", 0), trace_id=trace_id))

    def _send_response(self, task_id, response, status_override=None):
        """Helper to queue a standard task response."""
        res = response or agent_pb2.TaskResponse(task_id=task_id)
        if status_override: res.status = status_override
        self.task_queue.put(agent_pb2.ClientTaskMessage(task_response=res))

    def _send_sync_ok(self, sid, tid, msg):
        self.task_queue.put(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=sid, task_id=tid, status=agent_pb2.SyncStatus(code=agent_pb2.SyncStatus.OK, message=msg))))

    def _send_sync_error(self, sid, tid, msg):
        self.task_queue.put(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=sid, task_id=tid, status=agent_pb2.SyncStatus(code=agent_pb2.SyncStatus.ERROR, message=msg))))

    def _on_sync_delta(self, session_id, payload):
        self.task_queue.put(agent_pb2.ClientTaskMessage(file_sync=payload if isinstance(payload, agent_pb2.FileSyncMessage) else agent_pb2.FileSyncMessage(session_id=session_id, file_data=payload)))

    def shutdown(self):
        """Gracefully shuts down the node."""
        self._stop_event.set()
        self.skills.shutdown()
        if self.channel: self.channel.close()
        print("[*] Node shutdown complete.")