import threading
import queue
import time
import os
import hashlib
import logging
import json
import zlib
import shutil
import socket
import traceback
from concurrent.futures import ThreadPoolExecutor
try:
import psutil
except ImportError:
psutil = None
from mesh_core import agent_pb2, agent_pb2_grpc
from mesh_core.engines import MeshNodeCore
from agent_node.skills.manager import SkillManager
from agent_node.core.sandbox import SandboxEngine
from agent_node.core.sync import NodeSyncManager
from agent_node.core.watcher import WorkspaceWatcher
from agent_node.utils.auth import verify_task_signature, verify_server_message_signature
from agent_node.utils.network import get_secure_stub
import agent_node.config as config
from agent_node.utils.watchdog import watchdog
from agent_node.core.regex_patterns import ANSI_ESCAPE
from mesh_core.transport import GrpcMeshTransport
from mesh_core.utils import DataChunker
import agent_node.config as config
from agent_node.utils.watchdog import watchdog
from agent_node.utils.network import get_secure_stub
logger = logging.getLogger(__name__)
class AgentNode(MeshNodeCore):
"""
Agent Core: Orchestrates local skills and maintains connectivity.
Now leverages MeshNodeCore for transport-agnostic orchestration.
"""
def __init__(self):
# 1. Initialize Transport (gRPC by default for production)
# Pass the secure stub factory to keep existing security logic
self.transport = GrpcMeshTransport(config.NODE_ID, get_secure_stub, auth_token=config.AUTH_TOKEN)
# 2. Initialize Core Engine
super().__init__(config.NODE_ID, self.transport)
# 3. Initialize Agent-Specific Modules
self.sandbox = SandboxEngine()
self.sync_mgr = NodeSyncManager()
self.skills = SkillManager(max_workers=config.MAX_SKILL_WORKERS, sync_mgr=self.sync_mgr)
self.watcher = WorkspaceWatcher(self._on_sync_delta)
self._stop_event = threading.Event()
self.io_executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="NodeIO")
self.io_semaphore = threading.Semaphore(50)
self.write_locks = {}
self.lock_map_mutex = threading.Lock()
# 4. Bind Mesh Events
self.on_task = self._handle_task
self.on_cancel = self._handle_cancel
self.on_policy = self._on_policy_update
self.on_sync = self._handle_file_sync
self.on_ready = self._on_mesh_ready
self.on_disconnect = self._on_mesh_disconnect
def _on_mesh_ready(self, msg):
print(f"[Mesh] Connected and Authorized. Policy Synced.")
if hasattr(msg, 'policy_update'):
self._on_policy_update(msg.policy_update)
elif hasattr(msg, 'policy'):
self._on_policy_update(msg.policy)
def _on_mesh_disconnect(self):
print(f"[Mesh] Disconnected from Hub.")
def _on_policy_update(self, policy):
self.sandbox.sync(policy)
self._apply_skill_config(policy.skill_config_json)
def sync_configuration(self):
""" Handshake now handled by Transport/Core. This remains for compatibility if needed. """
# In the new SDK model, the transport handles the initial SyncConfiguration call
# or it's wrapped in the connect() flow.
pass
def start_health_reporting(self):
""" Launches background health reporting using the transport. """
def _report():
while not self._stop_event.is_set():
if self.transport.is_connected():
try:
ids = self.skills.get_active_ids()
vmem = psutil.virtual_memory() if psutil else None
hb = agent_pb2.Heartbeat(
node_id=self.node_id,
cpu_usage_percent=psutil.cpu_percent() if psutil else 0,
memory_usage_percent=vmem.percent if vmem else 0,
active_worker_count=len(ids),
max_worker_capacity=config.MAX_SKILL_WORKERS,
running_task_ids=ids,
cpu_count=psutil.cpu_count() if psutil else 0,
memory_used_gb=vmem.used/(1024**3) if vmem else 0,
memory_total_gb=vmem.total/(1024**3) if vmem else 0
)
# Health is sent via a separate stream in gRPC,
# so we use the transport's specialized method if it exists
if hasattr(self.transport, 'send_health'):
self.transport.send_health(hb)
watchdog.tick()
except Exception as e:
logger.error(f"Health report error: {e}")
time.sleep(config.HEALTH_REPORT_INTERVAL)
threading.Thread(target=_report, daemon=True, name="HealthReporter").start()
def run_task_stream(self):
""" Starts the core engine which manages the task stream. """
if not self.start(): # From MeshNodeCore
raise RuntimeError("Handshake failed, node cannot start.")
while not self._stop_event.is_set():
if not self.transport.is_connected():
raise RuntimeError("Transport disconnected.")
time.sleep(1)
def _handle_cancel(self, cancel_req):
"""Cancels an active task or an entire session's background tasks."""
session_id = getattr(cancel_req, 'session_id', None)
if session_id and not cancel_req.task_id:
logger.info(f"[*] Cancelling all tasks for session: {session_id}")
self.skills.cancel_session(session_id)
return
if self.skills.cancel(cancel_req.task_id):
self._send_response(cancel_req.task_id, None, agent_pb2.TaskResponse.CANCELLED)
def _handle_work_pool(self, update):
"""Claims tasks from the global work pool with randomized backoff."""
if len(self.skills.get_active_ids()) < config.MAX_SKILL_WORKERS:
for tid in update.available_task_ids:
import random
time.sleep(random.uniform(0.1, 0.5))
self.send_message(agent_pb2.ClientTaskMessage(
task_claim=agent_pb2.TaskClaimRequest(task_id=tid, node_id=self.node_id)
))
def _handle_file_sync(self, fs):
"""Dispatches file sync messages to specialized sub-handlers."""
sid = fs.session_id
if fs.HasField("manifest"): self._on_sync_manifest(sid, fs.manifest)
elif fs.HasField("file_data"): self._on_sync_data(sid, fs.file_data)
elif fs.HasField("control"): self._on_sync_control(sid, fs.control, fs.task_id)
def _on_sync_manifest(self, sid, manifest):
"""Reconciles local state with a remote manifest."""
drift = self.sync_mgr.handle_manifest(sid, manifest, on_purge_callback=lambda p: self.watcher.acknowledge_remote_delete(sid, p))
status = agent_pb2.SyncStatus(
code=agent_pb2.SyncStatus.RECONCILE_REQUIRED if drift else agent_pb2.SyncStatus.OK,
message=f"Drift in {len(drift)} files" if drift else "Synchronized",
reconcile_paths=drift
)
self.send_message(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=sid, status=status)))
def _on_sync_data(self, sid, file_data):
"""Offloads disk I/O to a background worker pool."""
self.io_semaphore.acquire()
try: self.io_executor.submit(self._async_write_chunk, sid, file_data)
except: self.io_semaphore.release()
def _on_sync_control(self, sid, ctrl, task_id):
"""Handles sync control actions like watching, locking, or directory listing."""
action = ctrl.action
if action == agent_pb2.SyncControl.START_WATCHING:
self.watcher.start_watching(sid, ctrl.path if os.path.isabs(ctrl.path) else os.path.join(self.sync_mgr.get_session_dir(sid), ctrl.path))
elif action == agent_pb2.SyncControl.STOP_WATCHING: self.watcher.stop_watching(sid)
elif action == agent_pb2.SyncControl.LOCK: self.watcher.set_lock(sid, True)
elif action == agent_pb2.SyncControl.UNLOCK: self.watcher.set_lock(sid, False)
elif action in (agent_pb2.SyncControl.REFRESH_MANIFEST, agent_pb2.SyncControl.RESYNC):
if ctrl.request_paths:
for p in ctrl.request_paths: self.io_executor.submit(self._push_file, sid, p)
else: self._push_full_manifest(sid, ctrl.path)
elif action == agent_pb2.SyncControl.PURGE:
self.watcher.stop_watching(sid)
self.sync_mgr.purge(sid)
elif action == agent_pb2.SyncControl.LIST: self._push_full_manifest(sid, ctrl.path, task_id=task_id, shallow=True)
elif action == agent_pb2.SyncControl.READ: self._push_file(sid, ctrl.path, task_id=task_id)
elif action == agent_pb2.SyncControl.WRITE: self._handle_fs_write(sid, ctrl.path, ctrl.content, ctrl.is_dir, task_id=task_id)
elif action == agent_pb2.SyncControl.DELETE: self._handle_fs_delete(sid, ctrl.path, task_id=task_id)
def _get_base_dir(self, session_id, create=False):
"""Resolves the session's effective root directory."""
if session_id == "__fs_explorer__": return config.FS_ROOT
watched = self.watcher.get_watch_path(session_id)
return watched if watched else self.sync_mgr.get_session_dir(session_id, create=create)
def _push_full_manifest(self, session_id, rel_path=".", task_id="", shallow=False):
"""Generates and pushes a local file manifest to the server."""
base_dir = self._get_base_dir(session_id, create=True)
safe_rel = rel_path.lstrip("/") if session_id != "__fs_explorer__" else rel_path
watch_path = os.path.normpath(os.path.join(base_dir, safe_rel))
if not os.path.exists(watch_path):
if session_id != "__fs_explorer__": os.makedirs(watch_path, exist_ok=True)
else: self._send_sync_error(session_id, task_id, f"Path {rel_path} not found")
if session_id == "__fs_explorer__": return
files = []
try:
if shallow:
with os.scandir(watch_path) as it:
for entry in it:
if entry.name == ".cortex_sync": continue
is_dir = entry.is_dir() if not entry.is_symlink() else os.path.isdir(entry.path)
item_rel = os.path.relpath(os.path.join(watch_path, entry.name), base_dir).replace("\\", "/")
files.append(agent_pb2.FileInfo(path=item_rel, size=entry.stat().st_size if not is_dir else 0, hash="", is_dir=is_dir))
else:
for root, dirs, filenames in os.walk(watch_path):
for name in filenames:
abs_p = os.path.join(root, name)
h = self.sync_mgr.get_file_hash(abs_p)
if h: files.append(agent_pb2.FileInfo(path=os.path.relpath(abs_p, base_dir).replace("\\", "/"), size=os.path.getsize(abs_p), hash=h, is_dir=False))
for d in dirs:
files.append(agent_pb2.FileInfo(path=os.path.relpath(os.path.join(root, d), base_dir).replace("\\", "/"), size=0, hash="", is_dir=True))
except Exception as e:
return self._send_sync_error(session_id, task_id, str(e))
self._send_manifest_chunks(session_id, task_id, rel_path, files)
def _send_manifest_chunks(self, sid, tid, root, files):
"""Splits large manifests into chunks for gRPC streaming."""
chunk_size = 1000
if not files:
self.send_message(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=sid, task_id=tid, manifest=agent_pb2.DirectoryManifest(root_path=root, is_final=True))))
return
for i in range(0, len(files), chunk_size):
chunk = files[i:i+chunk_size]
self.send_message(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=sid, task_id=tid,
manifest=agent_pb2.DirectoryManifest(root_path=root, files=chunk, chunk_index=i//chunk_size, is_final=(i+chunk_size >= len(files))))))
def _handle_fs_write(self, session_id, rel_path, content, is_dir, task_id=""):
"""Handles single file or directory creation."""
try:
base = os.path.normpath(self._get_base_dir(session_id, create=True))
target = os.path.normpath(os.path.join(base, rel_path.lstrip("/") if session_id != "__fs_explorer__" else rel_path))
# Security: Only enforce jail if not in the global file explorer mode
if session_id != "__fs_explorer__":
if not target.startswith(base) and (not config.FS_ROOT or not target.startswith(os.path.normpath(config.FS_ROOT))):
raise Exception("Path traversal blocked.")
if is_dir: os.makedirs(target, exist_ok=True)
else:
os.makedirs(os.path.dirname(target), exist_ok=True)
with open(target, "wb") as f: f.write(content)
self._send_sync_ok(session_id, task_id, "Resource written")
self._push_full_manifest(session_id, os.path.dirname(rel_path) or ".", task_id=task_id, shallow=True)
except Exception as e: self._send_sync_error(session_id, task_id, str(e))
def _handle_fs_delete(self, session_id, rel_path, task_id=""):
"""Removes a file or directory from the node."""
try:
base = os.path.normpath(self._get_base_dir(session_id))
target = os.path.normpath(os.path.join(base, rel_path.lstrip("/") if session_id != "__fs_explorer__" else rel_path))
if session_id != "__fs_explorer__":
if not target.startswith(base) and (not config.FS_ROOT or not target.startswith(os.path.normpath(config.FS_ROOT))):
raise Exception("Path traversal blocked.")
self.watcher.acknowledge_remote_delete(session_id, rel_path)
if os.path.isdir(target): shutil.rmtree(target)
else: os.remove(target)
self._send_sync_ok(session_id, task_id, f"Deleted {rel_path}")
self._push_full_manifest(session_id, os.path.dirname(rel_path) or ".", task_id=task_id, shallow=True)
except Exception as e: self._send_sync_error(session_id, task_id, str(e))
def _async_write_chunk(self, sid, payload):
"""Worker for segmented file writing with path-level locking."""
path = payload.path
with self.lock_map_mutex:
if (sid, path) not in self.write_locks: self.write_locks[(sid, path)] = threading.Lock()
lock = self.write_locks[(sid, path)]
with lock:
try:
if payload.chunk_index == 0: self.watcher.suppress_path(sid, path)
success = self.sync_mgr.write_chunk(sid, payload)
if payload.is_final:
if success and payload.hash: self.watcher.acknowledge_remote_write(sid, path, payload.hash)
self.watcher.unsuppress_path(sid, path)
with self.lock_map_mutex: self.write_locks.pop((sid, path), None)
self._send_sync_ok(sid, "", f"File {path} synced")
finally: self.io_semaphore.release()
def _push_file(self, session_id, rel_path, task_id=""):
"""Streams a file from node to server using 4MB chunks."""
base = os.path.normpath(self._get_base_dir(session_id, create=False))
target = os.path.normpath(os.path.join(base, rel_path.lstrip("/") if session_id != "__fs_explorer__" else rel_path))
if session_id != "__fs_explorer__":
if not target.startswith(base) and (not config.FS_ROOT or not target.startswith(os.path.normpath(config.FS_ROOT))): return
if not os.path.exists(target):
if task_id: self._send_sync_error(session_id, task_id, "File not found")
return
hasher = hashlib.sha256()
size = os.path.getsize(target)
try:
with open(target, "rb") as f:
for idx, chunk in enumerate(DataChunker.chunk_file(f)):
hasher.update(chunk)
is_final = f.tell() >= size
self.send_message(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(
session_id=session_id,
task_id=task_id,
file_data=agent_pb2.FilePayload(
path=rel_path.replace("\\", "/"),
chunk=zlib.compress(chunk),
chunk_index=idx,
is_final=is_final,
hash=hasher.hexdigest() if is_final else "",
compressed=True
)
)))
except Exception as e: logger.error(f"Push Error: {e}")
def _handle_task(self, task):
"""Verifies and submits a skill task for execution."""
print(f" [📥] Received Task: {task.task_id} (Type: {task.task_type})")
if not verify_task_signature(task):
print(f"[Warn] Task signature mismatch for {task.task_id}. Proceeding anyway (DEBUG).")
success, reason = self.skills.submit(task, self.sandbox, self._on_finish, self._on_event)
if not success:
self._send_response(task.task_id, agent_pb2.TaskResponse(task_id=task.task_id, status=agent_pb2.TaskResponse.ERROR, stderr=f"Rejection: {reason}"))
def _on_event(self, event):
"""Forwards skill events to the gRPC stream."""
self.send_message(event if isinstance(event, agent_pb2.ClientTaskMessage) else agent_pb2.ClientTaskMessage(skill_event=event))
def _on_finish(self, task_id, result, trace_id):
"""Finalizes a task and sends the response back to the Hub."""
self._send_response(task_id, agent_pb2.TaskResponse(task_id=task_id, stdout=result.get("stdout", ""), stderr=result.get("stderr", ""), status=result.get("status", 0), trace_id=trace_id))
def _send_response(self, task_id, response, status_override=None):
"""Helper to queue a standard task response with memory safety caps."""
res = response or agent_pb2.TaskResponse(task_id=task_id)
if status_override: res.status = status_override
# Redundant safety cap (5MB) to protect gRPC/Queue memory
SAFETY_CAP = 5 * 1024 * 1024
for field in ['stdout', 'stderr']:
val = getattr(res, field, "")
if len(val) > SAFETY_CAP:
logger.warning(f"Truncating excessive {field} ({len(val):,} bytes) for task {task_id}")
setattr(res, field, val[:SAFETY_CAP // 2] + "\n... [TRUNCATED DUE TO SIZE] ...\n" + val[-(SAFETY_CAP // 2):])
self.send_message(agent_pb2.ClientTaskMessage(task_response=res))
def _send_sync_ok(self, sid, tid, msg):
self.send_message(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=sid, task_id=tid, status=agent_pb2.SyncStatus(code=agent_pb2.SyncStatus.OK, message=msg))))
def _send_sync_error(self, sid, tid, msg):
self.send_message(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=sid, task_id=tid, status=agent_pb2.SyncStatus(code=agent_pb2.SyncStatus.ERROR, message=msg))))
def _on_sync_delta(self, session_id, payload):
self.send_message(agent_pb2.ClientTaskMessage(file_sync=payload if isinstance(payload, agent_pb2.FileSyncMessage) else agent_pb2.FileSyncMessage(session_id=session_id, file_data=payload)))
def shutdown(self):
"""Gracefully shuts down the node."""
self._stop_event.set()
self.skills.shutdown()
self.stop() # From MeshNodeCore
print("[*] Node shutdown complete.")