import threading
import queue
import time
import os
import hashlib
import logging
import json
import zlib
import shutil
import socket
import traceback
from concurrent.futures import ThreadPoolExecutor
try:
import psutil
except ImportError:
psutil = None
from protos import agent_pb2, agent_pb2_grpc
from agent_node.skills.manager import SkillManager
from agent_node.core.sandbox import SandboxEngine
from agent_node.core.sync import NodeSyncManager
from agent_node.core.watcher import WorkspaceWatcher
from agent_node.utils.auth import verify_task_signature, verify_server_message_signature
from agent_node.utils.network import get_secure_stub
import agent_node.config as config
from agent_node.utils.watchdog import watchdog
from agent_node.core.regex_patterns import ANSI_ESCAPE
logger = logging.getLogger(__name__)
class AgentNode:
"""
Agent Core: Orchestrates local skills and maintains gRPC connectivity.
Refactored for structural clarity and modular message handling.
"""
def __init__(self):
self.node_id = config.NODE_ID
self.sandbox = SandboxEngine()
self.sync_mgr = NodeSyncManager()
self.skills = SkillManager(max_workers=config.MAX_SKILL_WORKERS, sync_mgr=self.sync_mgr)
self.watcher = WorkspaceWatcher(self._on_sync_delta)
self.task_queue = queue.Queue(maxsize=250)
self.stub = None
self.channel = None
self._stop_event = threading.Event()
self._refresh_stub()
self.io_executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="NodeIO")
self.io_semaphore = threading.Semaphore(50)
self.write_locks = {}
self.lock_map_mutex = threading.Lock()
def _refresh_stub(self):
"""Force refreshes the gRPC channel and stub."""
if self.channel:
try: self.channel.close()
except: pass
self.stub, self.channel = get_secure_stub()
self._setup_connectivity_watcher()
def _setup_connectivity_watcher(self):
"""Monitors gRPC channel state."""
import grpc
self._last_grpc_state = None
def _on_state_change(state):
if not self._stop_event.is_set() and state != self._last_grpc_state:
print(f"[*] [gRPC-State] {state}", flush=True)
self._last_grpc_state = state
self.channel.subscribe(_on_state_change, try_to_connect=True)
def sync_configuration(self):
"""Handshake with the Orchestrator to sync policy and metadata."""
while True:
config.reload()
self.node_id = config.NODE_ID
if not self.stub: self._refresh_stub()
print(f"[*] Handshake with Orchestrator: {self.node_id}")
caps = self._collect_capabilities()
reg_req = agent_pb2.RegistrationRequest(
node_id=self.node_id, auth_token=config.AUTH_TOKEN,
node_description=config.NODE_DESC,
capabilities={k: str(v).lower() if isinstance(v, bool) else str(v) for k, v in caps.items()}
)
try:
res = self.stub.SyncConfiguration(reg_req, timeout=10)
if res.success:
self.sandbox.sync(res.policy)
self._apply_skill_config(res.policy.skill_config_json)
print("[OK] Handshake successful. Policy Synced.")
break
else:
print(f"[!] Rejection: {res.error_message}. Retrying in 5s...")
time.sleep(5)
except Exception as e:
print(f"[!] Connection Fail: {str(e)}. Retrying in 5s...")
time.sleep(5)
def _apply_skill_config(self, config_json):
"""Applies dynamic skill configurations from the server."""
if not config_json: return
try:
cfg = json.loads(config_json)
for skill in self.skills.skills.values():
if hasattr(skill, "apply_config"): skill.apply_config(cfg)
except Exception as e:
logger.error(f"Error applying skill config: {e}")
def _collect_capabilities(self) -> dict:
"""Collects hardware and OS metadata."""
from agent_node.utils.platform_metrics import get_platform_metrics
caps = get_platform_metrics().collect_capabilities()
try:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.settimeout(0)
s.connect(('10.254.254.254', 1))
caps["local_ip"] = s.getsockname()[0]
s.close()
except: caps["local_ip"] = "unknown"
return caps
def start_health_reporting(self):
"""Launches the background health reporting stream."""
from agent_node.utils.platform_metrics import get_platform_metrics
metrics_tool = get_platform_metrics()
def _report():
while not self._stop_event.is_set():
try:
def _gen():
while not self._stop_event.is_set():
ids = self.skills.get_active_ids()
vmem = psutil.virtual_memory() if psutil else None
yield agent_pb2.Heartbeat(
node_id=self.node_id,
cpu_usage_percent=psutil.cpu_percent() if psutil else 0,
memory_usage_percent=vmem.percent if vmem else 0,
active_worker_count=len(ids),
max_worker_capacity=config.MAX_SKILL_WORKERS,
running_task_ids=ids,
cpu_count=psutil.cpu_count() if psutil else 0,
memory_used_gb=vmem.used/(1024**3) if vmem else 0,
memory_total_gb=vmem.total/(1024**3) if vmem else 0,
load_avg=metrics_tool.get_load_avg()
)
time.sleep(max(0, config.HEALTH_REPORT_INTERVAL - 1.0))
for _ in self.stub.ReportHealth(_gen()): watchdog.tick()
except Exception as e:
time.sleep(5)
threading.Thread(target=_report, daemon=True, name="HealthReporter").start()
def run_task_stream(self):
"""Main bi-directional task stream with auto-reconnection."""
while True:
try:
def _gen():
yield agent_pb2.ClientTaskMessage(announce=agent_pb2.NodeAnnounce(node_id=self.node_id))
while True: yield self.task_queue.get()
responses = self.stub.TaskStream(_gen())
print(f"[*] Task stream connected ({self.node_id}).")
for msg in responses:
watchdog.tick()
self._process_server_message(msg)
except Exception as e:
print(f"[!] Task stream error: {e}")
self._refresh_stub()
time.sleep(5)
def _process_server_message(self, msg):
"""Routes inbound server messages to their respective handlers."""
try:
kind = msg.WhichOneof('payload')
if not verify_server_message_signature(msg):
print(f"[!] Signature mismatch for {kind}. Proceeding anyway (DEBUG).")
if kind == 'task_request': self._handle_task(msg.task_request)
elif kind == 'task_cancel': self._handle_cancel(msg.task_cancel)
elif kind == 'work_pool_update': self._handle_work_pool(msg.work_pool_update)
elif kind == 'file_sync': self._handle_file_sync(msg.file_sync)
elif kind == 'policy_update':
self.sandbox.sync(msg.policy_update)
self._apply_skill_config(msg.policy_update.skill_config_json)
except Exception as e:
print(f"[!] Error processing server message '{kind}': {e}")
traceback.print_exc()
def _handle_cancel(self, cancel_req):
"""Cancels an active task."""
if self.skills.cancel(cancel_req.task_id):
self._send_response(cancel_req.task_id, None, agent_pb2.TaskResponse.CANCELLED)
def _handle_work_pool(self, update):
"""Claims tasks from the global work pool with randomized backoff."""
if len(self.skills.get_active_ids()) < config.MAX_SKILL_WORKERS:
for tid in update.available_task_ids:
import random
time.sleep(random.uniform(0.1, 0.5))
self.task_queue.put(agent_pb2.ClientTaskMessage(
task_claim=agent_pb2.TaskClaimRequest(task_id=tid, node_id=self.node_id)
))
def _handle_file_sync(self, fs):
"""Dispatches file sync messages to specialized sub-handlers."""
sid = fs.session_id
if fs.HasField("manifest"): self._on_sync_manifest(sid, fs.manifest)
elif fs.HasField("file_data"): self._on_sync_data(sid, fs.file_data)
elif fs.HasField("control"): self._on_sync_control(sid, fs.control, fs.task_id)
def _on_sync_manifest(self, sid, manifest):
"""Reconciles local state with a remote manifest."""
drift = self.sync_mgr.handle_manifest(sid, manifest, on_purge_callback=lambda p: self.watcher.acknowledge_remote_delete(sid, p))
status = agent_pb2.SyncStatus(
code=agent_pb2.SyncStatus.RECONCILE_REQUIRED if drift else agent_pb2.SyncStatus.OK,
message=f"Drift in {len(drift)} files" if drift else "Synchronized",
reconcile_paths=drift
)
self.task_queue.put(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=sid, status=status)))
def _on_sync_data(self, sid, file_data):
"""Offloads disk I/O to a background worker pool."""
self.io_semaphore.acquire()
try: self.io_executor.submit(self._async_write_chunk, sid, file_data)
except: self.io_semaphore.release()
def _on_sync_control(self, sid, ctrl, task_id):
"""Handles sync control actions like watching, locking, or directory listing."""
action = ctrl.action
if action == agent_pb2.SyncControl.START_WATCHING:
self.watcher.start_watching(sid, ctrl.path if os.path.isabs(ctrl.path) else os.path.join(self.sync_mgr.get_session_dir(sid), ctrl.path))
elif action == agent_pb2.SyncControl.STOP_WATCHING: self.watcher.stop_watching(sid)
elif action == agent_pb2.SyncControl.LOCK: self.watcher.set_lock(sid, True)
elif action == agent_pb2.SyncControl.UNLOCK: self.watcher.set_lock(sid, False)
elif action in (agent_pb2.SyncControl.REFRESH_MANIFEST, agent_pb2.SyncControl.RESYNC):
if ctrl.request_paths:
for p in ctrl.request_paths: self.io_executor.submit(self._push_file, sid, p)
else: self._push_full_manifest(sid, ctrl.path)
elif action == agent_pb2.SyncControl.PURGE:
self.watcher.stop_watching(sid)
self.sync_mgr.purge(sid)
elif action == agent_pb2.SyncControl.LIST: self._push_full_manifest(sid, ctrl.path, task_id=task_id, shallow=True)
elif action == agent_pb2.SyncControl.READ: self._push_file(sid, ctrl.path, task_id=task_id)
elif action == agent_pb2.SyncControl.WRITE: self._handle_fs_write(sid, ctrl.path, ctrl.content, ctrl.is_dir, task_id=task_id)
elif action == agent_pb2.SyncControl.DELETE: self._handle_fs_delete(sid, ctrl.path, task_id=task_id)
def _get_base_dir(self, session_id, create=False):
"""Resolves the session's effective root directory."""
if session_id == "__fs_explorer__": return config.FS_ROOT
watched = self.watcher.get_watch_path(session_id)
return watched if watched else self.sync_mgr.get_session_dir(session_id, create=create)
def _push_full_manifest(self, session_id, rel_path=".", task_id="", shallow=False):
"""Generates and pushes a local file manifest to the server."""
base_dir = self._get_base_dir(session_id, create=True)
safe_rel = rel_path.lstrip("/") if session_id != "__fs_explorer__" else rel_path
watch_path = os.path.normpath(os.path.join(base_dir, safe_rel))
if not os.path.exists(watch_path):
if session_id != "__fs_explorer__": os.makedirs(watch_path, exist_ok=True)
else: self._send_sync_error(session_id, task_id, f"Path {rel_path} not found")
if session_id == "__fs_explorer__": return
files = []
try:
if shallow:
with os.scandir(watch_path) as it:
for entry in it:
if entry.name == ".cortex_sync": continue
is_dir = entry.is_dir() if not entry.is_symlink() else os.path.isdir(entry.path)
item_rel = os.path.relpath(os.path.join(watch_path, entry.name), base_dir).replace("\\", "/")
files.append(agent_pb2.FileInfo(path=item_rel, size=entry.stat().st_size if not is_dir else 0, hash="", is_dir=is_dir))
else:
for root, dirs, filenames in os.walk(watch_path):
for name in filenames:
abs_p = os.path.join(root, name)
h = self.sync_mgr.get_file_hash(abs_p)
if h: files.append(agent_pb2.FileInfo(path=os.path.relpath(abs_p, base_dir).replace("\\", "/"), size=os.path.getsize(abs_p), hash=h, is_dir=False))
for d in dirs:
files.append(agent_pb2.FileInfo(path=os.path.relpath(os.path.join(root, d), base_dir).replace("\\", "/"), size=0, hash="", is_dir=True))
except Exception as e:
return self._send_sync_error(session_id, task_id, str(e))
self._send_manifest_chunks(session_id, task_id, rel_path, files)
def _send_manifest_chunks(self, sid, tid, root, files):
"""Splits large manifests into chunks for gRPC streaming."""
chunk_size = 1000
if not files:
self.task_queue.put(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=sid, task_id=tid, manifest=agent_pb2.DirectoryManifest(root_path=root, is_final=True))))
return
for i in range(0, len(files), chunk_size):
chunk = files[i:i+chunk_size]
self.task_queue.put(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=sid, task_id=tid,
manifest=agent_pb2.DirectoryManifest(root_path=root, files=chunk, chunk_index=i//chunk_size, is_final=(i+chunk_size >= len(files))))))
def _handle_fs_write(self, session_id, rel_path, content, is_dir, task_id=""):
"""Handles single file or directory creation."""
try:
base = os.path.normpath(self._get_base_dir(session_id, create=True))
target = os.path.normpath(os.path.join(base, rel_path.lstrip("/") if session_id != "__fs_explorer__" else rel_path))
# Security: Only enforce jail if not in the global file explorer mode
if session_id != "__fs_explorer__":
if not target.startswith(base) and (not config.FS_ROOT or not target.startswith(os.path.normpath(config.FS_ROOT))):
raise Exception("Path traversal blocked.")
if is_dir: os.makedirs(target, exist_ok=True)
else:
os.makedirs(os.path.dirname(target), exist_ok=True)
with open(target, "wb") as f: f.write(content)
self._send_sync_ok(session_id, task_id, "Resource written")
self._push_full_manifest(session_id, os.path.dirname(rel_path) or ".", task_id=task_id, shallow=True)
except Exception as e: self._send_sync_error(session_id, task_id, str(e))
def _handle_fs_delete(self, session_id, rel_path, task_id=""):
"""Removes a file or directory from the node."""
try:
base = os.path.normpath(self._get_base_dir(session_id))
target = os.path.normpath(os.path.join(base, rel_path.lstrip("/") if session_id != "__fs_explorer__" else rel_path))
if session_id != "__fs_explorer__":
if not target.startswith(base) and (not config.FS_ROOT or not target.startswith(os.path.normpath(config.FS_ROOT))):
raise Exception("Path traversal blocked.")
self.watcher.acknowledge_remote_delete(session_id, rel_path)
if os.path.isdir(target): shutil.rmtree(target)
else: os.remove(target)
self._send_sync_ok(session_id, task_id, f"Deleted {rel_path}")
self._push_full_manifest(session_id, os.path.dirname(rel_path) or ".", task_id=task_id, shallow=True)
except Exception as e: self._send_sync_error(session_id, task_id, str(e))
def _async_write_chunk(self, sid, payload):
"""Worker for segmented file writing with path-level locking."""
path = payload.path
with self.lock_map_mutex:
if (sid, path) not in self.write_locks: self.write_locks[(sid, path)] = threading.Lock()
lock = self.write_locks[(sid, path)]
with lock:
try:
if payload.chunk_index == 0: self.watcher.suppress_path(sid, path)
success = self.sync_mgr.write_chunk(sid, payload)
if payload.is_final:
if success and payload.hash: self.watcher.acknowledge_remote_write(sid, path, payload.hash)
self.watcher.unsuppress_path(sid, path)
with self.lock_map_mutex: self.write_locks.pop((sid, path), None)
self._send_sync_ok(sid, "", f"File {path} synced")
finally: self.io_semaphore.release()
def _push_file(self, session_id, rel_path, task_id=""):
"""Streams a file from node to server using 4MB chunks."""
base = os.path.normpath(self._get_base_dir(session_id, create=False))
target = os.path.normpath(os.path.join(base, rel_path.lstrip("/") if session_id != "__fs_explorer__" else rel_path))
if session_id != "__fs_explorer__":
if not target.startswith(base) and (not config.FS_ROOT or not target.startswith(os.path.normpath(config.FS_ROOT))): return
if not os.path.exists(target):
if task_id: self._send_sync_error(session_id, task_id, "File not found")
return
hasher = hashlib.sha256()
size, chunk_size = os.path.getsize(target), 4 * 1024 * 1024
try:
with open(target, "rb") as f:
idx = 0
while True:
chunk = f.read(chunk_size)
if not chunk and idx > 0: break
hasher.update(chunk)
is_final = f.tell() >= size
self.task_queue.put(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=session_id, task_id=task_id,
file_data=agent_pb2.FilePayload(path=rel_path.replace("\\", "/"), chunk=zlib.compress(chunk), chunk_index=idx, is_final=is_final, hash=hasher.hexdigest() if is_final else "", compressed=True))))
if is_final: break
idx += 1
except Exception as e: logger.error(f"Push Error: {e}")
def _handle_task(self, task):
"""Verifies and submits a skill task for execution."""
if not verify_task_signature(task):
print(f"[!] Task signature mismatch for {task.task_id}. Proceeding anyway (DEBUG).")
success, reason = self.skills.submit(task, self.sandbox, self._on_finish, self._on_event)
if not success:
self._send_response(task.task_id, agent_pb2.TaskResponse(task_id=task.task_id, status=agent_pb2.TaskResponse.ERROR, stderr=f"Rejection: {reason}"))
def _on_event(self, event):
"""Forwards skill events to the gRPC stream."""
self.task_queue.put(event if isinstance(event, agent_pb2.ClientTaskMessage) else agent_pb2.ClientTaskMessage(skill_event=event))
def _on_finish(self, task_id, result, trace_id):
"""Finalizes a task and sends the response back to the Hub."""
self._send_response(task_id, agent_pb2.TaskResponse(task_id=task_id, stdout=result.get("stdout", ""), stderr=result.get("stderr", ""), status=result.get("status", 0), trace_id=trace_id))
def _send_response(self, task_id, response, status_override=None):
"""Helper to queue a standard task response."""
res = response or agent_pb2.TaskResponse(task_id=task_id)
if status_override: res.status = status_override
self.task_queue.put(agent_pb2.ClientTaskMessage(task_response=res))
def _send_sync_ok(self, sid, tid, msg):
self.task_queue.put(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=sid, task_id=tid, status=agent_pb2.SyncStatus(code=agent_pb2.SyncStatus.OK, message=msg))))
def _send_sync_error(self, sid, tid, msg):
self.task_queue.put(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=sid, task_id=tid, status=agent_pb2.SyncStatus(code=agent_pb2.SyncStatus.ERROR, message=msg))))
def _on_sync_delta(self, session_id, payload):
self.task_queue.put(agent_pb2.ClientTaskMessage(file_sync=payload if isinstance(payload, agent_pb2.FileSyncMessage) else agent_pb2.FileSyncMessage(session_id=session_id, file_data=payload)))
def shutdown(self):
"""Gracefully shuts down the node."""
self._stop_event.set()
self.skills.shutdown()
if self.channel: self.channel.close()
print("[*] Node shutdown complete.")