diff --git a/.dockerignore b/.dockerignore index ba004fa..9011797 100644 --- a/.dockerignore +++ b/.dockerignore @@ -5,3 +5,7 @@ **/.pytest_cache **/data ai-hub/native_hub.log +**/.DS_Store +**/._* +.env +.env.* diff --git a/agent-node/README.md b/agent-node/README.md index ac8e6fe..2f7fc11 100644 --- a/agent-node/README.md +++ b/agent-node/README.md @@ -27,3 +27,16 @@ ### Foreground Usage Run `python3 src/agent_node/main.py` directly if you want to watch the logs in your terminal. + +## Maintenance & Uninstallation + +### Self-Update +The node automatically checks for updates from the connected Hub. You can force an update check by restarting the service. + +### Clean Uninstallation (Purge) +To completely remove the agent, stop all background services, and delete all local configuration and code: +```bash +python3 purge.py +``` +> [!WARNING] +> This will permanently delete the current directory and deregister the node from the Hub. diff --git a/agent-node/bootstrap_windows.ps1 b/agent-node/bootstrap_windows.ps1 index 52547c6..f9eabaf 100644 --- a/agent-node/bootstrap_windows.ps1 +++ b/agent-node/bootstrap_windows.ps1 @@ -2,7 +2,10 @@ [string]$NodeId = "", [string]$AuthToken = "", [string]$HubUrl = "", - [string]$GrpcUrl = "" + [string]$GrpcUrl = "", + [switch]$RegisterService = $false, + [switch]$ForceFirewall = $false, + [switch]$AutoRun = $true ) $ErrorActionPreference = "Stop" @@ -14,6 +17,7 @@ # 1. Check Python installation (defensively avoid Microsoft Store alias) $pythonValid = $false try { + # Check if python is in path and not the 0-byte executable from store $out = python --version 2>&1 if ($out -like "*Python *") { $pythonValid = $true } } catch { } @@ -21,7 +25,7 @@ if (!$pythonValid) { Write-Host "[!] Python not found or invalid. Installing via winget..." -ForegroundColor Yellow winget install -e --id Python.Python.3.12 --accept-package-agreements --accept-source-agreements - # Refresh PATH + # Refresh PATH for the current session $env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User") } @@ -36,17 +40,12 @@ } Set-Location $workDir -# 4. Download agent code from Hub (Matching Linux pattern) +# 4. Download agent code from Hub if ((-not $HubUrl) -or (-not $AuthToken)) { Write-Host "[!] Hub details missing. Will prompt for them later after basic setup." -ForegroundColor Yellow } else { Write-Host "[*] Fetching agent source from Hub..." -ForegroundColor Cyan - $baseUrl = $HubUrl.Split(":")[0] - if ($HubUrl.Contains("http")) { - $downloadUrl = "$HubUrl/api/v1/agent/download" - } else { - $downloadUrl = "http://$baseUrl:8002/api/v1/agent/download" - } + $downloadUrl = "$HubUrl/api/v1/agent/download" $tarPath = Join-Path $workDir "agent.tar.gz" $headers = @{"X-Agent-Token" = $AuthToken} @@ -55,13 +54,12 @@ Invoke-WebRequest -Uri $downloadUrl -Headers $headers -OutFile $tarPath Write-Host "[+] Download complete. Extracting..." -ForegroundColor Green - # Windows 10+ has tar.exe built-in. Fallback to Expand-Archive if needed. + # Windows 10+ has tar.exe built-in. if (Get-Command tar -ErrorAction SilentlyContinue) { tar -xzf $tarPath --strip-components=1 } else { - Write-Warning "tar.exe not found. Attempting Expand-Archive (may not support tar.gz natively without 7-Zip/etc)." - # Note: PowerShell's Expand-Archive usually only likes .zip. - # We recommend users have tar or we provide a zip endpoint. + Write-Warning "tar.exe not found. Attempting Expand-Archive (may only support .zip)." + # Minimal fallback logic } Remove-Item $tarPath } catch { @@ -72,14 +70,16 @@ # 5. Setup Virtual Environment Write-Host "[*] Creating Virtual Environment..." -python -m venv venv -.\venv\Scripts\Activate.ps1 +if (!(Test-Path "venv")) { + python -m venv venv +} +$pythonExe = Join-Path $workDir "venv\Scripts\python.exe" # 6. Install Dependencies Write-Host "[*] Installing Dependencies..." -python -m pip install --upgrade pip -python -m pip install -r requirements.txt -python -m pip install pywinpty pypiwin32 # Windows-specific requirements +& $pythonExe -m pip install --upgrade pip +& $pythonExe -m pip install -r requirements.txt +& $pythonExe -m pip install pywinpty pypiwin32 # Windows-specific requirements # 7. Environment Setup Write-Host "------------------------------------------" @@ -101,14 +101,24 @@ $envFile | Out-File -FilePath ".env" -Encoding ascii # 8. Test Execution -Write-Host "[*] Bootstrap complete. You can now run the agent with:" -ForegroundColor Green -Write-Host " .\venv\Scripts\python.exe src\agent_node\main.py" -ForegroundColor Yellow +Write-Host "[*] Bootstrap complete." -ForegroundColor Green +if ($AutoRun) { + Write-Host "[โšก] Auto-starting Agent..." -ForegroundColor Cyan + Start-Process -FilePath $pythonExe -ArgumentList "src\agent_node\main.py" -WorkingDirectory $workDir +} else { + Write-Host " To run manually: .\venv\Scripts\python.exe src\agent_node\main.py" -ForegroundColor Yellow +} # 9. Optional Service Registration -$installService = Read-Host "Would you like to register this as a startup task? (y/n)" -if ($installService -eq "y") { +if ($RegisterService) { Write-Host "[*] Registering Scheduled Task..." - python install_service.py --name "CortexAgent" --run + & $pythonExe install_service.py --name "CortexAgent" --run +} elseif ($NodeId -eq "" -or !$NodeId) { + # Only ask if we are in interactive mode (no NodeId provided early) + $ans = Read-Host "Would you like to register this as a startup task? (y/n)" + if ($ans -eq "y") { + & $pythonExe install_service.py --name "CortexAgent" --run + } } # 10. Security & Firewall Check @@ -118,22 +128,25 @@ $disabled = $profiles | Where-Object { $_.Enabled -eq "False" } if ($disabled) { - Write-Host "[!] Warning: One or more Firewall profiles are DISABLED." -ForegroundColor Yellow - $enable = Read-Host "Would you like to ENABLE the Windows Firewall now? (y/n)" - if ($enable -eq "y") { + if ($ForceFirewall) { Set-NetFirewallProfile -Profile Domain,Public,Private -Enabled True Write-Host "[+] Windows Firewall enabled." -ForegroundColor Green + } elseif ($NodeId -eq "" -or !$NodeId) { + $ans = Read-Host "[!] Warning: Firewall disabled. Enable it now? (y/n)" + if ($ans -eq "y") { + Set-NetFirewallProfile -Profile Domain,Public,Private -Enabled True + } } } # Add rule for Python communication Write-Host "[*] Adding firewall exception for agent communication..." -$pythonPath = (Get-Command python).Source -if ($pythonPath) { - New-NetFirewallRule -DisplayName "Cortex Agent Communication" -Direction Outbound -Program $pythonPath -Action Allow -Description "Allows Cortex Agent to reach the Hub" -ErrorAction SilentlyContinue - New-NetFirewallRule -DisplayName "Cortex Agent Communication" -Direction Inbound -Program $pythonPath -Action Allow -Description "Allows Cortex Agent Mesh Communication" -Profile Any -ErrorAction SilentlyContinue +if ($pythonExe -and (Test-Path $pythonExe)) { + New-NetFirewallRule -DisplayName "Cortex Agent Communication" -Direction Outbound -Program $pythonExe -Action Allow -Description "Allows Cortex Agent to reach the Hub" -ErrorAction SilentlyContinue + New-NetFirewallRule -DisplayName "Cortex Agent Communication" -Direction Inbound -Program $pythonExe -Action Allow -Description "Allows Cortex Agent Mesh Communication" -Profile Any -ErrorAction SilentlyContinue } Write-Host "==========================================" Write-Host " DONE! Check the Hub for node status. " -ForegroundColor Cyan Write-Host "==========================================" + diff --git a/agent-node/purge.py b/agent-node/purge.py new file mode 100644 index 0000000..e36e0bb --- /dev/null +++ b/agent-node/purge.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 +""" +Cortex Agent Node - Clean Purge Tool +==================================== +Stops services, deregisters from the Hub, and deletes all agent files. + +Usage: + python3 purge.py [--force] +""" + +import os +import sys +import time +import shutil +import platform +import subprocess +import requests +from agent_node.utils.service_manager import get_service_manager + +def load_env(): + env_path = os.path.join(os.path.dirname(__file__), ".env") + env = {} + if os.path.exists(env_path): + with open(env_path, "r") as f: + for line in f: + if "=" in line: + key, val = line.strip().split("=", 1) + env[key] = val + return env + +def main(): + print("๐Ÿงน Starting Cortex Agent Clean Purge...") + + env = load_env() + node_id = env.get("AGENT_NODE_ID") + token = env.get("AGENT_AUTH_TOKEN") + hub_url = env.get("AGENT_HUB_URL") + + # 1. Uninstall Service + print("[*] Removing background daemon/service...") + manager = get_service_manager() + try: + manager.stop() + if manager.uninstall(): + print("โœ… Service uninstalled.") + else: + print("โš ๏ธ Service was not installed or failed to uninstall.") + except Exception as e: + print(f"โš ๏ธ Error during service removal: {e}") + + # 2. Deregister from Hub + if node_id and token and hub_url: + print(f"[*] Deregistering node '{node_id}' from Hub...") + purge_url = f"{hub_url}/api/v1/nodes/purge?node_id={node_id}&token={token}" + try: + resp = requests.post(purge_url, timeout=10) + if resp.status_code == 200: + print("โœ… Hub successfully notified. Node record deleted.") + else: + print(f"โš ๏ธ Hub rejected purge request: {resp.text}") + except Exception as e: + print(f"โš ๏ธ Could not reach Hub for deregistration: {e}") + else: + print("โš ๏ธ Missing .env details. Skipping Hub deregistration.") + + # 3. Final Wipe + working_dir = os.path.abspath(os.path.dirname(__file__)) + print(f"[*] Preparing to delete directory: {working_dir}") + + # Platform-specific self-deletion + if platform.system() == "Windows": + # Create a temporary batch file to delete the folder after this process exits + temp_dir = os.environ.get("TEMP", "C:\\Temp") + bat_path = os.path.join(temp_dir, f"cortex_purge_{int(time.time())}.bat") + + # Batch script: Wait for Python to exit, remove dir, delete itself. + bat_content = f"""@echo off +timeout /t 2 /nobreak > nul +rd /s /q "{working_dir}" +del "%~f0" +""" + with open(bat_path, "w") as f: + f.write(bat_content) + + print(f"๐Ÿš€ Launching cleanup wrapper: {bat_path}") + subprocess.Popen(["cmd.exe", "/c", bat_path], shell=True) + else: + # Linux/Mac can usually 'rm' themselves while running, + # but to be safe we'll use a shell wrapper + print("๐Ÿš€ Executing final recursive deletion...") + # We run it in the background with a small delay + cmd = f"sleep 1 && rm -rf {working_dir} &" + os.system(cmd) + + print("\nโœ… Purge complete. Goodbye!") + sys.exit(0) + +if __name__ == "__main__": + # Ensure requests is installed for this script (minimal dependency check) + try: + import requests + except ImportError: + print("[*] Installing requests for deregistration...") + subprocess.run([sys.executable, "-m", "pip", "install", "requests"], capture_output=True) + import requests + + main() diff --git a/agent-node/src/agent_node/core/regex_patterns.py b/agent-node/src/agent_node/core/regex_patterns.py new file mode 100644 index 0000000..5ce7dc8 --- /dev/null +++ b/agent-node/src/agent_node/core/regex_patterns.py @@ -0,0 +1,31 @@ +import re + +# ANSI Escape Sequences (Standard and Xterm) +ANSI_ESCAPE = re.compile(r'\x1b(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') + +# Shell Prompt Detection patterns +PROMPT_PATTERNS = [ + r"[\r\n].*[@\w\.\-]+:.*[#$]\s*$", # bash/zsh: user@host:~$ + r">>>\s*$", # python + r"\.\.\.\s*$", # python multi-line + r">\s*$", # node/js + r"PS\s+.*>\s*$", # powershell +] + +# Compiled prompt patterns for performance +COMPILED_PROMPT_PATTERNS = [re.compile(p) for p in PROMPT_PATTERNS] + +# Protocol Extraction (OSC 1337 / Bracketed) +TASK_PROTOCOL_FENCE = re.compile(r'1337;Task(Start|End);id=([a-zA-Z0-9-]*)') +EXIT_CODE_PATTERN = re.compile(r'exit=(\d+)') + +# Echo Suppression Patterns (M9/M7) +ECHO_START_PATTERN = re.compile(r'echo \s*\[\[1337;Task\^Start;id=[a-zA-Z0-9-]*\]\]\s*&\s*') +ECHO_END_PATTERN = re.compile(r'\s*&\s*echo \s*\[\[1337;Task\^End;id=[a-zA-Z0-9-]*;exit=%errorlevel%\]\]') +ECHO_CLEANUP_ANSI = re.compile(r'echo \x1b]1337;TaskEnd;.*') +ECHO_CLEANUP_BRACKET = re.compile(r'echo \[\[1337;TaskEnd;.*') + +# UI Stealth Filtering +PROTOCOL_HINT_PATTERN = re.compile(r"1337;Task|`e]|\\033]") +STRIP_START_FENCE = re.compile(r'\x1b]1337;Task(Start|End);id=.*?\x07') +STRIP_BRACKET_FENCE = re.compile(r'\[\[1337;Task(Start|End);id=.*?\]\]') diff --git a/agent-node/src/agent_node/node.py b/agent-node/src/agent_node/node.py index 171c010..880e48b 100644 --- a/agent-node/src/agent_node/node.py +++ b/agent-node/src/agent_node/node.py @@ -1,19 +1,21 @@ import threading import queue import time -import sys import os import hashlib import logging import json import zlib +import shutil +import socket +from concurrent.futures import ThreadPoolExecutor + try: import psutil except ImportError: psutil = None -from protos import agent_pb2, agent_pb2_grpc -logger = logging.getLogger(__name__) +from protos import agent_pb2, agent_pb2_grpc from agent_node.skills.manager import SkillManager from agent_node.core.sandbox import SandboxEngine from agent_node.core.sync import NodeSyncManager @@ -22,138 +24,104 @@ from agent_node.utils.network import get_secure_stub import agent_node.config as config from agent_node.utils.watchdog import watchdog +from agent_node.core.regex_patterns import ANSI_ESCAPE +logger = logging.getLogger(__name__) class AgentNode: - """The 'Agent Core': Orchestrates Local Skills and Maintains gRPC Connection.""" + """ + Agent Core: Orchestrates local skills and maintains gRPC connectivity. + Refactored for structural clarity and modular message handling. + """ def __init__(self): - # Dynamically read config instead of caching static defaults self.node_id = config.NODE_ID self.sandbox = SandboxEngine() self.sync_mgr = NodeSyncManager() self.skills = SkillManager(max_workers=config.MAX_SKILL_WORKERS, sync_mgr=self.sync_mgr) self.watcher = WorkspaceWatcher(self._on_sync_delta) - # Bounded queue to prevent memory ballooning; 250 * 4MB chunks = 1GB max memory. self.task_queue = queue.Queue(maxsize=250) self.stub = None self.channel = None self._stop_event = threading.Event() self._refresh_stub() - # M6: Parallel Disk I/O Workers - from concurrent.futures import ThreadPoolExecutor self.io_executor = ThreadPoolExecutor(max_workers=10, thread_name_prefix="NodeIO") - # Backpressure for I/O: Prevent memory ballooning during heavy sync self.io_semaphore = threading.Semaphore(50) - self.write_locks = {} # (sid, path) -> threading.Lock + self.write_locks = {} self.lock_map_mutex = threading.Lock() def _refresh_stub(self): - """Force refreshes the gRPC channel and stub, ensuring old ones are closed.""" + """Force refreshes the gRPC channel and stub.""" if self.channel: - try: - self.channel.close() - except: - pass + try: self.channel.close() + except: pass self.stub, self.channel = get_secure_stub() self._setup_connectivity_watcher() def _setup_connectivity_watcher(self): - """Monitor gRPC channel state and log only on actual transitions.""" + """Monitors gRPC channel state.""" import grpc self._last_grpc_state = None - def _on_state_change(state): - if self._stop_event.is_set(): - return - if state != self._last_grpc_state: + if not self._stop_event.is_set() and state != self._last_grpc_state: print(f"[*] [gRPC-State] {state}", flush=True) self._last_grpc_state = state - - # Persistent subscription โ€” only call ONCE per channel. - # Re-subscribing inside the callback causes an exponential callback leak. self.channel.subscribe(_on_state_change, try_to_connect=True) - def _collect_capabilities(self) -> dict: - """Collect hardware metadata using abstract platform metrics.""" - from agent_node.utils.platform_metrics import get_platform_metrics - import socket - - metrics = get_platform_metrics() - caps = metrics.collect_capabilities() + def sync_configuration(self): + """Handshake with the Orchestrator to sync policy and metadata.""" + while True: + config.reload() + self.node_id = config.NODE_ID + if not self.stub: self._refresh_stub() - # Shared: Local IP Detection + print(f"[*] Handshake with Orchestrator: {self.node_id}") + caps = self._collect_capabilities() + reg_req = agent_pb2.RegistrationRequest( + node_id=self.node_id, auth_token=config.AUTH_TOKEN, + node_description=config.NODE_DESC, + capabilities={k: str(v).lower() if isinstance(v, bool) else str(v) for k, v in caps.items()} + ) + + try: + res = self.stub.SyncConfiguration(reg_req, timeout=10) + if res.success: + self.sandbox.sync(res.policy) + self._apply_skill_config(res.policy.skill_config_json) + print("[OK] Handshake successful. Policy Synced.") + break + else: + print(f"[!] Rejection: {res.error_message}. Retrying in 5s...") + time.sleep(5) + except Exception as e: + print(f"[!] Connection Fail: {str(e)}. Retrying in 5s...") + time.sleep(5) + + def _apply_skill_config(self, config_json): + """Applies dynamic skill configurations from the server.""" + if not config_json: return + try: + cfg = json.loads(config_json) + for skill in self.skills.skills.values(): + if hasattr(skill, "apply_config"): skill.apply_config(cfg) + except Exception as e: + logger.error(f"Error applying skill config: {e}") + + def _collect_capabilities(self) -> dict: + """Collects hardware and OS metadata.""" + from agent_node.utils.platform_metrics import get_platform_metrics + caps = get_platform_metrics().collect_capabilities() try: s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s.settimeout(0) s.connect(('10.254.254.254', 1)) caps["local_ip"] = s.getsockname()[0] s.close() - except Exception: - caps["local_ip"] = "unknown" - + except: caps["local_ip"] = "unknown" return caps - def sync_configuration(self): - """Initial handshake to retrieve policy and metadata.""" - while True: - config.reload() - self.node_id = config.NODE_ID - self.skills.max_workers = config.MAX_SKILL_WORKERS - if not self.stub: - self._refresh_stub() - - print(f"[*] Handshake with Orchestrator: {self.node_id}") - caps = self._collect_capabilities() - print(f"[*] Capabilities: {caps}") - - caps_str = {k: str(v).lower() if isinstance(v, bool) else str(v) for k, v in caps.items()} - - reg_req = agent_pb2.RegistrationRequest( - node_id=self.node_id, - auth_token=config.AUTH_TOKEN, - node_description=config.NODE_DESC, - capabilities=caps_str - ) - - try: - print(f"[*] [gRPC-Handshake] Sending SyncConfiguration (timeout=10s)...", flush=True) - res = self.stub.SyncConfiguration(reg_req, timeout=10) - if res.success: - self.sandbox.sync(res.policy) - print("[OK] [gRPC-Handshake] Handshake successful. Sandbox Policy Synced.") - - if res.policy.skill_config_json: - try: - cfg = json.loads(res.policy.skill_config_json) - for name, skill in self.skills.skills.items(): - if hasattr(skill, "apply_config"): - skill.apply_config(cfg) - except Exception as e: - print(f"[!] Error applying initial skill config: {e}") - break - else: - print(f"[!] Rejection: {res.error_message}") - print("[!] Retrying handshake in 5 seconds...") - time.sleep(5) - except Exception as e: - err_desc = self._format_grpc_error(e) - print(f"[!] Connection Fail: {err_desc}") - print("[!] Retrying handshake in 5 seconds...") - time.sleep(5) - - def _format_grpc_error(self, e) -> str: - """Helper to extract detailed info from gRPC exceptions.""" - try: - import grpc - if isinstance(e, grpc.RpcError): - return f"gRPC Error {e.code()} | {e.details()}" - except: - pass - return str(e) - def start_health_reporting(self): - """Streaming node metrics using abstract platform metrics.""" + """Launches the background health reporting stream.""" from agent_node.utils.platform_metrics import get_platform_metrics metrics_tool = get_platform_metrics() @@ -163,683 +131,280 @@ def _gen(): while not self._stop_event.is_set(): ids = self.skills.get_active_ids() - if psutil: - cpu = psutil.cpu_percent(interval=None) - per_core = psutil.cpu_percent(percpu=True, interval=None) - vmem = psutil.virtual_memory() - mem_percent = vmem.percent - used_gb = vmem.used / (1024**3) - total_gb = vmem.total / (1024**3) - avail_gb = vmem.available / (1024**3) - cpu_count = psutil.cpu_count() - else: - cpu, per_core, mem_percent = 0.0, [], 0.0 - used_gb, total_gb, avail_gb, cpu_count = 0.0, 0.0, 0.0, 0 - - freq = 0 - if psutil: - try: freq = psutil.cpu_freq().current - except: pass - - load = metrics_tool.get_load_avg() - + vmem = psutil.virtual_memory() if psutil else None yield agent_pb2.Heartbeat( node_id=self.node_id, - cpu_usage_percent=cpu, - memory_usage_percent=mem_percent, + cpu_usage_percent=psutil.cpu_percent() if psutil else 0, + memory_usage_percent=vmem.percent if vmem else 0, active_worker_count=len(ids), max_worker_capacity=config.MAX_SKILL_WORKERS, running_task_ids=ids, - cpu_count=cpu_count, - memory_used_gb=used_gb, - memory_total_gb=total_gb, - cpu_usage_per_core=per_core, - cpu_freq_mhz=freq, - memory_available_gb=avail_gb, - load_avg=load + cpu_count=psutil.cpu_count() if psutil else 0, + memory_used_gb=vmem.used/(1024**3) if vmem else 0, + memory_total_gb=vmem.total/(1024**3) if vmem else 0, + load_avg=metrics_tool.get_load_avg() ) time.sleep(max(0, config.HEALTH_REPORT_INTERVAL - 1.0)) - for response in self.stub.ReportHealth(_gen()): - watchdog.tick() + for _ in self.stub.ReportHealth(_gen()): watchdog.tick() except Exception as e: - err_desc = self._format_grpc_error(e) - print(f"[!] Health reporting interrupted: {err_desc}. Retrying in 5s...") time.sleep(5) - threading.Thread(target=_report, daemon=True, name=f"Health-{self.node_id}").start() + threading.Thread(target=_report, daemon=True, name="HealthReporter").start() def run_task_stream(self): - """Main Persistent Bi-directional Stream for Task Management with Reconnection.""" + """Main bi-directional task stream with auto-reconnection.""" while True: try: def _gen(): - # Initial announcement for routing identity - announce_msg = agent_pb2.ClientTaskMessage( - announce=agent_pb2.NodeAnnounce(node_id=self.node_id) - ) - yield announce_msg - - while True: - out_msg = self.task_queue.get() - try: - yield out_msg - except Exception as ye: - print(f"[*] [gRPC-Stream] !!! Send Error: {ye}", flush=True) - raise ye + yield agent_pb2.ClientTaskMessage(announce=agent_pb2.NodeAnnounce(node_id=self.node_id)) + while True: yield self.task_queue.get() responses = self.stub.TaskStream(_gen()) - print(f"[*] [gRPC-Stream] Connected to Orchestrator ({self.node_id}).", flush=True) - + print(f"[*] Task stream connected ({self.node_id}).") for msg in responses: watchdog.tick() self._process_server_message(msg) - - print(f"[*] [gRPC-Stream] Connection closed by server.", flush=True) except Exception as e: - import traceback - traceback.print_exc() - err_desc = self._format_grpc_error(e) - print(f"[!] Task Stream Failure: {err_desc}. Reconnecting in 5s...", flush=True) - # Force refresh stub on reconnection, closing old channel + print(f"[!] Task stream error: {e}. Reconnecting...") self._refresh_stub() time.sleep(5) - # Re-sync config in case permissions changed during downtime - try: self.sync_configuration() - except: pass def _process_server_message(self, msg): + """Routes inbound server messages to their respective handlers.""" if not verify_server_message_signature(msg): - print(f"[โŒ] Invalid signature on ServerTaskMessage! Dropping message.", flush=True) + logger.warning("Invalid server message signature. Dropping.") return kind = msg.WhichOneof('payload') - if config.DEBUG_GRPC or True: # Force logging for now to debug Mac - if kind == 'file_sync' and msg.file_sync.HasField('control'): - print(f"[*] Inbound: {kind} (control={msg.file_sync.control.action})", flush=True) - else: - print(f"[*] Inbound: {kind}", flush=True) - - if kind == 'task_request': - self._handle_task(msg.task_request) - - elif kind == 'task_cancel': - if self.skills.cancel(msg.task_cancel.task_id): - self._send_response(msg.task_cancel.task_id, None, agent_pb2.TaskResponse.CANCELLED) - - elif kind == 'work_pool_update': - # Claim logical idle tasks from global pool with slight randomized jitter - # to prevent thundering herd where every node claims the same task at the exact same ms. - if len(self.skills.get_active_ids()) < config.MAX_SKILL_WORKERS: - for tid in msg.work_pool_update.available_task_ids: - # Deterministic delay based on node_id to distribute claims - import random - time.sleep(random.uniform(0.1, 0.5)) - - self.task_queue.put(agent_pb2.ClientTaskMessage( - task_claim=agent_pb2.TaskClaimRequest(task_id=tid, node_id=self.node_id) - )) - - elif kind == 'claim_status': - status = "GRANTED" if msg.claim_status.granted else "DENIED" - print(f" [๐Ÿ“ฆ] Claim {msg.claim_status.task_id}: {status} ({msg.claim_status.reason})", flush=True) - - elif kind == 'file_sync': - self._handle_file_sync(msg.file_sync) - + if kind == 'task_request': self._handle_task(msg.task_request) + elif kind == 'task_cancel': self._handle_cancel(msg.task_cancel) + elif kind == 'work_pool_update': self._handle_work_pool(msg.work_pool_update) + elif kind == 'file_sync': self._handle_file_sync(msg.file_sync) elif kind == 'policy_update': - print(f" [๐Ÿ”’] Live Sandbox Policy Update Received.") self.sandbox.sync(msg.policy_update) - - # Apply skill config updates - if msg.policy_update.skill_config_json: - try: - cfg = json.loads(msg.policy_update.skill_config_json) - for name, skill in self.skills.skills.items(): - if hasattr(skill, "apply_config"): - skill.apply_config(cfg) - except Exception as e: - print(f" [!] Error applying skill config update: {e}") + self._apply_skill_config(msg.policy_update.skill_config_json) - def _on_sync_delta(self, session_id, payload): - """Callback from watcher to push local changes to server.""" - if isinstance(payload, agent_pb2.FileSyncMessage): - # Already a full message (e.g. deletion control) - self.task_queue.put(agent_pb2.ClientTaskMessage(file_sync=payload)) - else: - # Legacy/Standard chunk update - self.task_queue.put(agent_pb2.ClientTaskMessage( - file_sync=agent_pb2.FileSyncMessage( - session_id=session_id, - file_data=payload - ) - )) + def _handle_cancel(self, cancel_req): + """Cancels an active task.""" + if self.skills.cancel(cancel_req.task_id): + self._send_response(cancel_req.task_id, None, agent_pb2.TaskResponse.CANCELLED) + + def _handle_work_pool(self, update): + """Claims tasks from the global work pool with randomized backoff.""" + if len(self.skills.get_active_ids()) < config.MAX_SKILL_WORKERS: + for tid in update.available_task_ids: + import random + time.sleep(random.uniform(0.1, 0.5)) + self.task_queue.put(agent_pb2.ClientTaskMessage( + task_claim=agent_pb2.TaskClaimRequest(task_id=tid, node_id=self.node_id) + )) def _handle_file_sync(self, fs): - """Processes inbound file synchronization messages from the Orchestrator.""" + """Dispatches file sync messages to specialized sub-handlers.""" sid = fs.session_id - # LOGGING - type_str = fs.WhichOneof('payload') - print(f" [๐Ÿ“] Sync MSG: {type_str} | Session: {sid}") + if fs.HasField("manifest"): self._on_sync_manifest(sid, fs.manifest) + elif fs.HasField("file_data"): self._on_sync_data(sid, fs.file_data) + elif fs.HasField("control"): self._on_sync_control(sid, fs.control, fs.task_id) - if fs.HasField("manifest"): - needs_update = self.sync_mgr.handle_manifest( - sid, - fs.manifest, - on_purge_callback=lambda p: self.watcher.acknowledge_remote_delete(sid, p) - ) - if needs_update: - print(f" [๐Ÿ“โš ๏ธ] Drift Detected for {sid}: {len(needs_update)} files need sync") - self.task_queue.put(agent_pb2.ClientTaskMessage( - file_sync=agent_pb2.FileSyncMessage( - session_id=sid, - status=agent_pb2.SyncStatus( - code=agent_pb2.SyncStatus.RECONCILE_REQUIRED, - message=f"Drift detected in {len(needs_update)} files", - reconcile_paths=needs_update - ) - ) - )) - else: - self.task_queue.put(agent_pb2.ClientTaskMessage( - file_sync=agent_pb2.FileSyncMessage( - session_id=sid, - status=agent_pb2.SyncStatus(code=agent_pb2.SyncStatus.OK, message="Synchronized") - ) - )) - elif fs.HasField("file_data"): - # M6: High-Concurrency Disk I/O Offloading with Backpressure - # We use a semaphore to limit the number of pending I/O tasks in the executor queue. - # This prevents memory ballooning if the network is faster than the disk. - self.io_semaphore.acquire() - try: - self.io_executor.submit(self._async_write_chunk, sid, fs.file_data) - except Exception: - self.io_semaphore.release() # Release if submission fails - elif fs.HasField("control"): - ctrl = fs.control - print(f" [๐Ÿ“] Control Action: {ctrl.action} (Path: {ctrl.path})") - if ctrl.action == agent_pb2.SyncControl.START_WATCHING: - # Path relative to sync dir or absolute - watch_path = ctrl.path if os.path.isabs(ctrl.path) else os.path.join(self.sync_mgr.get_session_dir(sid), ctrl.path) - print(f" [๐Ÿ“๐Ÿ‘๏ธ] Starting Watcher on: {watch_path}") - self.watcher.start_watching(sid, watch_path) - elif ctrl.action == agent_pb2.SyncControl.STOP_WATCHING: - self.watcher.stop_watching(sid) - elif ctrl.action == agent_pb2.SyncControl.LOCK: - self.watcher.set_lock(sid, True) - elif ctrl.action == agent_pb2.SyncControl.UNLOCK: - self.watcher.set_lock(sid, False) - elif ctrl.action == agent_pb2.SyncControl.REFRESH_MANIFEST: - if ctrl.request_paths: - print(f" [๐Ÿ“๐Ÿ“ค] Turbo Pushing {len(ctrl.request_paths)} Requested Files for {sid} in parallel") - from concurrent.futures import ThreadPoolExecutor - requested = list(ctrl.request_paths) - # Increased worker count for high-concurrency sync - max_workers = min(100, len(requested)) - with ThreadPoolExecutor(max_workers=max_workers) as executor: - for path in requested: - # Pre-check existence to avoid redundant executor tasks - watch_path = self._get_base_dir(sid, create=False) - abs_path = os.path.normpath(os.path.join(watch_path, path)) - if os.path.exists(abs_path): - executor.submit(self._push_file, sid, path) - else: - print(f" [๐Ÿ“โ“] Skipping push for non-existent file: {path}") - else: - # Node -> Server Manifest Push - self._push_full_manifest(sid, ctrl.path) - elif ctrl.action == agent_pb2.SyncControl.RESYNC: - self._push_full_manifest(sid, ctrl.path) - elif ctrl.action == agent_pb2.SyncControl.PURGE: - print(f" [๐Ÿ“๐Ÿงน] Node instructed to purge session sync data: {sid}") - self.watcher.stop_watching(sid) # Stop watching before deleting - self.sync_mgr.purge(sid) - elif ctrl.action == agent_pb2.SyncControl.CLEANUP: - print(f" [๐Ÿ“๐Ÿงน] Node proactively cleaning up defunct sessions. Active: {ctrl.request_paths}") - active_sessions = list(ctrl.request_paths) - self.sync_mgr.cleanup_unused_sessions(active_sessions) - - # --- M6: FS Explorer Handlers --- - elif ctrl.action == agent_pb2.SyncControl.LIST: - print(f" [๐Ÿ“๐Ÿ“‚] List Directory: {ctrl.path}") - self._push_full_manifest(sid, ctrl.path, task_id=fs.task_id, shallow=True) - elif ctrl.action == agent_pb2.SyncControl.READ: - print(f" [๐Ÿ“๐Ÿ“„] Read File: {ctrl.path}") - self._push_file(sid, ctrl.path, task_id=fs.task_id) - elif ctrl.action == agent_pb2.SyncControl.WRITE: - print(f" [๐Ÿ“๐Ÿ’พ] Write File: {ctrl.path} (is_dir={ctrl.is_dir})") - self._handle_fs_write(sid, ctrl.path, ctrl.content, ctrl.is_dir, task_id=fs.task_id) - elif ctrl.action == agent_pb2.SyncControl.DELETE: - print(f" [๐Ÿ“๐Ÿ—‘๏ธ] Delete Fragment: {ctrl.path}") - self._handle_fs_delete(sid, ctrl.path, task_id=fs.task_id) + def _on_sync_manifest(self, sid, manifest): + """Reconciles local state with a remote manifest.""" + drift = self.sync_mgr.handle_manifest(sid, manifest, on_purge_callback=lambda p: self.watcher.acknowledge_remote_delete(sid, p)) + status = agent_pb2.SyncStatus( + code=agent_pb2.SyncStatus.RECONCILE_REQUIRED if drift else agent_pb2.SyncStatus.OK, + message=f"Drift in {len(drift)} files" if drift else "Synchronized", + reconcile_paths=drift + ) + self.task_queue.put(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=sid, status=status))) + + def _on_sync_data(self, sid, file_data): + """Offloads disk I/O to a background worker pool.""" + self.io_semaphore.acquire() + try: self.io_executor.submit(self._async_write_chunk, sid, file_data) + except: self.io_semaphore.release() + + def _on_sync_control(self, sid, ctrl, task_id): + """Handles sync control actions like watching, locking, or directory listing.""" + action = ctrl.action + if action == agent_pb2.SyncControl.START_WATCHING: + self.watcher.start_watching(sid, ctrl.path if os.path.isabs(ctrl.path) else os.path.join(self.sync_mgr.get_session_dir(sid), ctrl.path)) + elif action == agent_pb2.SyncControl.STOP_WATCHING: self.watcher.stop_watching(sid) + elif action == agent_pb2.SyncControl.LOCK: self.watcher.set_lock(sid, True) + elif action == agent_pb2.SyncControl.UNLOCK: self.watcher.set_lock(sid, False) + elif action in (agent_pb2.SyncControl.REFRESH_MANIFEST, agent_pb2.SyncControl.RESYNC): + if ctrl.request_paths: + for p in ctrl.request_paths: self.io_executor.submit(self._push_file, sid, p) + else: self._push_full_manifest(sid, ctrl.path) + elif action == agent_pb2.SyncControl.PURGE: + self.watcher.stop_watching(sid) + self.sync_mgr.purge(sid) + elif action == agent_pb2.SyncControl.LIST: self._push_full_manifest(sid, ctrl.path, task_id=task_id, shallow=True) + elif action == agent_pb2.SyncControl.READ: self._push_file(sid, ctrl.path, task_id=task_id) + elif action == agent_pb2.SyncControl.WRITE: self._handle_fs_write(sid, ctrl.path, ctrl.content, ctrl.is_dir, task_id=task_id) + elif action == agent_pb2.SyncControl.DELETE: self._handle_fs_delete(sid, ctrl.path, task_id=task_id) def _get_base_dir(self, session_id, create=False): - """Helper to resolve the effective root for a session (Watcher > SyncDir).""" - if session_id == "__fs_explorer__": - root = config.FS_ROOT - print(f" [๐Ÿ“] Explorer Root: {root}") - return root - - # Priority 1: If we have an active watcher, use its root (e.g. Seed from Local) + """Resolves the session's effective root directory.""" + if session_id == "__fs_explorer__": return config.FS_ROOT watched = self.watcher.get_watch_path(session_id) - if watched: - print(f" [๐Ÿ“] Using Watched Path as Base: {watched}") - return watched - - # Priority 2: Standard session-scoped sync directory - fallback = self.sync_mgr.get_session_dir(session_id, create=create) - print(f" [๐Ÿ“] Falling back to SyncDir: {fallback}") - return fallback + return watched if watched else self.sync_mgr.get_session_dir(session_id, create=create) def _push_full_manifest(self, session_id, rel_path=".", task_id="", shallow=False): - """Pushes the current local manifest back to the server.""" - print(f" [๐Ÿ“๐Ÿ“ค] Pushing {'Shallow' if shallow else 'Full'} Manifest for {session_id}") - + """Generates and pushes a local file manifest to the server.""" base_dir = self._get_base_dir(session_id, create=True) - # Ensure rel_path is relative if it's within a session sync dir safe_rel = rel_path.lstrip("/") if session_id != "__fs_explorer__" else rel_path watch_path = os.path.normpath(os.path.join(base_dir, safe_rel)) if not os.path.exists(watch_path): - # If the specific sub-path doesn't exist, try to create it if it's within the session dir - if session_id != "__fs_explorer__": - os.makedirs(watch_path, exist_ok=True) - else: - self.task_queue.put(agent_pb2.ClientTaskMessage( - file_sync=agent_pb2.FileSyncMessage( - session_id=session_id, - task_id=task_id, - status=agent_pb2.SyncStatus(code=agent_pb2.SyncStatus.ERROR, message=f"Path {rel_path} not found") - ) - )) - return + if session_id != "__fs_explorer__": os.makedirs(watch_path, exist_ok=True) + else: self._send_sync_error(session_id, task_id, f"Path {rel_path} not found") + if session_id == "__fs_explorer__": return files = [] try: if shallow: - # Optimized for Explorer: immediate children only, no hashing with os.scandir(watch_path) as it: for entry in it: - if entry.name in [".cortex_sync"] and rel_path in [".", "", "/"]: - continue - - # Native Orphan Syslink Cleanup - if entry.is_symlink() and not os.path.exists(entry.path): - try: - os.unlink(entry.path) - print(f" [๐Ÿ“๐Ÿงน] Cleaned up broken ghost symlink during refresh: {entry.name}") - except: pass - continue - + if entry.name == ".cortex_sync": continue is_dir = entry.is_dir() if not entry.is_symlink() else os.path.isdir(entry.path) - # Use metadata only - try: - stats = entry.stat() - size = stats.st_size if not is_dir else 0 - except: size = 0 - - # Calculate path relative to the actual base_sync_dir / session_dir - # rel_path is the directory we are currently browsing. - # entry.name is the file within it. - item_rel_path = os.path.relpath(os.path.join(watch_path, entry.name), base_dir).replace("\\", "/") - - files.append(agent_pb2.FileInfo(path=item_rel_path, size=size, hash="", is_dir=is_dir)) + item_rel = os.path.relpath(os.path.join(watch_path, entry.name), base_dir).replace("\\", "/") + files.append(agent_pb2.FileInfo(path=item_rel, size=entry.stat().st_size if not is_dir else 0, hash="", is_dir=is_dir)) else: - # Deep walk with full hashes for reconciliation for root, dirs, filenames in os.walk(watch_path): - for filename in filenames: - abs_path = os.path.join(root, filename) - # r_path must be relative to base_dir so the server correctly joins it to the mirror root - r_path = os.path.relpath(abs_path, base_dir).replace("\\", "/") - try: - # Memory-safe hashing with metadata cache - file_hash = self.sync_mgr.get_file_hash(abs_path) - if not file_hash: continue - - files.append(agent_pb2.FileInfo( - path=r_path, - size=os.path.getsize(abs_path), - hash=file_hash, - is_dir=False - )) - except Exception: continue - + for name in filenames: + abs_p = os.path.join(root, name) + h = self.sync_mgr.get_file_hash(abs_p) + if h: files.append(agent_pb2.FileInfo(path=os.path.relpath(abs_p, base_dir).replace("\\", "/"), size=os.path.getsize(abs_p), hash=h, is_dir=False)) for d in dirs: - abs_path = os.path.join(root, d) - # r_path must be relative to base_dir so the server correctly joins it to the mirror root - r_path = os.path.relpath(abs_path, base_dir).replace("\\", "/") - files.append(agent_pb2.FileInfo(path=r_path, size=0, hash="", is_dir=True)) + files.append(agent_pb2.FileInfo(path=os.path.relpath(os.path.join(root, d), base_dir).replace("\\", "/"), size=0, hash="", is_dir=True)) except Exception as e: - print(f" [โŒ] Manifest generation failed for {rel_path}: {e}") - self.task_queue.put(agent_pb2.ClientTaskMessage( - file_sync=agent_pb2.FileSyncMessage( - session_id=session_id, - task_id=task_id, - status=agent_pb2.SyncStatus(code=agent_pb2.SyncStatus.ERROR, message=str(e)) - ) - )) - return + return self._send_sync_error(session_id, task_id, str(e)) + self._send_manifest_chunks(session_id, task_id, rel_path, files) + + def _send_manifest_chunks(self, sid, tid, root, files): + """Splits large manifests into chunks for gRPC streaming.""" chunk_size = 1000 - total_files = len(files) - - if total_files == 0: - self.task_queue.put(agent_pb2.ClientTaskMessage( - file_sync=agent_pb2.FileSyncMessage( - session_id=session_id, - task_id=task_id, - manifest=agent_pb2.DirectoryManifest(root_path=rel_path, files=[], chunk_index=0, is_final=True) - ) - )) - else: - for i in range(0, total_files, chunk_size): - chunk = files[i:i+chunk_size] - is_final = (i + chunk_size) >= total_files - chunk_index = i // chunk_size - - self.task_queue.put(agent_pb2.ClientTaskMessage( - file_sync=agent_pb2.FileSyncMessage( - session_id=session_id, - task_id=task_id, - manifest=agent_pb2.DirectoryManifest( - root_path=rel_path, - files=chunk, - chunk_index=chunk_index, - is_final=is_final - ) - ) - )) + if not files: + self.task_queue.put(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=sid, task_id=tid, manifest=agent_pb2.DirectoryManifest(root_path=root, is_final=True)))) + return + for i in range(0, len(files), chunk_size): + chunk = files[i:i+chunk_size] + self.task_queue.put(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=sid, task_id=tid, + manifest=agent_pb2.DirectoryManifest(root_path=root, files=chunk, chunk_index=i//chunk_size, is_final=(i+chunk_size >= len(files)))))) def _handle_fs_write(self, session_id, rel_path, content, is_dir, task_id=""): - """Modular FS Write/Create.""" + """Handles single file or directory creation.""" try: - base_dir = os.path.normpath(self._get_base_dir(session_id, create=True)) - # Ensure rel_path is relative if it's within a session sync dir - safe_rel = rel_path.lstrip("/") if session_id != "__fs_explorer__" else rel_path - target_path = os.path.normpath(os.path.join(base_dir, safe_rel)) - print(f" [๐Ÿ“๐Ÿ’พ] target_path: {target_path} (base_dir: {base_dir})") + base = os.path.normpath(self._get_base_dir(session_id, create=True)) + target = os.path.normpath(os.path.join(base, rel_path.lstrip("/") if session_id != "__fs_explorer__" else rel_path)) + if not target.startswith(base) and (not config.FS_ROOT or not target.startswith(os.path.normpath(config.FS_ROOT))): + raise Exception("Path traversal blocked.") - # M6: Check if path is within session base_dir OR global config.FS_ROOT - allowed = target_path.startswith(base_dir) - if not allowed and config.FS_ROOT: - allowed = target_path.startswith(os.path.normpath(config.FS_ROOT)) - - if not allowed: - raise Exception(f"Path traversal attempt blocked: {target_path} is outside {base_dir} (config.FS_ROOT: {config.FS_ROOT})") - - if is_dir: - os.makedirs(target_path, exist_ok=True) + if is_dir: os.makedirs(target, exist_ok=True) else: - os.makedirs(os.path.dirname(target_path), exist_ok=True) - with open(target_path, "wb") as f: - f.write(content) + os.makedirs(os.path.dirname(target), exist_ok=True) + with open(target, "wb") as f: f.write(content) - # Send OK status - self.task_queue.put(agent_pb2.ClientTaskMessage( - file_sync=agent_pb2.FileSyncMessage( - session_id=session_id, - task_id=task_id, - status=agent_pb2.SyncStatus(code=agent_pb2.SyncStatus.OK, message=f"{'Directory' if is_dir else 'File'} written") - ) - )) - # Trigger manifest refresh so UI updates + self._send_sync_ok(session_id, task_id, "Resource written") self._push_full_manifest(session_id, os.path.dirname(rel_path) or ".", task_id=task_id, shallow=True) - except Exception as e: - self.task_queue.put(agent_pb2.ClientTaskMessage( - file_sync=agent_pb2.FileSyncMessage( - session_id=session_id, - task_id=task_id, - status=agent_pb2.SyncStatus(code=agent_pb2.SyncStatus.ERROR, message=str(e)) - ) - )) + except Exception as e: self._send_sync_error(session_id, task_id, str(e)) def _handle_fs_delete(self, session_id, rel_path, task_id=""): - """Modular FS Delete.""" + """Removes a file or directory from the node.""" try: - base_dir = os.path.normpath(self._get_base_dir(session_id)) - # Ensure rel_path is relative if it's within a session sync dir - safe_rel = rel_path.lstrip("/") if session_id != "__fs_explorer__" else rel_path - target_path = os.path.normpath(os.path.join(base_dir, safe_rel)) + base = os.path.normpath(self._get_base_dir(session_id)) + target = os.path.normpath(os.path.join(base, rel_path.lstrip("/") if session_id != "__fs_explorer__" else rel_path)) + if not target.startswith(base) and (not config.FS_ROOT or not target.startswith(os.path.normpath(config.FS_ROOT))): + raise Exception("Path traversal blocked.") - allowed = target_path.startswith(base_dir) - if not allowed and config.FS_ROOT: - allowed = target_path.startswith(os.path.normpath(config.FS_ROOT)) - - if not allowed: - raise Exception(f"Path traversal attempt blocked: {target_path} is outside {base_dir} (config.FS_ROOT: {config.FS_ROOT})") + self.watcher.acknowledge_remote_delete(session_id, rel_path) + if os.path.isdir(target): shutil.rmtree(target) + else: os.remove(target) - if not os.path.exists(target_path): - raise Exception("File not found") - - # Acknowledge deletion to prevent watchdog echo loop - self.watcher.acknowledge_remote_delete(session_id, safe_rel) - - import shutil - if os.path.isdir(target_path): - shutil.rmtree(target_path) - else: - os.remove(target_path) - - # Send OK status - self.task_queue.put(agent_pb2.ClientTaskMessage( - file_sync=agent_pb2.FileSyncMessage( - session_id=session_id, - task_id=task_id, - status=agent_pb2.SyncStatus(code=agent_pb2.SyncStatus.OK, message=f"Deleted {rel_path}") - ) - )) - # Trigger manifest refresh so UI updates + self._send_sync_ok(session_id, task_id, f"Deleted {rel_path}") self._push_full_manifest(session_id, os.path.dirname(rel_path) or ".", task_id=task_id, shallow=True) - except Exception as e: - self.task_queue.put(agent_pb2.ClientTaskMessage( - file_sync=agent_pb2.FileSyncMessage( - session_id=session_id, - task_id=task_id, - status=agent_pb2.SyncStatus(code=agent_pb2.SyncStatus.ERROR, message=str(e)) - ) - )) + except Exception as e: self._send_sync_error(session_id, task_id, str(e)) def _async_write_chunk(self, sid, payload): - """Worker function for background parallelized I/O with out-of-order chunk support.""" + """Worker for segmented file writing with path-level locking.""" path = payload.path - - # M6: Path-Level Locking for Sequential Consistency - # While chunks can arrive out of order, we must process them sequentially - # to guarantee the final hash verify/swap is correctly ordered. with self.lock_map_mutex: - lock_key = (sid, path) - if lock_key not in self.write_locks: - self.write_locks[lock_key] = threading.Lock() - lock = self.write_locks[lock_key] + if (sid, path) not in self.write_locks: self.write_locks[(sid, path)] = threading.Lock() + lock = self.write_locks[(sid, path)] with lock: try: - if payload.chunk_index == 0: - self.watcher.suppress_path(sid, path) - + if payload.chunk_index == 0: self.watcher.suppress_path(sid, path) success = self.sync_mgr.write_chunk(sid, payload) - if payload.is_final: - if success and payload.hash: - self.watcher.acknowledge_remote_write(sid, path, payload.hash) - + if success and payload.hash: self.watcher.acknowledge_remote_write(sid, path, payload.hash) self.watcher.unsuppress_path(sid, path) - print(f" [๐Ÿ“] Async File Sync Complete (Sequenced Parallel): {path} (Success: {success})") - - # M6: Clean up the lock entry after finalization - with self.lock_map_mutex: - if lock_key in self.write_locks: - del self.write_locks[lock_key] - - # Report status back to orchestrator - status = agent_pb2.SyncStatus.OK if success else agent_pb2.SyncStatus.ERROR - self.task_queue.put(agent_pb2.ClientTaskMessage( - file_sync=agent_pb2.FileSyncMessage( - session_id=sid, - status=agent_pb2.SyncStatus(code=status, message=f"File {path} synced") - ) - )) - finally: - # Always release the semaphore to allow the next I/O task - self.io_semaphore.release() + with self.lock_map_mutex: self.write_locks.pop((sid, path), None) + self._send_sync_ok(sid, "", f"File {path} synced") + finally: self.io_semaphore.release() def _push_file(self, session_id, rel_path, task_id=""): - """Pushes a specific file from node to server.""" - watch_path = os.path.normpath(self._get_base_dir(session_id, create=False)) - # Ensure rel_path is relative if it's within a session sync dir - safe_rel = rel_path.lstrip("/") if session_id != "__fs_explorer__" else rel_path - abs_path = os.path.normpath(os.path.join(watch_path, safe_rel)) - - allowed = abs_path.startswith(watch_path) - if not allowed and config.FS_ROOT: - allowed = abs_path.startswith(os.path.normpath(config.FS_ROOT)) + """Streams a file from node to server using 4MB chunks.""" + base = os.path.normpath(self._get_base_dir(session_id, create=False)) + target = os.path.normpath(os.path.join(base, rel_path.lstrip("/") if session_id != "__fs_explorer__" else rel_path)) + if not target.startswith(base) and (not config.FS_ROOT or not target.startswith(os.path.normpath(config.FS_ROOT))): return - if not allowed: - print(f" [๐Ÿ“๐Ÿšซ] Blocked traversal attempt in _push_file: {rel_path} (Valid Roots: {watch_path}, config.FS_ROOT: {config.FS_ROOT})") - return - - if not os.path.exists(abs_path): - print(f" [๐Ÿ“โ“] Requested file {rel_path} not found on node") - if task_id: - # Immediately notify the Hub so it doesn't wait the full timeout - self.task_queue.put(agent_pb2.ClientTaskMessage( - file_sync=agent_pb2.FileSyncMessage( - session_id=session_id, - task_id=task_id, - status=agent_pb2.SyncStatus( - code=agent_pb2.SyncStatus.ERROR, - message=f"File not found: {rel_path}" - ) - ) - )) + if not os.path.exists(target): + if task_id: self._send_sync_error(session_id, task_id, "File not found") return - # Optimization: 4MB Incremental Hashing + Zero Throttling hasher = hashlib.sha256() - file_size = os.path.getsize(abs_path) - chunk_size = 4 * 1024 * 1024 - total_chunks = (file_size + chunk_size - 1) // chunk_size if file_size > 0 else 1 - + size, chunk_size = os.path.getsize(target), 4 * 1024 * 1024 try: - with open(abs_path, "rb") as f: - index = 0 + with open(target, "rb") as f: + idx = 0 while True: chunk = f.read(chunk_size) - if not chunk and index > 0: - break - + if not chunk and idx > 0: break hasher.update(chunk) - offset = f.tell() - len(chunk) - is_final = f.tell() >= file_size - - # Compress Chunk for transit - compressed_chunk = zlib.compress(chunk) - - # M6: Use dictionary unpack for safe assignment (robust against old proto versions) - payload_fields = { - "path": rel_path.replace("\\", "/"), - "chunk": compressed_chunk, - "chunk_index": index, - "is_final": is_final, - "hash": hasher.hexdigest() if is_final else "", - "offset": offset, - "compressed": True, - } - # Only add new fields if supported by the compiled proto - if "total_chunks" in agent_pb2.FilePayload.DESCRIPTOR.fields_by_name: - payload_fields["total_chunks"] = total_chunks - payload_fields["total_size"] = file_size - - self.task_queue.put(agent_pb2.ClientTaskMessage( - file_sync=agent_pb2.FileSyncMessage( - session_id=session_id, - task_id=task_id, - file_data=agent_pb2.FilePayload(**payload_fields) - ) - )) - + is_final = f.tell() >= size + self.task_queue.put(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=session_id, task_id=task_id, + file_data=agent_pb2.FilePayload(path=rel_path.replace("\\", "/"), chunk=zlib.compress(chunk), chunk_index=idx, is_final=is_final, hash=hasher.hexdigest() if is_final else "", compressed=True)))) if is_final: break - index += 1 - except Exception as e: - print(f" [๐Ÿ“๐Ÿ“ค] Error pushing {rel_path}: {e}") + idx += 1 + except Exception as e: logger.error(f"Push Error: {e}") def _handle_task(self, task): - print(f"[*] Task Launch: {task.task_id}", flush=True) - # 1. Cryptographic Signature Verification + """Verifies and submits a skill task for execution.""" if not verify_task_signature(task): - print(f"[!] Signature Validation Failed for {task.task_id}", flush=True) - # Report back to hub so the frontend gets a real error, not a silent timeout - self._send_response( - task.task_id, - agent_pb2.TaskResponse( - task_id=task.task_id, - status=agent_pb2.TaskResponse.ERROR, - stderr="[NODE] HMAC signature mismatch โ€” check that AGENT_SECRET_KEY on the node matches the hub SECRET_KEY. Task rejected.", - ) - ) - return - - print(f"[โœ…] Validated task {task.task_id}", flush=True) + return self._send_response(task.task_id, agent_pb2.TaskResponse(task_id=task.task_id, status=agent_pb2.TaskResponse.ERROR, stderr="HMAC signature mismatch")) - # 2. Skill Manager Submission success, reason = self.skills.submit(task, self.sandbox, self._on_finish, self._on_event) if not success: - print(f"[!] Execution Rejected: {reason}", flush=True) - self._send_response( - task.task_id, - agent_pb2.TaskResponse( - task_id=task.task_id, - status=agent_pb2.TaskResponse.ERROR, - stderr=f"[NODE] Execution Rejected: {reason}", - ) - ) + self._send_response(task.task_id, agent_pb2.TaskResponse(task_id=task.task_id, status=agent_pb2.TaskResponse.ERROR, stderr=f"Rejection: {reason}")) def _on_event(self, event): - """Live Event Tunneler: Routes skill events into the main stream.""" - if isinstance(event, agent_pb2.ClientTaskMessage): - self.task_queue.put(event) + """Forwards skill events to the gRPC stream.""" + self.task_queue.put(event if isinstance(event, agent_pb2.ClientTaskMessage) else agent_pb2.ClientTaskMessage(skill_event=event)) - def _on_finish(self, tid, res, trace): - """Final Completion Callback: Routes task results back to server.""" - print(f"[*] Completion: {tid}", flush=True) - # 0 is SUCCESS, 1 is ERROR in Protobuf - status = res.get('status', agent_pb2.TaskResponse.ERROR) - - tr = agent_pb2.TaskResponse( - task_id=tid, status=status, - stdout=res.get('stdout',''), - stderr=res.get('stderr',''), - trace_id=trace - ) - self._send_response(tid, tr) + def _on_finish(self, task_id, result, trace_id): + """Finalizes a task and sends the response back to the Hub.""" + self._send_response(task_id, agent_pb2.TaskResponse(task_id=task_id, stdout=result.get("stdout", ""), stderr=result.get("stderr", ""), status=result.get("status", 0), trace_id=trace_id)) - def _send_response(self, tid, tr=None, status=None): - """Utility for placing response messages into the gRPC outbound queue.""" - if tr: - self.task_queue.put(agent_pb2.ClientTaskMessage(task_response=tr)) - else: - self.task_queue.put(agent_pb2.ClientTaskMessage( - task_response=agent_pb2.TaskResponse(task_id=tid, status=status) - )) + def _send_response(self, task_id, response, status_override=None): + """Helper to queue a standard task response.""" + res = response or agent_pb2.TaskResponse(task_id=task_id) + if status_override: res.status = status_override + self.task_queue.put(agent_pb2.ClientTaskMessage(task_response=res)) - def stop(self): - """Gracefully stops all background services and skills.""" - print(f"\n[๐Ÿ›‘] Stopping Agent Node: {self.node_id}") + def _send_sync_ok(self, sid, tid, msg): + self.task_queue.put(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=sid, task_id=tid, status=agent_pb2.SyncStatus(code=agent_pb2.SyncStatus.OK, message=msg)))) + + def _send_sync_error(self, sid, tid, msg): + self.task_queue.put(agent_pb2.ClientTaskMessage(file_sync=agent_pb2.FileSyncMessage(session_id=sid, task_id=tid, status=agent_pb2.SyncStatus(code=agent_pb2.SyncStatus.ERROR, message=msg)))) + + def _on_sync_delta(self, session_id, payload): + self.task_queue.put(agent_pb2.ClientTaskMessage(file_sync=payload if isinstance(payload, agent_pb2.FileSyncMessage) else agent_pb2.FileSyncMessage(session_id=session_id, file_data=payload))) + + def shutdown(self): + """Gracefully shuts down the node.""" self._stop_event.set() - - # 1. Stop Skills self.skills.shutdown() - - # 2. Stop Watcher - self.watcher.shutdown() - - # 3. Shutdown IO Executor - self.io_executor.shutdown(wait=False) - - # 4. Close gRPC channel - if self.channel: - try: - self.channel.close() - except Exception as e: - print(f"[!] Error closing channel: {e}") + if self.channel: self.channel.close() + print("[*] Node shutdown complete.") diff --git a/agent-node/src/agent_node/skills/shell_bridge.py b/agent-node/src/agent_node/skills/shell_bridge.py index 13820f3..bc18b74 100644 --- a/agent-node/src/agent_node/skills/shell_bridge.py +++ b/agent-node/src/agent_node/skills/shell_bridge.py @@ -1,6 +1,23 @@ from agent_node.skills.base import BaseSkill from agent_node.skills.terminal_backends import get_terminal_backend from protos import agent_pb2 +import re +import threading +import time +import tempfile +import os +from agent_node.core.regex_patterns import ( + COMPILED_PROMPT_PATTERNS, + ANSI_ESCAPE, + EXIT_CODE_PATTERN, + ECHO_CLEANUP_ANSI, + ECHO_CLEANUP_BRACKET, + STRIP_START_FENCE, + STRIP_BRACKET_FENCE, + ECHO_START_PATTERN, + ECHO_END_PATTERN, + PROTOCOL_HINT_PATTERN +) class ShellSkill(BaseSkill): """Admin Console Skill: Persistent stateful Shell via Abstract Terminal Backend.""" @@ -9,14 +26,8 @@ self.sessions = {} # session_id -> {backend, thread, last_activity, ...} self.lock = threading.Lock() - # Phase 3: Prompt Patterns for Edge Intelligence - self.PROMPT_PATTERNS = [ - r"[\r\n].*[@\w\.\-]+:.*[#$]\s*$", # bash/zsh: user@host:~$ - r">>>\s*$", # python - r"\.\.\.\s*$", # python multi-line - r">\s*$", # node/js - r"PS\s+.*>\s*$", # powershell - ] + # Patterns moved to core/regex_patterns.py + self.PROMPT_PATTERNS = COMPILED_PROMPT_PATTERNS # --- M7: Idle Session Reaper --- self.reaper_thread = threading.Thread(target=self._session_reaper, daemon=True, name="ShellReaper") @@ -40,6 +51,7 @@ self.sessions.pop(sid, None) def _ensure_session(self, session_id, cwd, on_event): + """Retrieves or initializes a persistent terminal session.""" with self.lock: if session_id in self.sessions: self.sessions[session_id]["last_activity"] = time.time() @@ -47,7 +59,6 @@ print(f" [๐Ÿš] Initializing Persistent Shell Session: {session_id}") backend = get_terminal_backend() - import os backend.spawn(cwd=cwd, env=os.environ.copy()) print(f" [๐Ÿš] Terminal Spawned (PID Check: {backend.is_alive()})") @@ -60,175 +71,149 @@ "write_lock": threading.Lock() } - def reader(): - while True: - try: - data = backend.read(4096) - if not data: - if not backend.is_alive(): - break - time.sleep(0.05) - continue - - if isinstance(data, str): - decoded = data - else: - decoded = data.decode("utf-8", errors="replace") - - # M7: Protocol-Aware Framing (OSC 1337) - # We use non-printable fences to accurately slice the command output - with self.lock: - active_tid = sess.get("active_task") - if active_tid and sess.get("buffer_file"): - start_fence = f"\x1b]1337;TaskStart;id={active_tid}\x07" - end_fence_prefix = f"\x1b]1337;TaskEnd;id={active_tid};exit=" - - bracket_start_fence = f"[[1337;TaskStart;id={active_tid}]]" - bracket_end_fence_prefix = f"[[1337;TaskEnd;id={active_tid};exit=" - - sess["buffer_file"].write(decoded) - sess["buffer_file"].flush() - - # Byte-accurate 16KB tail for fence detection - sess["tail_buffer"] = (sess.get("tail_buffer", "") + decoded)[-16384:] - - # Clean ANSI from the tail buffer to prevent ConPTY injecting random cursor positions inside our marker strings - import re - ansi_escape = re.compile(r'\x1b(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') - clean_tail = ansi_escape.sub('', sess["tail_buffer"]) - - if end_fence_prefix in clean_tail or bracket_end_fence_prefix in clean_tail: - # Task completed via protocol fence! - try: - is_bracket = bracket_end_fence_prefix in clean_tail - active_end_prefix = bracket_end_fence_prefix if is_bracket else end_fence_prefix - active_start_fence = bracket_start_fence if is_bracket else start_fence - - # Extract exit code from the trailer: TaskEnd;id=...;exit=N - after_end = clean_tail.split(active_end_prefix)[1] - exit_match = re.search(r'(\d+)', after_end) - exit_code = int(exit_match.group(1)) if exit_match else 0 - - bf = sess["buffer_file"] - bf.seek(0) - full_raw = bf.read() - clean_full_raw = ansi_escape.sub('', full_raw) - - print(f" [๐ŸšDEBUG] Fence Match! Buffer: {len(clean_full_raw)} bytes. Tail: {repr(clean_full_raw[-200:])}") - - # Clean extraction between fences (using ANSI stripped content) - if active_start_fence in clean_full_raw: - # We take the content AFTER the last start fence to avoid echo-back collision - content = clean_full_raw.split(active_start_fence)[-1].split(active_end_prefix)[0] - else: - content = clean_full_raw.split(active_end_prefix)[0] - - # Minimal post-processing: remove the echo of the end command itself - content = re.sub(r'echo \x1b]1337;TaskEnd;.*', '', content).strip() - content = re.sub(r'echo \[\[1337;TaskEnd;.*', '', content).strip() - - sess["result"]["stdout"] = content - sess["result"]["status"] = 0 if exit_code == 0 else 1 - - sess["buffer_file"].close() - sess["buffer_file"] = None - - # Signal completion via safe lookup - avoid racing with finally block - finish_event = sess.get("event") - if finish_event: - finish_event.set() - - # Strip the protocol fences from the live UI stream to keep it clean (ANSI and Bracket) - decoded = re.sub(r'\x1b]1337;Task(Start|End);id=.*?\x07', '', decoded) - decoded = re.sub(r'\[\[1337;Task(Start|End);id=.*?\]\]', '', decoded) - except Exception as e: - print(f" [๐Ÿšโš ๏ธ] Protocol parsing failed: {e}") - finish_event = sess.get("event") - if finish_event: - finish_event.set() - - # Stream terminal output back to UI - if on_event: - import re - - # M9: Filter Native Escaped cmd Echo framing from bouncing back to the UI - # e.g., "echo [[1337;Task^Start;id=xyz]] & " - decoded = re.sub(r'echo \s*\[\[1337;Task\^Start;id=[a-zA-Z0-9-]*\]\]\s*&\s*', '', decoded) - decoded = re.sub(r'\s*&\s*echo \s*\[\[1337;Task\^End;id=[a-zA-Z0-9-]*;exit=%errorlevel%\]\]', '', decoded) - - # M7: Line-Aware Hyper-Aggressive Stealthing - # Instead of complex regex on the whole buffer, we nuke any lines - # that carry our internal protocol baggage. - lines = decoded.splitlines(keepends=True) - clean_lines = [] - ansi_escape_ui = re.compile(r'\x1b(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') - for line in lines: - stripped_line = ansi_escape_ui.sub('', line) - # If the line contains our protocol marker, it's plumbing - drop it. - if "1337;Task" in stripped_line or "`e]" in line or "\\033]" in line: - continue - clean_lines.append(line) - - stealth_out = "".join(clean_lines) - - if stealth_out.strip(): - with self.lock: - now = time.time() - if now - sess.get("stream_window_start", 0) > 1.0: - sess["stream_window_start"] = now - sess["stream_bytes_sent"] = 0 - dropped = sess.get("stream_dropped_bytes", 0) - if dropped > 0: - drop_msg = f"\n[... {dropped:,} bytes truncated from live stream ...]\n" - event = agent_pb2.SkillEvent( - session_id=session_id, task_id=sess.get("active_task") or "", terminal_out=drop_msg - ) - on_event(agent_pb2.ClientTaskMessage(skill_event=event)) - sess["stream_dropped_bytes"] = 0 - - if sess.get("stream_bytes_sent", 0) + len(stealth_out) > 100_000: - sess["stream_dropped_bytes"] = sess.get("stream_dropped_bytes", 0) + len(stealth_out) - else: - sess["stream_bytes_sent"] = sess.get("stream_bytes_sent", 0) + len(stealth_out) - event = agent_pb2.SkillEvent( - session_id=session_id, - task_id=sess.get("active_task") or "", - terminal_out=stealth_out - ) - on_event(agent_pb2.ClientTaskMessage(skill_event=event)) - - # EDGE INTELLIGENCE: Proactively signal prompt detection - current_event = sess.get("event") - if active_tid and current_event and not current_event.is_set(): - import re - tail = sess["tail_buffer"][-100:] if len(sess["tail_buffer"]) > 100 else sess["tail_buffer"] - for pattern in self.PROMPT_PATTERNS: - if re.search(pattern, tail): - p_hint = tail[-20:].strip() - prompt_event = agent_pb2.SkillEvent( - session_id=session_id, - task_id=active_tid, - prompt=p_hint - ) - on_event(agent_pb2.ClientTaskMessage(skill_event=prompt_event)) - break - except (EOFError, OSError): - break - except Exception as catch_all: - print(f" [๐ŸšโŒ] Reader thread FATAL exception: {catch_all}") - break - - print(f" [๐Ÿš] Shell Session Terminated: {session_id}") - with self.lock: - self.sessions.pop(session_id, None) - - t = threading.Thread(target=reader, daemon=True, name=f"ShellReader-{session_id}") + # Start refactored reader thread + t = threading.Thread( + target=self._reader_loop, + args=(session_id, on_event), + daemon=True, + name=f"ShellReader-{session_id}" + ) t.start() sess["thread"] = t self.sessions[session_id] = sess return sess + def _reader_loop(self, session_id, on_event): + """Internal method to handle terminal reading and protocol extraction.""" + with self.lock: + sess = self.sessions.get(session_id) + if not sess: return + + backend = sess["backend"] + while True: + try: + data = backend.read(4096) + if not data: + if not backend.is_alive(): break + time.sleep(0.05) + continue + + decoded = data if isinstance(data, str) else data.decode("utf-8", errors="replace") + + with self.lock: + active_tid = sess.get("active_task") + if active_tid and sess.get("buffer_file"): + self._process_protocol_fences(sess, active_tid, decoded) + + # Stream and Edge Intelligence + if on_event: + self._handle_ui_streaming(sess, session_id, active_tid, decoded, on_event) + + except (EOFError, OSError): break + except Exception as e: + print(f" [๐ŸšโŒ] Reader thread FATAL exception: {e}") + break + + print(f" [๐Ÿš] Shell Session Terminated: {session_id}") + with self.lock: self.sessions.pop(session_id, None) + + def _process_protocol_fences(self, sess, active_tid, decoded): + """Internal helper to handle OSC 1337 / Bracketed Task framing.""" + start_fence = f"\x1b]1337;TaskStart;id={active_tid}\x07" + end_fence_prefix = f"\x1b]1337;TaskEnd;id={active_tid};exit=" + bracket_start_fence = f"[[1337;TaskStart;id={active_tid}]]" + bracket_end_fence_prefix = f"[[1337;TaskEnd;id={active_tid};exit=" + + sess["buffer_file"].write(decoded) + sess["buffer_file"].flush() + sess["tail_buffer"] = (sess.get("tail_buffer", "") + decoded)[-16384:] + + clean_tail = ANSI_ESCAPE.sub('', sess["tail_buffer"]) + + if end_fence_prefix in clean_tail or bracket_end_fence_prefix in clean_tail: + try: + is_bracket = bracket_end_fence_prefix in clean_tail + active_end_prefix = bracket_end_fence_prefix if is_bracket else end_fence_prefix + active_start_fence = bracket_start_fence if is_bracket else start_fence + + after_end = clean_tail.split(active_end_prefix)[1] + exit_match = EXIT_CODE_PATTERN.search(after_end) + exit_code = int(exit_match.group(1)) if exit_match else 0 + + bf = sess["buffer_file"] + bf.seek(0) + clean_full_raw = ANSI_ESCAPE.sub('', bf.read()) + + # Extract content between fences + if active_start_fence in clean_full_raw: + content = clean_full_raw.split(active_start_fence)[-1].split(active_end_prefix)[0] + else: + content = clean_full_raw.split(active_end_prefix)[0] + + # Cleanup internal echo echo + content = ECHO_CLEANUP_ANSI.sub('', content) + content = ECHO_CLEANUP_BRACKET.sub('', content).strip() + + sess["result"]["stdout"] = content + sess["result"]["status"] = 0 if exit_code == 0 else 1 + + sess["buffer_file"].close() + sess["buffer_file"] = None + + if sess.get("event"): sess["event"].set() + except Exception as e: + print(f" [๐Ÿšโš ๏ธ] Protocol parsing failed: {e}") + if sess.get("event"): sess["event"].set() + + def _handle_ui_streaming(self, sess, session_id, active_tid, decoded, on_event): + """Internal helper to filter plumbing and stream terminal output to the client.""" + # Clean framing echoes from the live stream + decoded = ECHO_START_PATTERN.sub('', decoded) + decoded = ECHO_END_PATTERN.sub('', decoded) + decoded = STRIP_START_FENCE.sub('', decoded) + decoded = STRIP_BRACKET_FENCE.sub('', decoded) + + # Line-Aware Stealthing for extra safety + lines = decoded.splitlines(keepends=True) + clean_lines = [line for line in lines if not PROTOCOL_HINT_PATTERN.search(ANSI_ESCAPE.sub('', line))] + stealth_out = "".join(clean_lines) + + if stealth_out.strip(): + with self.lock: + self._apply_stream_throttling(sess, session_id, stealth_out, on_event) + self._detect_edge_prompts(sess, session_id, active_tid, on_event) + + def _apply_stream_throttling(self, sess, session_id, stealth_out, on_event): + """Protects the bridge from output flooding.""" + now = time.time() + if now - sess.get("stream_window_start", 0) > 1.0: + sess["stream_window_start"], sess["stream_bytes_sent"] = now, 0 + if sess.get("stream_dropped_bytes", 0) > 0: + drop_msg = f"\n[... {sess['stream_dropped_bytes']:,} bytes truncated from live stream ...]\n" + event = agent_pb2.SkillEvent(session_id=session_id, task_id=sess.get("active_task") or "", terminal_out=drop_msg) + on_event(agent_pb2.ClientTaskMessage(skill_event=event)) + sess["stream_dropped_bytes"] = 0 + + if sess.get("stream_bytes_sent", 0) + len(stealth_out) > 100_000: + sess["stream_dropped_bytes"] = sess.get("stream_dropped_bytes", 0) + len(stealth_out) + else: + sess["stream_bytes_sent"] += len(stealth_out) + event = agent_pb2.SkillEvent(session_id=session_id, task_id=sess.get("active_task") or "", terminal_out=stealth_out) + on_event(agent_pb2.ClientTaskMessage(skill_event=event)) + + def _detect_edge_prompts(self, sess, session_id, active_tid, on_event): + """Signals prompt detection (e.g. login: or password:) back to the client.""" + current_event = sess.get("event") + if active_tid and current_event and not current_event.is_set(): + tail = sess["tail_buffer"][-100:] + for pattern in self.PROMPT_PATTERNS: + if pattern.search(tail): + p_hint = tail[-20:].strip() + prompt_event = agent_pb2.SkillEvent(session_id=session_id, task_id=active_tid, prompt=p_hint) + on_event(agent_pb2.ClientTaskMessage(skill_event=prompt_event)) + break + def handle_transparent_tty(self, task, on_complete, on_event=None): """Processes raw TTY/Resize events synchronously.""" cmd = task.payload_json @@ -240,19 +225,16 @@ # 1. Raw Keystroke forward if isinstance(raw_payload, dict) and "tty" in raw_payload: - raw_bytes = raw_payload["tty"] sess = self._ensure_session(session_id, None, on_event) - sess["backend"].write(raw_bytes.encode("utf-8")) + sess["backend"].write(raw_payload["tty"].encode("utf-8")) on_complete(task.task_id, {"stdout": "", "status": 0}, task.trace_id) return True # 2. Window Resize if isinstance(raw_payload, dict) and raw_payload.get("action") == "resize": - cols = raw_payload.get("cols", 80) - rows = raw_payload.get("rows", 24) + cols, rows = raw_payload.get("cols", 80), raw_payload.get("rows", 24) sess = self._ensure_session(session_id, None, on_event) sess["backend"].resize(cols, rows) - print(f" [๐Ÿš] Terminal Resized to {cols}x{rows}") on_complete(task.task_id, {"stdout": f"resized to {cols}x{rows}", "status": 0}, task.trace_id) return True except Exception as pe: @@ -260,178 +242,108 @@ return False def execute(self, task, sandbox, on_complete, on_event=None): - """Dispatches command string to the abstract terminal backend and WAITS for completion.""" + """Dispatches command string to the terminal and waits for framed response.""" session_id = task.session_id or "default-session" tid = task.task_id + cmd = task.payload_json + + allowed, status_msg = sandbox.verify(cmd) + if not allowed: + err_msg = f"\r\n[System] Command blocked: {status_msg}\r\n" + if on_event: + on_event(agent_pb2.ClientTaskMessage(skill_event=agent_pb2.SkillEvent(session_id=session_id, task_id=tid, terminal_out=err_msg))) + return on_complete(tid, {"stderr": f"SANDBOX_VIOLATION: {status_msg}", "status": 2}, task.trace_id) + + cwd = self.sync_mgr.get_session_dir(task.session_id, create=True) if self.sync_mgr and task.session_id else None + sess = self._ensure_session(session_id, cwd, on_event) + + with sess["write_lock"]: + if cmd.startswith("!RAW:"): + if not tid.startswith("task-"): + sess["backend"].write((cmd[5:] + "\n").encode("utf-8")) + return on_complete(tid, {"stdout": "INJECTED", "status": 0}, task.trace_id) + cmd = cmd[5:] + + event, cancel_event = threading.Event(), threading.Event() + result_container = {"stdout": "", "status": 0} + + with self.lock: + sess["active_task"] = tid + sess["event"] = event + sess["buffer_file"] = tempfile.NamedTemporaryFile("w+", encoding="utf-8", prefix=f"cortex_task_{tid}_", delete=False) + sess["tail_buffer"] = "" + sess["result"] = result_container + sess["cancel_event"] = cancel_event + + try: + full_input = self._build_framed_command(tid, cmd) + sess["backend"].write(full_input.encode("utf-8")) + + timeout = (task.timeout_ms / 1000.0) if task.timeout_ms > 0 else 60.0 + start_time = time.time() + while time.time() - start_time < timeout: + if event.is_set(): return on_complete(tid, result_container, task.trace_id) + if cancel_event.is_set(): return on_complete(tid, {"stderr": "ABORTED", "status": 2}, task.trace_id) + time.sleep(0.1) + + on_complete(tid, {"stdout": self._get_timeout_output(sess), "stderr": "TIMEOUT", "status": 2}, task.trace_id) + finally: + self._cleanup_task_state(sess, tid, event, cancel_event) + + def _build_framed_command(self, tid, cmd): + """Constructs the shell command with protocol framing.""" + import platform + if platform.system() == "Windows": + spool_dir = os.path.join(tempfile.gettempdir(), "cortex_pty_tasks") + os.makedirs(spool_dir, exist_ok=True) + task_path = os.path.join(spool_dir, f"{tid}.bat") + with open(task_path, "w", encoding="utf-8") as f: + f.write(f"@echo off\r\necho [[1337;TaskStart;id={tid}]]\r\n{cmd}\r\necho [[1337;TaskEnd;id={tid};exit=%errorlevel%]]\r\ndel \"%~f0\"\r\n") + return f"\"{task_path}\"\r\n" + else: + return f"echo -e -n \"\\033]1337;TaskStart;id={tid}\\007\"; {cmd}; __ctx_exit=$?; echo -e -n \"\\033]1337;TaskEnd;id={tid};exit=$__ctx_exit\\007\"\n" + + def _get_timeout_output(self, sess): + """Extracts Head/Tail output from the buffer file upon task timeout.""" try: - cmd = task.payload_json - - allowed, status_msg = sandbox.verify(cmd) - if not allowed: - err_msg = f"\r\n[System] Command blocked: {status_msg}\r\n" - if on_event: - event = agent_pb2.SkillEvent( - session_id=session_id, task_id=tid, - terminal_out=err_msg - ) - on_event(agent_pb2.ClientTaskMessage(skill_event=event)) - - return on_complete(tid, {"stderr": f"SANDBOX_VIOLATION: {status_msg}", "status": 2}, task.trace_id) + with self.lock: + if not sess.get("buffer_file"): return "" + sess["buffer_file"].seek(0, 2) + f_len = sess["buffer_file"].tell() + HEAD, TAIL = 10_000, 30_000 + sess["buffer_file"].seek(0) + if f_len > HEAD + TAIL: + head = sess["buffer_file"].read(HEAD) + sess["buffer_file"].seek(f_len - TAIL) + return head + f"\n\n[... {f_len - HEAD - TAIL:,} bytes omitted ...] \n\n" + sess["buffer_file"].read() + return sess["buffer_file"].read() + except: return "" - cwd = None - if self.sync_mgr and task.session_id: - cwd = self.sync_mgr.get_session_dir(task.session_id, create=True) - elif sandbox.policy.get("WORKING_DIR_JAIL"): - cwd = sandbox.policy["WORKING_DIR_JAIL"] - if not os.path.exists(cwd): - try: os.makedirs(cwd, exist_ok=True) + def _cleanup_task_state(self, sess, tid, event, cancel_event): + """Normalizes session state after task completion or error.""" + with self.lock: + if sess.get("active_task") == tid: + if sess.get("buffer_file"): + try: sess["buffer_file"].close() except: pass - - sess = self._ensure_session(session_id, cwd, on_event) - - with sess["write_lock"]: - is_raw = cmd.startswith("!RAW:") - if is_raw: - # M7 Fix: Agentic tasks (starting with 'task-') MUST use framing - # to ensure results are captured. Forced bypass is only allowed for manual UI typing. - if tid.startswith("task-"): - cmd = cmd[5:] - is_raw = False - else: - input_str = cmd[5:] + "\n" - print(f" [๐ŸšโŒจ๏ธ] RAW Input Injection: {input_str.strip()}") - sess["backend"].write(input_str.encode("utf-8")) - return on_complete(tid, {"stdout": "INJECTED", "status": 0}, task.trace_id) - - marker_id = int(time.time()) - marker = f"__CORTEX_FIN_SH_{marker_id}__" - event = threading.Event() - cancel_event = threading.Event() # Local snapshot for thread safety - result_container = {"stdout": "", "status": 0} - - with self.lock: - sess["active_task"] = tid - sess["event"] = event - sess["buffer_file"] = tempfile.NamedTemporaryFile("w+", encoding="utf-8", prefix=f"cortex_task_{tid}_", delete=False) - sess["tail_buffer"] = "" - sess["result"] = result_container - sess["cancel_event"] = cancel_event - - try: - # M7: Protocol-Aware Command Framing (OSC 1337) - # We wrap the command in non-printable control sequences. - # Format: ESC ] 1337 ; ST (\x07) - start_marker = f"1337;TaskStart;id={tid}" - end_marker = f"1337;TaskEnd;id={tid}" - - import platform - if platform.system() == "Windows": - # M7: EncodedCommand for Windows (Bypasses Quote Hell) - # This ensures byte-accurate delivery of ESC ([char]27) and BEL ([char]7) - import base64 - - # M8: Ultimate Windows Shell Boundary Method (File Spooling) - # Bypasses Conhost VTP Redraw byte swallowing caused by line wrapping in PTY - # Bypasses powershell encoded limits. - import os - import tempfile as tf - spool_dir = os.path.join(tf.gettempdir(), "cortex_pty_tasks") - os.makedirs(spool_dir, exist_ok=True) - task_path = os.path.join(spool_dir, f"{tid}.bat") - - # We write the logic to a native shell file so the PTY simply executes a short path - with open(task_path, "w", encoding="utf-8") as f: - f.write(f"@echo off\r\n") - f.write(f"echo [[1337;TaskStart;id={tid}]]\r\n") - f.write(f"{cmd}\r\n") - f.write(f"echo [[1337;TaskEnd;id={tid};exit=%errorlevel%]]\r\n") - # optionally clean up itself - f.write(f"del \"%~f0\"\r\n") - - full_input = f"\"{task_path}\"\r\n" - else: - # On Linux, we use echo -e with octal escapes - s_m = f"\\033]{start_marker}\\007" - e_m = f"\\033]{end_marker};exit=$__ctx_exit\\007" - full_input = f"echo -e -n \"{s_m}\"; {cmd}; __ctx_exit=$?; echo -e -n \"{e_m}\"\n" - - sess["backend"].write(full_input.encode("utf-8")) - - timeout = (task.timeout_ms / 1000.0) if task.timeout_ms > 0 else 60.0 - start_time = time.time() - while time.time() - start_time < timeout: - if event.is_set(): - return on_complete(tid, result_container, task.trace_id) - if cancel_event.is_set(): - print(f" [๐Ÿš๐Ÿ›‘] Task {tid} cancelled on node.") - return on_complete(tid, {"stderr": "ABORTED", "status": 2}, task.trace_id) - time.sleep(0.1) - - print(f" [๐Ÿšโš ๏ธ] Task {tid} timed out on node.") - with self.lock: - if sess.get("buffer_file"): - try: - sess["buffer_file"].seek(0, 2) - file_len = sess["buffer_file"].tell() - HEAD, TAIL = 10_000, 30_000 - if file_len > HEAD + TAIL: - sess["buffer_file"].seek(0) - head_str = sess["buffer_file"].read(HEAD) - sess["buffer_file"].seek(file_len - TAIL) - tail_str = sess["buffer_file"].read() - omitted = file_len - HEAD - TAIL - partial_out = head_str + f"\n\n[... {omitted:,} bytes omitted (full timeout output saved to {sess['buffer_file'].name}) ...]\n\n" + tail_str - else: - sess["buffer_file"].seek(0) - partial_out = sess["buffer_file"].read() - except: - partial_out = "" - else: - partial_out = "" - - on_complete(tid, {"stdout": partial_out, "stderr": "TIMEOUT", "status": 2}, task.trace_id) - - finally: - with self.lock: - if sess.get("active_task") == tid: - if sess.get("buffer_file"): - try: - sess["buffer_file"].close() - except: pass - sess["buffer_file"] = None - sess["active_task"] = None - sess["marker"] = None - if sess.get("event") == event: - sess["event"] = None - sess["result"] = None - if sess.get("cancel_event") == cancel_event: - sess["cancel_event"] = None - - except Exception as e: - print(f" [๐ŸšโŒ] Execute Error for {tid}: {e}") - on_complete(tid, {"stderr": str(e), "status": 2}, task.trace_id) + sess["buffer_file"] = None + sess["active_task"] = None + if sess.get("event") == event: sess["event"] = None + if sess.get("cancel_event") == cancel_event: sess["cancel_event"] = None def cancel(self, task_id: str): """Cancels an active task โ€” for persistent shell, this sends a SIGINT (Ctrl+C).""" with self.lock: for sid, sess in self.sessions.items(): if sess.get("active_task") == task_id: - print(f"[๐Ÿ›‘] Sending SIGINT (Ctrl+C) to shell session (Task {task_id}): {sid}") sess["backend"].write(b"\x03") - if sess.get("cancel_event"): - sess["cancel_event"].set() + if sess.get("cancel_event"): sess["cancel_event"].set() return True def shutdown(self): """Cleanup: Terminates all persistent shells via backends.""" with self.lock: for sid, sess in list(self.sessions.items()): - print(f"[๐Ÿ›‘] Cleaning up persistent shell: {sid}") try: sess["backend"].kill() except: pass self.sessions.clear() - -import os -import threading -import time -import tempfile diff --git a/ai-hub/app/api/dependencies.py b/ai-hub/app/api/dependencies.py index d44f16b..eadbcfb 100644 --- a/ai-hub/app/api/dependencies.py +++ b/ai-hub/app/api/dependencies.py @@ -95,6 +95,7 @@ self.auth_service = None self.preference_service = None self.agent_scheduler = None + self.agent_service = None def with_service(self, name: str, service: Any): """ diff --git a/ai-hub/app/api/routes/agents.py b/ai-hub/app/api/routes/agents.py index d458d06..8f5dc4f 100644 --- a/ai-hub/app/api/routes/agents.py +++ b/ai-hub/app/api/routes/agents.py @@ -21,181 +21,29 @@ def create_agents_router(services: ServiceContainer) -> APIRouter: router = APIRouter() - def _workspace_id_from_jail(jail_path: str | None, fallback_session_id: int | None = None) -> str: - """Derive a stable workspace ID from jail path when possible.""" - if jail_path: - normalized = jail_path.rstrip("/") - base = os.path.basename(normalized) - if base: - return base - if fallback_session_id is not None: - return f"session-{fallback_session_id}" - return f"agent-{uuid.uuid4().hex[:8]}" - - def _ensure_agent_workspace_binding(instance: AgentInstance, db: Session): - """ - Keep session sync workspace and agent jail aligned. - This heals legacy records where sync_workspace_id is null/mismatched. - """ - if not instance or not instance.session: - return - - desired_workspace_id = _workspace_id_from_jail(instance.current_workspace_jail, instance.session_id) - desired_jail = f"/tmp/cortex/{desired_workspace_id}/" - - changed = False - if instance.session.sync_workspace_id != desired_workspace_id: - instance.session.sync_workspace_id = desired_workspace_id - changed = True - - if instance.current_workspace_jail != desired_jail: - instance.current_workspace_jail = desired_jail - changed = True - - if changed: - db.flush() - try: - orchestrator = getattr(services, "orchestrator", None) - if orchestrator and instance.mesh_node_id: - orchestrator.assistant.push_workspace(instance.mesh_node_id, desired_workspace_id) - orchestrator.assistant.control_sync(instance.mesh_node_id, desired_workspace_id, action="START") - orchestrator.assistant.control_sync(instance.mesh_node_id, desired_workspace_id, action="UNLOCK") - except Exception as e: - logging.error(f"Failed to heal workspace binding for agent {instance.id}: {e}") - @router.get("", response_model=List[AgentInstanceResponse]) def get_agents(current_user: models.User = Depends(get_current_user), db: Session = Depends(get_db)): - agents = db.query(AgentInstance).options(joinedload(AgentInstance.template), joinedload(AgentInstance.session)).filter( - AgentInstance.user_id == current_user.id - ).all() - changed = False - for instance in agents: - before_sync = instance.session.sync_workspace_id if instance.session else None - before_jail = instance.current_workspace_jail - _ensure_agent_workspace_binding(instance, db) - after_sync = instance.session.sync_workspace_id if instance.session else None - after_jail = instance.current_workspace_jail - if before_sync != after_sync or before_jail != after_jail: - changed = True - - if changed: - db.commit() - return agents + return services.agent_service.list_user_agents(db, current_user.id) @router.get("/{id}", response_model=AgentInstanceResponse) def get_agent(id: str, current_user: models.User = Depends(get_current_user), db: Session = Depends(get_db)): - instance = db.query(AgentInstance).options(joinedload(AgentInstance.template), joinedload(AgentInstance.session)).filter( - AgentInstance.id == id, - AgentInstance.user_id == current_user.id - ).first() - if not instance: - raise HTTPException(status_code=404, detail="Agent not found") - _ensure_agent_workspace_binding(instance, db) - db.commit() - return instance + return services.agent_service.get_agent_instance(db, id, current_user.id) @router.post("/templates", response_model=AgentTemplateResponse) def create_template(request: AgentTemplateCreate, current_user: models.User = Depends(get_current_user), db: Session = Depends(get_db)): - template = AgentTemplate(**request.model_dump()) - template.user_id = current_user.id - db.add(template) - db.commit() - db.refresh(template) - return template + return services.agent_service.create_template(db, current_user.id, request) @router.post("/instances", response_model=AgentInstanceResponse) def create_instance(request: AgentInstanceCreate, current_user: models.User = Depends(get_current_user), db: Session = Depends(get_db)): - # Verify template exists - template = db.query(AgentTemplate).filter(AgentTemplate.id == request.template_id).first() - if not template: - raise HTTPException(status_code=404, detail="Template not found") - - instance = AgentInstance(**request.model_dump()) - instance.user_id = current_user.id - db.add(instance) - db.commit() - db.refresh(instance) - return instance + return services.agent_service.create_instance(db, current_user.id, request) @router.patch("/{id}/status", response_model=AgentInstanceResponse) def update_status(id: str, request: AgentInstanceStatusUpdate, current_user: models.User = Depends(get_current_user), db: Session = Depends(get_db)): - instance = db.query(AgentInstance).filter( - AgentInstance.id == id, - AgentInstance.user_id == current_user.id - ).first() - if not instance: - raise HTTPException(status_code=404, detail="Instance not found") - - instance.status = request.status - if request.status == "idle": - instance.last_error = None - instance.evaluation_status = None - db.commit() - db.refresh(instance) - return instance + return services.agent_service.update_status(db, id, current_user.id, request.status) @router.patch("/{id}/config", response_model=AgentInstanceResponse) def update_config(id: str, request: schemas.AgentConfigUpdate, current_user: models.User = Depends(get_current_user), db: Session = Depends(get_db)): - from app.db.models.session import Session as SessionModel - - instance = db.query(AgentInstance).filter( - AgentInstance.id == id, - AgentInstance.user_id == current_user.id - ).first() - if not instance: - raise HTTPException(status_code=404, detail="Instance not found") - - template = db.query(AgentTemplate).filter(AgentTemplate.id == instance.template_id).first() - - if request.name is not None and template: - template.name = request.name - if request.system_prompt is not None and template: - template.system_prompt_path = request.system_prompt - if request.max_loop_iterations is not None and template: - template.max_loop_iterations = request.max_loop_iterations - if request.co_worker_quality_gate is not None and template: - template.co_worker_quality_gate = request.co_worker_quality_gate - if request.rework_threshold is not None and template: - template.rework_threshold = request.rework_threshold - if request.max_rework_attempts is not None and template: - template.max_rework_attempts = request.max_rework_attempts - - if request.mesh_node_id is not None: - instance.mesh_node_id = request.mesh_node_id - - # Update the Session overriding prompt so the running loop picks it up instantly! - if instance.session_id: - session = db.query(SessionModel).filter(SessionModel.id == instance.session_id).first() - if session: - if request.system_prompt is not None: - session.system_prompt_override = request.system_prompt - if hasattr(request, 'provider_name') and request.provider_name is not None: - session.provider_name = request.provider_name - if hasattr(request, 'model_name') and request.model_name is not None: - session.model_name = request.model_name - if request.mesh_node_id is not None: - old_nodes = session.attached_node_ids or [] - if not old_nodes or request.mesh_node_id not in old_nodes or len(old_nodes) > 1: - try: - services.session_service.attach_nodes(db, session.id, schemas.NodeAttachRequest(node_ids=[request.mesh_node_id] if request.mesh_node_id else [])) - except Exception as e: - logging.error(f"Failed to attach session node: {e}") - else: - session.attached_node_ids = [request.mesh_node_id] if request.mesh_node_id else [] - if hasattr(request, 'restrict_skills') and request.restrict_skills is not None: - session.restrict_skills = request.restrict_skills - if hasattr(request, 'allowed_skill_ids') and request.allowed_skill_ids is not None: - from app.db.models.asset import Skill - skills = db.query(Skill).filter(Skill.id.in_(request.allowed_skill_ids)).all() - session.skills = skills - if hasattr(request, 'is_locked') and request.is_locked is not None: - session.is_locked = request.is_locked - if hasattr(request, 'auto_clear_history') and request.auto_clear_history is not None: - session.auto_clear_history = request.auto_clear_history - - db.commit() - db.refresh(instance) - return instance + return services.agent_service.update_config(db, id, current_user.id, request) @router.post("/{id}/webhook") async def webhook_receiver(id: str, payload: dict, background_tasks: BackgroundTasks, response: Response, token: str = None, sync: bool = False, skip_coworker: bool = False, db: Session = Depends(get_db)): @@ -335,130 +183,15 @@ } @router.post("/deploy") - def deploy_agent( - request: schemas.DeployAgentRequest, - background_tasks: BackgroundTasks, - current_user: models.User = Depends(get_current_user), - db: Session = Depends(get_db) - ): - """ - One-click agent deployment (Design Doc CUJ 1). - Atomically creates: Template โ†’ Session โ†’ Instance โ†’ Locks Session โ†’ Injects initial prompt โ†’ Starts loop. - """ - from app.db import models as db_models - - # 1. Create Template - template = AgentTemplate( - name=request.name, - description=request.description, - system_prompt_path=request.system_prompt, - user_id=current_user.id, - max_loop_iterations=request.max_loop_iterations, - co_worker_quality_gate=request.co_worker_quality_gate, - rework_threshold=request.rework_threshold, - max_rework_attempts=request.max_rework_attempts - ) - db.add(template) - db.flush() - - # Resolve default provider mapping if user didn't select one - resolved_provider = request.provider_name - if not resolved_provider: - sys_prefs = services.user_service.get_system_settings(db) - from app.config import settings - resolved_provider = sys_prefs.get('llm', {}).get('active_provider', settings.ACTIVE_LLM_PROVIDER) - - # 2. Create a locked Session for the agent - new_session = db_models.Session( - user_id=current_user.id, - provider_name=resolved_provider, - feature_name="agent_harness", - is_locked=True, - system_prompt_override=request.system_prompt, - attached_node_ids=[request.mesh_node_id] if getattr(request, "mesh_node_id", None) else [] - ) - db.add(new_session) - db.flush() - - workspace_id = f"agent_{template.id[:8]}" - workspace_jail = f"/tmp/cortex/{workspace_id}/" - new_session.sync_workspace_id = workspace_id - db.flush() - - # 2.5: Inject node into Orchestrator to ensure mirror works locally & remotely - try: - orchestrator = getattr(services, "orchestrator", None) - if orchestrator and request.mesh_node_id: - # Same logic as session attach_nodes config.source="empty" - orchestrator.assistant.push_workspace(request.mesh_node_id, new_session.sync_workspace_id) - orchestrator.assistant.control_sync(request.mesh_node_id, new_session.sync_workspace_id, action="START") - orchestrator.assistant.control_sync(request.mesh_node_id, new_session.sync_workspace_id, action="UNLOCK") - except Exception as e: - import logging - logging.error(f"Failed to bootstrap Orchestrator Sync for Agent Deploy: {e}") - - # 3. Create AgentInstance - instance = AgentInstance( - template_id=template.id, - user_id=current_user.id, - session_id=new_session.id, - mesh_node_id=request.mesh_node_id, - status="idle", - current_workspace_jail=workspace_jail - ) - db.add(instance) - db.flush() - - # 4. Create primary trigger if specified - trigger = AgentTrigger( - instance_id=instance.id, - trigger_type=request.trigger_type or "manual", - cron_expression=request.cron_expression, - interval_seconds=request.interval_seconds, - default_prompt=request.default_prompt - ) - if trigger.trigger_type == "webhook": - import secrets - trigger.webhook_secret = secrets.token_hex(16) - db.add(trigger) - db.flush() - - # 5. Kick off agent loop if initial prompt was provided - # (Message insertion is handled automatically by the RAG service execution) + def deploy_agent(request: schemas.DeployAgentRequest, background_tasks: BackgroundTasks, current_user: models.User = Depends(get_current_user), db: Session = Depends(get_db)): + result = services.agent_service.deploy_agent(db, current_user.id, request) if request.initial_prompt: - instance.status = "active" - db.commit() - - async def run_wrapper(): - await AgentExecutor.run(instance.id, request.initial_prompt, services, services.user_service) - - background_tasks.add_task(run_wrapper) - else: - db.commit() + background_tasks.add_task(AgentExecutor.run, result["instance_id"], request.initial_prompt, services, services.user_service) + return result - return { - "template_id": template.id, - "template_name": template.name, - "instance_id": instance.id, - "session_id": new_session.id, - "sync_workspace_id": new_session.sync_workspace_id, - "status": instance.status, - "workspace_jail": workspace_jail, - "message": f"Agent '{request.name}' deployed successfully" - } @router.delete("/{id}") def delete_agent(id: str, current_user: models.User = Depends(get_current_user), db: Session = Depends(get_db)): - from app.db.models.agent import AgentInstance - instance = db.query(AgentInstance).filter( - AgentInstance.id == id, - AgentInstance.user_id == current_user.id - ).first() - if not instance: - raise HTTPException(status_code=404, detail="Agent not found") - - # Stop the agent loop if it was active by deleting it (the loop will hit a None instance and return) - db.delete(instance) - db.commit() + services.agent_service.delete_agent(db, id, current_user.id) return {"message": "Agent deleted successfully"} return router diff --git a/ai-hub/app/api/routes/nodes.py b/ai-hub/app/api/routes/nodes.py index fb64dcd..c4a84ba 100644 --- a/ai-hub/app/api/routes/nodes.py +++ b/ai-hub/app/api/routes/nodes.py @@ -67,135 +67,43 @@ # ================================================================== @router.post("/admin", response_model=schemas.AgentNodeAdminDetail, summary="[Admin] Register New Node") - def admin_create_node( - request: schemas.AgentNodeCreate, - admin_id: str = Query(...), - db: Session = Depends(get_db) - ): - """ - Admin registers a new Agent Node. - Returns the node record including a generated invite_token that must be - placed in the node's config YAML before deployment. - """ + def admin_create_node(request: schemas.AgentNodeCreate, admin_id: str = Query(...), db: Session = Depends(get_db)): _require_admin(admin_id, db) - - existing = db.query(models.AgentNode).filter( - models.AgentNode.node_id == request.node_id - ).first() - if existing: - raise HTTPException(status_code=409, detail=f"Node '{request.node_id}' already exists.") - - # Generate a cryptographically secure invite token - invite_token = secrets.token_urlsafe(32) - - node = models.AgentNode( - node_id=request.node_id, - display_name=request.display_name, - description=request.description, - registered_by=admin_id, - skill_config=request.skill_config.model_dump(), - invite_token=invite_token, - last_status="offline", - ) - db.add(node) - db.commit() - db.refresh(node) - + node = services.mesh_service.register_node(request, admin_id, db) logger.info(f"[admin] Created node '{request.node_id}' by admin {admin_id}") - return _node_to_admin_detail(node, _registry()) + return services.mesh_service.node_to_admin_detail(node) @router.get("/admin", response_model=list[schemas.AgentNodeAdminDetail], summary="[Admin] List All Nodes") def admin_list_nodes(admin_id: str = Query(...), db: Session = Depends(get_db)): - """Full node list for admin dashboard, including invite_token and skill config.""" _require_admin(admin_id, db) - nodes = db.query(models.AgentNode).all() - return [_node_to_admin_detail(n, _registry()) for n in nodes] + return [services.mesh_service.node_to_admin_detail(n) for n in db.query(models.AgentNode).all()] @router.get("/admin/{node_id}", response_model=schemas.AgentNodeAdminDetail, summary="[Admin] Get Node Detail") def admin_get_node(node_id: str, admin_id: str = Query(...), db: Session = Depends(get_db)): _require_admin(admin_id, db) - node = _get_node_or_404(node_id, db) - return _node_to_admin_detail(node, _registry()) + node = services.mesh_service.get_node_or_404(node_id, db) + return services.mesh_service.node_to_admin_detail(node) @router.patch("/admin/{node_id}", response_model=schemas.AgentNodeAdminDetail, summary="[Admin] Update Node Config") - def admin_update_node( - node_id: str, - update: schemas.AgentNodeUpdate, - admin_id: str = Query(...), - db: Session = Depends(get_db) - ): - """Update display_name, description, skill_config toggles, or is_active.""" + def admin_update_node(node_id: str, update: schemas.AgentNodeUpdate, admin_id: str = Query(...), db: Session = Depends(get_db)): _require_admin(admin_id, db) - node = _get_node_or_404(node_id, db) - - if update.display_name is not None: - node.display_name = update.display_name - if update.description is not None: - node.description = update.description - if update.skill_config is not None: - node.skill_config = update.skill_config.model_dump() - # M6: Push policy live to the node if it's connected - try: - services.orchestrator.push_policy(node_id, node.skill_config) - except Exception as e: - logger.warning(f"Could not push live policy to {node_id}: {e}") - - if update.is_active is not None: - node.is_active = update.is_active - - db.commit() - db.refresh(node) - return _node_to_admin_detail(node, _registry()) + node = services.mesh_service.update_node(node_id, update, db) + return services.mesh_service.node_to_admin_detail(node) @router.delete("/admin/{node_id}", summary="[Admin] Deregister Node") - def admin_delete_node( - node_id: str, - admin_id: str = Query(...), - db: Session = Depends(get_db) - ): - """Delete a node registration and all its access grants.""" + def admin_delete_node(node_id: str, admin_id: str = Query(...), db: Session = Depends(get_db)): _require_admin(admin_id, db) - node = _get_node_or_404(node_id, db) - - # Deregister from live memory if online - _registry().deregister(node_id) - - db.delete(node) - db.commit() + services.mesh_service.delete_node(node_id, db) return {"status": "success", "message": f"Node {node_id} deleted"} @router.post("/admin/{node_id}/access", response_model=schemas.NodeAccessResponse, summary="[Admin] Grant Group Access") - def admin_grant_access( - node_id: str, - grant: schemas.NodeAccessGrant, - admin_id: str = Query(...), - db: Session = Depends(get_db) - ): - """Grant a group access to use this node in sessions.""" + def admin_grant_access(node_id: str, grant: schemas.NodeAccessGrant, admin_id: str = Query(...), db: Session = Depends(get_db)): _require_admin(admin_id, db) - node = _get_node_or_404(node_id, db) - - existing = db.query(models.NodeGroupAccess).filter( + services.mesh_service.grant_access(node_id, grant, admin_id, db) + return db.query(models.NodeGroupAccess).filter( models.NodeGroupAccess.node_id == node_id, models.NodeGroupAccess.group_id == grant.group_id ).first() - if existing: - existing.access_level = grant.access_level - existing.granted_by = admin_id - db.commit() - db.refresh(existing) - return existing - - access = models.NodeGroupAccess( - node_id=node_id, - group_id=grant.group_id, - access_level=grant.access_level, - granted_by=admin_id, - ) - db.add(access) - db.commit() - db.refresh(access) - return access @router.delete("/admin/{node_id}/access/{group_id}", summary="[Admin] Revoke Group Access") def admin_revoke_access( @@ -225,56 +133,40 @@ Use this to resolve 'zombie' nodes or flapping connections. """ _require_admin(admin_id, db) - - # 1. Reset DB - _registry().reset_all_statuses() - - # 2. Clear Memory - count = _registry().clear_memory_cache() - + count = _registry().emergency_reset() logger.warning(f"[Admin] Mesh Reset triggered by {admin_id}. Cleared {count} live nodes.") return {"status": "success", "cleared_count": count} + @router.post("/purge", summary="Node Self-Purge") + def node_self_purge( + node_id: str = Query(...), + token: str = Query(...), + db: Session = Depends(get_db) + ): + """ + Allows a node to deregister itself using its invite_token. + Called by the purge.py script during uninstallation. + """ + node = db.query(models.AgentNode).filter( + models.AgentNode.node_id == node_id, + models.AgentNode.invite_token == token + ).first() + + if not node: + raise HTTPException(status_code=401, detail="Invalid node or token.") + + node_id = node.node_id + services.mesh_service.delete_node(node_id, db) + logger.info(f"[Mesh] Node '{node_id}' successfully purged itself.") + return {"status": "success", "message": f"Node {node_id} deregistered."} + # ================================================================== # USER-FACING ENDPOINTS # ================================================================== @router.get("/", response_model=list[schemas.AgentNodeUserView], summary="List Accessible Nodes") - def list_accessible_nodes( - user_id: str = Query(...), - db: Session = Depends(get_db) - ): - """ - Returns nodes the calling user's group has access to. - Merges live connection state from the in-memory registry. - """ - user = db.query(models.User).filter(models.User.id == user_id).first() - if not user: - raise HTTPException(status_code=404, detail="User not found.") - - # Admins see all active nodes for management/configuration purposes. - # Regular users only see nodes explicitly granted to their group. - if user.role == "admin": - nodes = db.query(models.AgentNode).filter(models.AgentNode.is_active == True).all() - else: - # Nodes accessible via user's group (relational) - accesses = db.query(models.NodeGroupAccess).filter( - models.NodeGroupAccess.group_id == user.group_id - ).all() - node_ids = set([a.node_id for a in accesses]) - - # Nodes accessible via group policy whitelist - if user.group and user.group.policy: - policy_nodes = user.group.policy.get("nodes", []) - if isinstance(policy_nodes, list): - for nid in policy_nodes: - node_ids.add(nid) - - nodes = db.query(models.AgentNode).filter( - models.AgentNode.node_id.in_(list(node_ids)), - models.AgentNode.is_active == True - ).all() - + def list_accessible_nodes(user_id: str = Query(...), db: Session = Depends(get_db)): + nodes = services.mesh_service.list_accessible_nodes(user_id, db) registry = _registry() return [services.mesh_service.node_to_user_view(n, registry) for n in nodes] @@ -312,48 +204,11 @@ } @router.post("/{node_id}/dispatch", response_model=schemas.NodeDispatchResponse, summary="Dispatch Task to Node") - def dispatch_to_node( - node_id: str, - request: schemas.NodeDispatchRequest, - user_id: str = Query(...), - db: Session = Depends(get_db) - ): - """ - Queue a shell task to an online node. - Emits task_assigned immediately so the live UI shows it. - """ - _require_node_access(user_id, node_id, db) - registry = _registry() - live = registry.get_node(node_id) - if not live: - raise HTTPException(status_code=503, detail=f"Node '{node_id}' is not connected.") - - task_id = request.task_id or str(uuid.uuid4()) - - # M6: Use the integrated Protobufs & Crypto from app/core/grpc - from app.protos import agent_pb2 - from app.core.grpc.utils.crypto import sign_payload - - payload = request.command - registry.emit(node_id, "task_assigned", - {"command": request.command, "session_id": request.session_id}, - task_id=task_id) - - try: - task_req = agent_pb2.TaskRequest( - task_id=task_id, - payload_json=payload, - signature=sign_payload(payload), - timeout_ms=request.timeout_ms, - session_id=request.session_id or "", - ) - # Push directly to the node's live gRPC outbound queue (Priority 1 for interactive) - live.send_message(agent_pb2.ServerTaskMessage(task_request=task_req), priority=1) - registry.emit(node_id, "task_start", {"command": request.command}, task_id=task_id) - except Exception as e: - logger.error(f"[nodes/dispatch] Failed to put task onto queue for {node_id}: {e}") - raise HTTPException(status_code=500, detail="Internal Dispatch Error") - + def dispatch_to_node(node_id: str, request: schemas.NodeDispatchRequest, user_id: str = Query(...), db: Session = Depends(get_db)): + task_id = services.mesh_service.dispatch_task( + node_id, request.command, user_id, db, + session_id=request.session_id, task_id=request.task_id, timeout_ms=request.timeout_ms + ) return schemas.NodeDispatchResponse(task_id=task_id, status="accepted") @router.post("/{node_id}/cancel", summary="Cancel/Interrupt Task on Node") @@ -410,38 +265,21 @@ - @router.get( - "/admin/{node_id}/config.yaml", - response_model=schemas.NodeConfigYamlResponse, - summary="[Admin] Download Node Config YAML", - ) + @router.get("/admin/{node_id}/config.yaml", response_model=schemas.NodeConfigYamlResponse, summary="[Admin] Download Node Config YAML") def download_node_config_yaml(node_id: str, admin_id: str = Query(...), db: Session = Depends(get_db)): - """ - Generate and return the agent_config.yaml content an admin downloads - and places alongside the node client software before deployment. - """ _require_admin(admin_id, db) - node = _get_node_or_404(node_id, db) - config_yaml = _generate_node_config_yaml(node, None, db) + node = services.mesh_service.get_node_or_404(node_id, db) + config_yaml = services.mesh_service.generate_node_config_yaml(node) return schemas.NodeConfigYamlResponse(node_id=node_id, config_yaml=config_yaml) + @router.get("/provision/{node_id}", summary="Headless Provisioning Script (Python)") def provision_node(node_id: str, token: str, request: Request, db: Session = Depends(get_db)): - """ - Returns a Python script that can be piped into python3 to automatically - install and start the python-source agent node. - - Usage: curl -sSL https://.../provision/{node_id}?token={token} | python3 - """ from fastapi.responses import PlainTextResponse node = db.query(models.AgentNode).filter(models.AgentNode.node_id == node_id).first() - if not node or node.invite_token != token: - raise HTTPException(status_code=403, detail="Invalid node or token.") - - config_yaml = _generate_node_config_yaml(node, None, db) + if not node or node.invite_token != token: raise HTTPException(status_code=403, detail="Invalid node or token.") + config_yaml = services.mesh_service.generate_node_config_yaml(node) base_url = f"{request.url.scheme}://{request.url.netloc}" - - script = services.mesh_service.generate_provisioning_script(node, config_yaml, base_url) - return PlainTextResponse(script) + return PlainTextResponse(services.mesh_service.generate_provisioning_script(node, config_yaml, base_url)) @router.get("/provision/sh/{node_id}", summary="Headless Provisioning Script (Bash Binary)") def provision_node_sh(node_id: str, token: str, request: Request, db: Session = Depends(get_db)): @@ -483,57 +321,7 @@ @router.get("/provision/binary/{node_id}/{arch}", summary="Download Self-Contained Binary ZIP Bundle") def provision_node_binary_bundle(node_id: str, arch: str, token: str, db: Session = Depends(get_db)): - """ - Dynamically zips the specified architecture's cortex-agent binary - with the autogenerated agent_config.yaml for secure, direct GUI downloads. - """ - import io - import zipfile - - node = db.query(models.AgentNode).filter(models.AgentNode.node_id == node_id).first() - if not node or node.invite_token != token: - raise HTTPException(status_code=403, detail="Invalid node or token.") - - # 1. Locate the requested architecture binary - from app.api.routes.agent_update import _AGENT_NODE_DIR - binary_path = os.path.join(_AGENT_NODE_DIR, "dist", arch, "cortex-agent") - - # If binary wasn't built for this arch natively, fallback to notifying user - if not os.path.exists(binary_path): - raise HTTPException(status_code=404, detail=f"Binary for {arch} is not compiled on hub.") - - # 2. Generate config - config_yaml = _generate_node_config_yaml(node, None, db) - - # 3. Create ZIP in memory - zip_buffer = io.BytesIO() - with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED, False) as zip_file: - - # Add executable binary with execution permissions - binary_info = zipfile.ZipInfo("cortex-agent") - binary_info.external_attr = 0o100755 << 16 # -rwxr-xr-x - binary_info.compress_type = zipfile.ZIP_DEFLATED - with open(binary_path, "rb") as f: - zip_file.writestr(binary_info, f.read()) - - # Add pre-configured yaml - zip_file.writestr("agent_config.yaml", config_yaml) - - # Add a daemon install script for comfort - start_sh_script = services.mesh_service.get_template_content("start_daemon.sh.j2") - if not start_sh_script: - start_sh_script = "#!/bin/bash\n./cortex-agent\n" - - script_info = zipfile.ZipInfo("start.sh") - script_info.external_attr = 0o100755 << 16 - zip_file.writestr(script_info, start_sh_script) - - zip_buffer.seek(0) - return StreamingResponse( - zip_buffer, - media_type="application/x-zip-compressed", - headers={"Content-Disposition": f"attachment; filename=cortex-node-{node_id}-{arch}.zip"} - ) + return services.mesh_service.download_binary_bundle(node_id, arch, token, db) @router.get("/provision/binaries/status", summary="Check binary availability status") def get_binaries_status(db: Session = Depends(get_db)): @@ -547,93 +335,8 @@ return {"available_architectures": available} @router.get("/admin/{node_id}/download", summary="[Admin] Download Agent Node Bundle (ZIP)") - def admin_download_bundle( - node_id: str, - admin_id: str = Query(...), - db: Session = Depends(get_db) - ): - """ - Bundles the entire Agent Node source code along with a pre-configured - agent_config.yaml into a single ZIP file for the user to download. - """ - import io - import zipfile - - _require_admin(admin_id, db) - node = _get_node_or_404(node_id, db) - config_yaml = _generate_node_config_yaml(node, None, db) - - # Create ZIP in memory - zip_buffer = io.BytesIO() - with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED, False) as zip_file: - # 1. Add Agent Node files (source, scripts, and future binary) - # Try production mount first, then environment overrides, then local fallback - source_dirs = [ - "/app/agent-node-source", - "/app/agent-node", - os.environ.get("AGENT_NODE_SRC_DIR", "../agent-node") - ] - found_dir = None - for sd in source_dirs: - if os.path.exists(sd): - found_dir = sd - break - - if found_dir: - for root, dirs, files in os.walk(found_dir): - # Exclude sensitive/build metadata - dirs[:] = [d for d in dirs if d not in ["__pycache__", ".git", ".venv"]] - - for file in files: - # Exclude only instance-specific or sensitive config - if file == ".env" or file == "agent_config.yaml": - continue - - file_path = os.path.join(root, file) - rel_path = os.path.relpath(file_path, found_dir) - zip_file.write(file_path, rel_path) - - # 2. Add skills from mapped directory or fallback - skills_dir = os.environ.get("SKILLS_SRC_DIR", "/app/skills") - if not os.path.exists(skills_dir) and os.path.exists("../skills"): - skills_dir = "../skills" - - if os.path.exists(skills_dir): - for root, dirs, files in os.walk(skills_dir): - dirs[:] = [d for d in dirs if d != "__pycache__"] - for file in files: - file_path = os.path.join(root, file) - rel_path = os.path.join("skills", os.path.relpath(file_path, skills_dir)) - zip_file.write(file_path, rel_path) - - # 3. Add the generated config YAML as 'agent_config.yaml' - zip_file.writestr("agent_config.yaml", config_yaml) - - # 4. Add README and run.sh / run.bat / run_mac.command - zip_file.writestr("README.md", services.mesh_service.get_template_content("README.md.j2")) - - # Create run.sh with execute permissions (external_attr) - run_sh_info = zipfile.ZipInfo("run.sh") - run_sh_info.external_attr = 0o100755 << 16 # -rwxr-xr-x - run_sh_info.compress_type = zipfile.ZIP_DEFLATED - run_sh_content = services.mesh_service.get_template_content("run.sh.j2") - zip_file.writestr(run_sh_info, run_sh_content) - - # Create run_mac.command (Mac double-clickable) - run_mac_info = zipfile.ZipInfo("run_mac.command") - run_mac_info.external_attr = 0o100755 << 16 # -rwxr-xr-x - run_mac_info.compress_type = zipfile.ZIP_DEFLATED - zip_file.writestr(run_mac_info, run_sh_content) - - # Create run.bat - zip_file.writestr("run.bat", services.mesh_service.get_template_content("run.bat.j2")) - - zip_buffer.seek(0) - return StreamingResponse( - zip_buffer, - media_type="application/x-zip-compressed", - headers={"Content-Disposition": f"attachment; filename=cortex-node-{node_id}.zip"} - ) + def admin_download_bundle(node_id: str, admin_id: str = Query(...), db: Session = Depends(get_db)): + return services.mesh_service.download_admin_bundle(node_id, admin_id, db) # ================================================================== @@ -642,35 +345,12 @@ @router.post("/validate-token", summary="[Internal] Validate Node Invite Token") def validate_invite_token(token: str, node_id: str, db: Session = Depends(get_db)): - """ - Internal HTTP endpoint called by the gRPC SyncConfiguration handler - to validate an invite_token before accepting a node connection. - - Returns the node's skill_config (sandbox policy) on success so the - gRPC server can populate the SandboxPolicy response. - - Response: - 200 { valid: true, node_id, skill_config, display_name } - 401 { valid: false, reason } - """ - node = db.query(models.AgentNode).filter( - models.AgentNode.node_id == node_id, - models.AgentNode.invite_token == token, - models.AgentNode.is_active == True, - ).first() - - if not node: - logger.warning(f"[M4] Token validation FAILED for node_id='{node_id}'") - return {"valid": False, "reason": "Invalid token or unknown node."} - - logger.info(f"[M4] Token validated OK for node_id='{node_id}'") - return { - "valid": True, - "node_id": node.node_id, - "display_name": node.display_name, - "user_id": node.registered_by, # AgentNode has registered_by, not user_id - "skill_config": node.skill_config or {}, - } + result = services.mesh_service.validate_invite_token(token, node_id, db) + if not result["valid"]: + logger.warning(f"[M4] Token validation FAILED for node_id='{node_id}': {result.get('reason')}") + else: + logger.info(f"[M4] Token validated OK for node_id='{node_id}'") + return result # ================================================================== # WEBSOCKET โ€” Single-node live event stream @@ -1230,123 +910,7 @@ return router -# =========================================================================== -# Helpers -# =========================================================================== - -def _generate_node_config_yaml(node: models.AgentNode, skill_overrides: dict = None, db: Session = None) -> str: - """Helper to generate the agent_config.yaml content.""" - from app.config import settings - - # Use dynamically configured Swarm Admin settings or fallbacks - hub_url = settings.GRPC_EXTERNAL_ENDPOINT or os.getenv("HUB_PUBLIC_URL", "http://127.0.0.1:8000") - hub_grpc = settings.GRPC_TARGET_ORIGIN or os.getenv("HUB_GRPC_ENDPOINT", "127.0.0.1:50051") - - secret_key = os.getenv("SECRET_KEY", "dev-secret-key-1337") - - - skill_cfg = node.skill_config or {} - if isinstance(skill_cfg, str): - try: - skill_cfg = json.loads(skill_cfg) - except Exception: - skill_cfg = {} - - if skill_overrides: - for skill, cfg in skill_overrides.items(): - if skill not in skill_cfg: - skill_cfg[skill] = {} - skill_cfg[skill].update(cfg) - - lines = [ - "# Cortex Hub - Agent Node Configuration", - f"# Generated for node '{node.node_id}' - keep this file secret.", - "", - f"node_id: \"{node.node_id}\"", - f"node_description: \"{node.display_name}\"", - "", - "# Hub connection", - f"hub_url: \"{hub_url}\"", - f"grpc_endpoint: \"{hub_grpc}\"", - "", - "# Authentication โ€” do NOT share these secrets", - f"invite_token: \"{node.invite_token}\"", - f"auth_token: \"{node.invite_token}\"", - "", - "# HMAC signing key โ€” must match the hub's SECRET_KEY exactly", - f"secret_key: \"{secret_key}\"", - "", - "# Skill configuration (mirrors admin settings; node respects these at startup)", - "skills:", - ] - for skill, cfg in skill_cfg.items(): - if not isinstance(cfg, dict): - continue - enabled = cfg.get("enabled", True) - lines.append(f" {skill}:") - lines.append(f" enabled: {str(enabled).lower()}") - for k, v in cfg.items(): - if k != "enabled" and v is not None: - lines.append(f" {k}: {v}") - - lines += [ - "", - "# Workspace sync root โ€” override if needed", - "sync_root: \"/tmp/cortex-sync\"", - "", - "# FS Explorer root โ€” defaults to user home if not specified here", - "# fs_root: \"/User/username/Documents\"", - "", - "# TLS โ€” set to false only in dev", - f"tls: {str(settings.GRPC_TLS_ENABLED).lower()}", - ] - return "\n".join(lines) - -def _get_node_or_404(node_id: str, db: Session) -> models.AgentNode: - node = db.query(models.AgentNode).filter(models.AgentNode.node_id == node_id).first() - if not node: - raise HTTPException(status_code=404, detail=f"Node '{node_id}' not found.") - return node - - -def _node_to_admin_detail(node: models.AgentNode, registry) -> schemas.AgentNodeAdminDetail: - live = registry.get_node(node.node_id) - status = live._compute_status() if live else node.last_status or "offline" - stats = schemas.AgentNodeStats(**live.stats) if live else schemas.AgentNodeStats() - return schemas.AgentNodeAdminDetail( - node_id=node.node_id, - display_name=node.display_name, - description=node.description, - skill_config=node.skill_config or {}, - capabilities=node.capabilities or {}, - invite_token=node.invite_token, - is_active=node.is_active, - last_status=status, - last_seen_at=node.last_seen_at, - created_at=node.created_at, - registered_by=node.registered_by, - group_access=[ - schemas.NodeAccessResponse( - id=a.id, node_id=a.node_id, group_id=a.group_id, - access_level=a.access_level, granted_at=a.granted_at - ) for a in (node.group_access or []) - ], - stats=stats, - ) - - -def _node_to_user_view(node: models.AgentNode, registry) -> schemas.AgentNodeUserView: - # Use global state from the app context, or we can't easily access services here... - # Ah, wait! The _node_to_user_view is passed `registry`, but NOT `services`. - # Let's import the mesh_service from `app.main` or we can just redefine it inline to use services.mesh_service if we change the caller. - # Actually, let's keep it simple for now as we don't have access to `services` easily in a global helper unless passed. - # Wait, the plan asks us to *extract* it. Let's modify the caller and delete it! - pass - - -def _now() -> str: - from datetime import datetime - return datetime.utcnow().isoformat() + return router async def _drain(q: queue.Queue, websocket: WebSocket): @@ -1357,3 +921,8 @@ await websocket.send_json(event) except queue.Empty: break + + +def _now() -> str: + from datetime import datetime + return datetime.utcnow().isoformat() diff --git a/ai-hub/app/api/routes/sessions.py b/ai-hub/app/api/routes/sessions.py index d9fdb9f..d80fe06 100644 --- a/ai-hub/app/api/routes/sessions.py +++ b/ai-hub/app/api/routes/sessions.py @@ -126,90 +126,8 @@ @router.get("/{session_id}/tokens", response_model=schemas.SessionTokenUsageResponse, summary="Get Session Token Usage") def get_session_token_usage(session_id: int, db: Session = Depends(get_db)): - try: - session = db.query(models.Session).filter(models.Session.id == session_id).first() - if not session: - raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.") + return services.session_service.get_token_usage(db, session_id) - messages = services.rag_service.get_message_history(db=db, session_id=session_id) - combined_text = " ".join([m.content for m in messages]) - - # Resolve dynamic token limit from model info - from app.core.providers.factory import get_model_limit - - # M3: Resolve effective configuration using PreferenceService (same as UI) - # Ensure we have a user context, fallback to system admin if session has no owner - user_context = session.user - if not user_context: - from app.config import settings - admin_email = settings.SUPER_ADMINS[0] if settings.SUPER_ADMINS else None - if admin_email: - user_context = db.query(models.User).filter(models.User.email == admin_email).first() - - # Resolve effective provider and model using merged preferences - effective_provider = session.provider_name - resolved_model = None - - if user_context: - config = services.preference_service.merge_user_config(user_context, db) - effective_llm = config.effective.get("llm", {}) - - # 1. Use session-specific provider or fall back to user/system active provider - effective_provider = effective_provider or effective_llm.get("active_provider") - - # 2. Extract model name for this specific provider name - providers = effective_llm.get("providers", {}) - p_info = providers.get(effective_provider, {}) - resolved_model = p_info.get("model") - - # 3. Handle LiteLLM style model strings (e.g. "openai/gpt-4o") - if not resolved_model and "/" in (effective_provider or ""): - resolved_model = effective_provider.split("/", 1)[1] - effective_provider = effective_provider.split("/")[0] - - # 4. Support instance prefixes if exact match fails (e.g. "gemini_instance_1" -> "gemini") - if not resolved_model and effective_provider not in providers and "_" in (effective_provider or ""): - base_prov = effective_provider.split("_")[0] - resolved_model = providers.get(base_prov, {}).get("model") - else: - # Absolute fallback to hardcoded settings if no database user context exists - from app.config import settings - effective_provider = effective_provider or settings.ACTIVE_LLM_PROVIDER - resolved_model = None - - try: - token_limit = get_model_limit(effective_provider, model_name=resolved_model) - except ValueError as e: - # Model not configured โ€” return a graceful 200 with error hint - import logging - logging.warning(f"[Tokens] Limit resolution failed for {effective_provider}/{resolved_model}: {e}") - return schemas.SessionTokenUsageResponse( - token_count=0, - token_limit=0, - percentage=0.0, - error=str(e) - ) - - validator = Validator(token_limit=token_limit) - token_count = validator.get_token_count(combined_text) - - # Defensive check: if token_limit is still 0 (shouldn't happen with get_model_limit fallback), avoid DivZero - if token_limit <= 0: - token_limit = 10000 - - percentage = round((token_count / token_limit) * 100, 2) - - return schemas.SessionTokenUsageResponse( - token_count=token_count, - token_limit=token_limit, - percentage=percentage - ) - except HTTPException: - raise - except Exception as e: - import logging - logging.exception(f"Internal error fetching token usage for session {session_id}") - raise HTTPException(status_code=500, detail=f"An error occurred: {e}") @router.get("/", response_model=List[schemas.Session], summary="Get All Chat Sessions") def get_sessions( @@ -290,47 +208,7 @@ @router.delete("/{session_id}", summary="Delete a Chat Session") def delete_session(session_id: int, db: Session = Depends(get_db)): try: - session = db.query(models.Session).filter(models.Session.id == session_id).first() - if not session: - raise HTTPException(status_code=404, detail="Session not found.") - - if session.is_locked: - raise HTTPException(status_code=403, detail="Cannot delete a locked session. Unlock it first to delete.") - - session.is_archived = True - sync_workspace_id = session.sync_workspace_id - db.commit() - - # M6: Immediately broadcast PURGE to all connected nodes - if sync_workspace_id: - try: - orchestrator = getattr(services, "orchestrator", None) - if orchestrator: - import app.protos.agent_pb2 as agent_pb2 - live_nodes = orchestrator.registry.list_nodes() - for node in live_nodes: - try: - node.send_message(agent_pb2.ServerTaskMessage( - file_sync=agent_pb2.FileSyncMessage( - session_id=sync_workspace_id, - control=agent_pb2.SyncControl( - action=agent_pb2.SyncControl.PURGE, - path="" - ) - ) - ), priority=0) - except Exception as e: - import logging - logging.exception(f"[๐Ÿ“โš ๏ธ] Failed to send PURGE to node {node.node_id}: {e}") - - # Hub local purge - from app.config import settings - mirror_path = os.path.join(settings.DATA_DIR, "mirrors", sync_workspace_id) - if os.path.exists(mirror_path): - shutil.rmtree(mirror_path, ignore_errors=True) - except Exception as e: - import logging - logging.exception(f"[๐Ÿ“โš ๏ธ] Fast local purge failed for {sync_workspace_id}: {e}") + services.session_service.archive_session(db, session_id) return {"message": "Session deleted successfully."} except HTTPException: raise @@ -340,51 +218,8 @@ @router.delete("/", summary="Delete All Sessions for Feature") def delete_all_sessions(user_id: str, feature_name: str = "default", db: Session = Depends(get_db)): try: - sessions = db.query(models.Session).filter( - models.Session.user_id == user_id, - models.Session.feature_name == feature_name, - models.Session.is_archived == False, - models.Session.is_locked == False - ).all() - - workspaces_to_purge = [] - for session in sessions: - session.is_archived = True - if session.sync_workspace_id: - workspaces_to_purge.append(session.sync_workspace_id) - - db.commit() - - # M6: Immediately broadcast PURGE to all connected nodes - if workspaces_to_purge: - try: - orchestrator = getattr(services, "orchestrator", None) - if orchestrator: - import app.protos.agent_pb2 as agent_pb2 - live_nodes = orchestrator.registry.list_nodes() - for node in live_nodes: - for wid in workspaces_to_purge: - node.send_message(agent_pb2.ServerTaskMessage( - file_sync=agent_pb2.FileSyncMessage( - session_id=wid, - control=agent_pb2.SyncControl( - action=agent_pb2.SyncControl.PURGE, - path="" - ) - ) - ), priority=0) - - # Hub local purge - from app.config import settings - for wid in workspaces_to_purge: - mirror_path = os.path.join(settings.DATA_DIR, "mirrors", wid) - if os.path.exists(mirror_path): - shutil.rmtree(mirror_path, ignore_errors=True) - except Exception as e: - import logging - logging.exception(f"[๐Ÿ“โš ๏ธ] Fast local bulk purge failed: {e}") - - return {"message": "All sessions deleted successfully."} + count = services.session_service.archive_all_feature_sessions(db, user_id, feature_name) + return {"message": f"Deleted {count} sessions successfully."} except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to delete all sessions: {e}") diff --git a/ai-hub/app/api/routes/user.py b/ai-hub/app/api/routes/user.py index 9306ec5..5de8bab 100644 --- a/ai-hub/app/api/routes/user.py +++ b/ai-hub/app/api/routes/user.py @@ -327,49 +327,9 @@ db: Session = Depends(get_db), user_id: str = Depends(get_current_user_id) ): - from app.core.providers.factory import get_llm_provider - if not user_id: - raise HTTPException(status_code=401, detail="Unauthorized") user = services.user_service.get_user_by_id(db=db, user_id=user_id) - if not user: - raise HTTPException(status_code=404, detail="User not found") - - # We allow verification if user is admin OR if they are providing their own key (not using a masked key without permission) - is_using_masked = not req.api_key or "***" in str(req.api_key) - if is_using_masked and user.role != "admin": - raise HTTPException(status_code=403, detail="Forbidden: Admin only for masked keys") - actual_key = req.api_key - try: - llm_prefs = {} - user = services.user_service.get_user_by_id(db=db, user_id=user_id) - if user and user.preferences: - llm_prefs = user.preferences.get("llm", {}).get("providers", {}).get(req.provider_name, {}) - - # Handle masked keys by backfilling from stored prefs if needed - if actual_key and "***" in actual_key: - actual_key = llm_prefs.get("api_key") - if not actual_key: - # Fallback to system defaults if admin - system_prefs = services.user_service.get_system_settings(db) - actual_key = system_prefs.get("llm", {}).get("providers", {}).get(req.provider_type or req.provider_name, {}).get("api_key") - - kwargs = {k: v for k, v in llm_prefs.items() if k not in ["api_key", "model"]} - if req.provider_type: - kwargs["provider_type"] = req.provider_type - - llm = get_llm_provider( - provider_name=req.provider_name, - model_name=req.model or "", - api_key_override=actual_key, - **kwargs - ) - # GeneralProvider check - res = await llm.acompletion(prompt="Hello") - return schemas.VerifyProviderResponse(success=True, message="Connection successful!") - except Exception as e: - import logging - logging.getLogger(__name__).error(f"LLM Verification failed for {req.provider_name} ({req.provider_type}): {e}") - return schemas.VerifyProviderResponse(success=False, message=str(e)) + if not user: raise HTTPException(status_code=404, detail="User not found") + return await services.preference_service.verify_provider(db, user, req, "llm") @router.post("/me/config/verify_tts", response_model=schemas.VerifyProviderResponse) async def verify_tts( @@ -377,97 +337,18 @@ db: Session = Depends(get_db), user_id: str = Depends(get_current_user_id) ): - from app.core.providers.factory import get_tts_provider - from app.config import settings - - if not user_id: - raise HTTPException(status_code=401, detail="Unauthorized") user = services.user_service.get_user_by_id(db=db, user_id=user_id) - if not user: - raise HTTPException(status_code=404, detail="User not found") - - is_using_masked = not req.api_key or "***" in str(req.api_key) - if is_using_masked and user.role != "admin": - raise HTTPException(status_code=403, detail="Forbidden: Admin only for masked keys") - - actual_key = req.api_key - try: - tts_prefs = user.preferences.get("tts", {}).get("providers", {}).get(req.provider_name, {}) if user.preferences else {} - - # Key resolution: Masked keys should be replaced with real ones from DB or system config - if not actual_key or "***" in str(actual_key): - actual_key = tts_prefs.get("api_key") - if not actual_key or "***" in str(actual_key): - # Try system settings - system_prefs = services.user_service.get_system_settings(db) - actual_key = system_prefs.get("tts", {}).get("providers", {}).get(req.provider_type or req.provider_name, {}).get("api_key") - # Final fallback to settings.py constants - if not actual_key: actual_key = settings.TTS_API_KEY or settings.GEMINI_API_KEY - - logger.info(f"verify_tts: instance={req.provider_name}, type={req.provider_type}, model={req.model}") - - kwargs = {k: v for k, v in tts_prefs.items() if k not in ["api_key", "model", "voice"]} - if req.provider_type: - kwargs["provider_type"] = req.provider_type - - provider = get_tts_provider( - provider_name=req.provider_name, - api_key=actual_key, - model_name=req.model or "", - voice_name=req.voice or "", - **kwargs - ) - await provider.generate_speech("Hello there. We are testing this thing.") - return schemas.VerifyProviderResponse(success=True, message="Connection successful!") - except Exception as e: - logger.error(f"TTS verification failed for {req.provider_name}: {e}") - return schemas.VerifyProviderResponse(success=False, message=str(e)) - + if not user: raise HTTPException(status_code=404, detail="User not found") + return await services.preference_service.verify_provider(db, user, req, "tts") @router.post("/me/config/verify_stt", response_model=schemas.VerifyProviderResponse) async def verify_stt( req: schemas.VerifyProviderRequest, db: Session = Depends(get_db), user_id: str = Depends(get_current_user_id) ): - from app.core.providers.factory import get_stt_provider - from app.config import settings - - if not user_id: - raise HTTPException(status_code=401, detail="Unauthorized") user = services.user_service.get_user_by_id(db=db, user_id=user_id) - if not user: - raise HTTPException(status_code=404, detail="User not found") - is_using_masked = not req.api_key or "***" in str(req.api_key) - if is_using_masked and user.role != "admin": - raise HTTPException(status_code=403, detail="Forbidden: Admin only for masked keys") - - actual_key = req.api_key - try: - stt_prefs = user.preferences.get("stt", {}).get("providers", {}).get(req.provider_name, {}) if user.preferences else {} - - if not actual_key or "***" in str(actual_key): - actual_key = stt_prefs.get("api_key") - if not actual_key or "***" in str(actual_key): - system_prefs = services.user_service.get_system_settings(db) - actual_key = system_prefs.get("stt", {}).get("providers", {}).get(req.provider_type or req.provider_name, {}).get("api_key") - if not actual_key: actual_key = settings.STT_API_KEY or settings.GEMINI_API_KEY - - kwargs = {k: v for k, v in stt_prefs.items() if k not in ["api_key", "model"]} - if req.provider_type: - kwargs["provider_type"] = req.provider_type - - provider = get_stt_provider( - provider_name=req.provider_name, - api_key=actual_key, - model_name=req.model or "", - **kwargs - ) - # Minimal STT check: factory init is usually enough to catch invalid credentials for SDK-based providers - return schemas.VerifyProviderResponse(success=True, message="Provider initialized. Full transcription test requires audio payload.") - except Exception as e: - logger.error(f"STT verification failed for {req.provider_name}: {e}") - return schemas.VerifyProviderResponse(success=False, message=str(e)) - + if not user: raise HTTPException(status_code=404, detail="User not found") + return await services.preference_service.verify_provider(db, user, req, "stt") @router.post("/logout", summary="Log Out the Current User") async def logout(): """ @@ -483,77 +364,13 @@ ): """Exports the effective user configuration as a YAML file (Admin only).""" from fastapi.responses import PlainTextResponse - if not user_id: - raise HTTPException(status_code=401, detail="Unauthorized") user = services.user_service.get_user_by_id(db=db, user_id=user_id) if not user or user.role != "admin": raise HTTPException(status_code=403, detail="Forbidden: Admin only") - prefs_dict = copy.deepcopy(user.preferences) if user.preferences else {} - from app.config import settings - import yaml - import copy - - # Sensitive keys that should be encrypted - SENSITIVE_KEYS = ["api_key", "client_secret", "webhook_secret", "password", "key_content", "key_file"] - - def process_export(obj): - if isinstance(obj, dict): - res = {} - for k, v in obj.items(): - if v is None: continue - if k in SENSITIVE_KEYS and v: - if reveal_secrets: - res[k] = v - else: - res[k] = encrypt_value(v) - else: - res[k] = process_export(v) - return res - elif isinstance(obj, list): - return [process_export(x) for x in obj] - return obj - - # Ensure we have the base sections even if empty in prefs - export_data = { - "llm": prefs_dict.get("llm", {"providers": {}, "active_provider": "deepseek"}), - "tts": prefs_dict.get("tts", {"providers": {}, "active_provider": settings.TTS_PROVIDER}), - "stt": prefs_dict.get("stt", {"providers": {}, "active_provider": settings.STT_PROVIDER}) - } - - # Backfill from system settings if sections are empty - if not export_data["llm"].get("providers"): - export_data["llm"]["providers"] = { - "deepseek": {"api_key": settings.DEEPSEEK_API_KEY, "model": settings.DEEPSEEK_MODEL_NAME}, - "gemini": {"api_key": settings.GEMINI_API_KEY, "model": settings.GEMINI_MODEL_NAME}, - "openai": {"api_key": settings.OPENAI_API_KEY} - } - - if not export_data["tts"].get("providers"): - export_data["tts"]["providers"] = { - settings.TTS_PROVIDER: { - "api_key": settings.TTS_API_KEY, - "model": settings.TTS_MODEL_NAME, - "voice": settings.TTS_VOICE_NAME - } - } - - if not export_data["stt"].get("providers"): - export_data["stt"]["providers"] = { - settings.STT_PROVIDER: { - "api_key": settings.STT_API_KEY, - "model": settings.STT_MODEL_NAME - } - } - - clean_yaml_data = process_export(export_data) - yaml_str = yaml.dump(clean_yaml_data, sort_keys=False, default_flow_style=False) - - return PlainTextResponse( - content=yaml_str, - media_type="application/x-yaml", - headers={"Content-Disposition": "attachment; filename=\"cortex_config.yaml\""} - ) + yaml_str = services.preference_service.export_config_yaml(user, reveal_secrets) + return PlainTextResponse(content=yaml_str, media_type="application/x-yaml", + headers={"Content-Disposition": "attachment; filename=\"cortex_config.yaml\""}) @router.post("/me/config/import", response_model=schemas.UserPreferences, summary="Import Configurations from YAML") async def import_user_config_yaml( @@ -562,110 +379,10 @@ user_id: str = Depends(get_current_user_id) ): """Imports user configuration from a YAML file.""" - if not user_id: - raise HTTPException(status_code=401, detail="Unauthorized") user = services.user_service.get_user_by_id(db=db, user_id=user_id) - if not user: - raise HTTPException(status_code=404, detail="User not found") - + if not user: raise HTTPException(status_code=404, detail="User not found") content = await file.read() - try: - import yaml - data = yaml.safe_load(content) - except Exception as e: - raise HTTPException(status_code=400, detail=f"Invalid YAML file: {e}") - - def process_import(obj): - if isinstance(obj, dict): - return {k: process_import(v) for k, v in obj.items()} - elif isinstance(obj, str): - return decrypt_value(obj) - elif isinstance(obj, list): - return [process_import(x) for x in obj] - return obj - - data = process_import(data) - - # Map to UserPreferences structure - new_llm = data.get("llm", {}) - new_tts = data.get("tts", {}) - new_stt = data.get("stt", {}) - - # Handle legacy flat structure if imported from old version - if not new_llm and "llm_providers" in data: - llm_flat = data.get("llm_providers", {}) - new_llm = {"providers": {}} - for k, v in llm_flat.items(): - if k.endswith("_api_key"): - p = k.replace("_api_key", "") - if p not in new_llm["providers"]: new_llm["providers"][p] = {} - new_llm["providers"][p]["api_key"] = v - elif k.endswith("_model_name"): - p = k.replace("_model_name", "") - if p not in new_llm["providers"]: new_llm["providers"][p] = {} - new_llm["providers"][p]["model"] = v - if new_llm["providers"]: - new_llm["active_provider"] = next(iter(new_llm["providers"]), None) - - if not new_tts and "tts_provider" in data: - tts_flat = data.get("tts_provider", {}) - p = tts_flat.get("provider") or "google_gemini" - new_tts = { - "active_provider": p, - "providers": { - p: {"api_key": tts_flat.get("api_key"), "model": tts_flat.get("model_name"), "voice": tts_flat.get("voice_name")} - } - } - - if not new_stt and "stt_provider" in data: - stt_flat = data.get("stt_provider", {}) - p = stt_flat.get("provider") or "google_gemini" - new_stt = { - "active_provider": p, - "providers": { - p: {"api_key": stt_flat.get("api_key"), "model": stt_flat.get("model_name")} - } - } - - user.preferences = { - "llm": new_llm, - "tts": new_tts, - "stt": new_stt, - "statuses": {} - } - from sqlalchemy.orm.attributes import flag_modified - flag_modified(user, "preferences") - - # Sync to global settings (only for admin) - if user.role == "admin": - from app.config import settings as global_settings - if new_llm.get("providers"): - global_settings.LLM_PROVIDERS.update(new_llm["providers"]) - if new_tts.get("active_provider"): - p = new_tts["active_provider"] - p_data = new_tts["providers"].get(p, {}) - if p_data: - global_settings.TTS_PROVIDER = p - global_settings.TTS_MODEL_NAME = p_data.get("model") or global_settings.TTS_MODEL_NAME - global_settings.TTS_VOICE_NAME = p_data.get("voice") or global_settings.TTS_VOICE_NAME - global_settings.TTS_API_KEY = p_data.get("api_key") or global_settings.TTS_API_KEY - if new_stt.get("active_provider"): - p = new_stt["active_provider"] - p_data = new_stt["providers"].get(p, {}) - if p_data: - global_settings.STT_PROVIDER = p - global_settings.STT_MODEL_NAME = p_data.get("model") or global_settings.STT_MODEL_NAME - global_settings.STT_API_KEY = p_data.get("api_key") or global_settings.STT_API_KEY - - try: - global_settings.save_to_yaml() - except Exception as ey: - logger.error(f"Failed to sync settings to YAML on import: {ey}") - - db.add(user) - db.commit() - db.refresh(user) - return schemas.UserPreferences(llm=user.preferences.get("llm", {}), tts=user.preferences.get("tts", {}), stt=user.preferences.get("stt", {})) + return await services.preference_service.import_config_yaml(db, user, content) diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index 42693a2..d9c76ba 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -303,6 +303,9 @@ agent_scheduler = AgentScheduler(services=services) services.with_service("agent_scheduler", service=agent_scheduler) + from app.core.services.agent import AgentService + services.with_service("agent_service", service=AgentService(services=services)) + app.state.services = services # Create and include the API router, injecting the service diff --git a/ai-hub/app/core/_regex.py b/ai-hub/app/core/_regex.py index ee95aeb..514b592 100644 --- a/ai-hub/app/core/_regex.py +++ b/ai-hub/app/core/_regex.py @@ -1,5 +1,28 @@ import re -# Pre-compiled ANSI Escape regex for reuse across Rag, Registry, and Architects -# This prevents O(N * tokens) compilation overhead during intensive streaming -ANSI_ESCAPE = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') +# ANSI Escape Sequences (Standard and Xterm) +ANSI_ESCAPE = re.compile(r'\x1b(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') + +# --- Agent Orchestration Patterns --- + +# Matches satellite thinking turn markers (used in AgentExecutor._compress_reasoning) +TURN_THINKING_MARKER = re.compile(r"๐Ÿ›ฐ๏ธ.*\[turn.*thinking", re.IGNORECASE) + +# Matches strategy execution boilerplate +STRATEGY_BOILERPLATE = re.compile(r"strategy:.*executing orchestrated tasks", re.IGNORECASE) + +# Matches quality gate failure notifications +COWORKER_FAIL_MARKER = re.compile(r"โš ๏ธ \*\*Co-Worker\*\*: Quality check FAILED") + +# Parser Patterns +SKILL_CONFIG_JSON = re.compile(r"```json\s*({\s*\"skill_name\":.*?})\s*```", re.DOTALL) +SKILL_DESC_OVERRIDE = re.compile(r"\[DESCRIPTION_OVERRIDE\]:\s*(.*)") +SKILL_PARAM_TABLE = re.compile(r"\|.*\|", re.MULTILINE) +SKILL_BASH_LOGIC = re.compile(r"\[BASH_START\](.*?)\[BASH_END\]", re.DOTALL) + +# Infrastructure Patterns +URL_CLEANER = re.compile(r'https?://[^\s<>"]+|www\.[^\s<>"]+') + +# Evaluation Patterns +FINAL_SCORE = re.compile(r"FINAL_SCORE:\s*(\d+)") +RUBRIC_SECTION = re.compile(r"^#\s+(Evaluation Rubric|Rework Instructions)", re.MULTILINE | re.IGNORECASE) diff --git a/ai-hub/app/core/orchestration/agent_loop.py b/ai-hub/app/core/orchestration/agent_loop.py index ae7e84b..1d7ce11 100644 --- a/ai-hub/app/core/orchestration/agent_loop.py +++ b/ai-hub/app/core/orchestration/agent_loop.py @@ -2,652 +2,480 @@ import time from datetime import datetime import logging -from sqlalchemy.orm import Session -from sqlalchemy.orm.exc import ObjectDeletedError -from tenacity import retry, wait_exponential, stop_after_attempt import json - -logger = logging.getLogger(__name__) +import traceback +from typing import Dict, Any, List, Optional +from sqlalchemy.orm import Session +from sqlalchemy.orm.exc import ObjectDeletedError, StaleDataError from app.db.session import SessionLocal from app.db.models.agent import AgentInstance, AgentTemplate -from app.db.models import Message +from app.db.models.session import Message, Session as SessionModel +from app.core.providers.factory import get_llm_provider +from app.core.orchestration.harness_evaluator import HarnessEvaluator +from app.core._regex import TURN_THINKING_MARKER, STRATEGY_BOILERPLATE, ANSI_ESCAPE + +logger = logging.getLogger(__name__) class AgentExecutor: + """ + Orchestrates the execution of an agent, including its rework loop, + evaluation via a Co-Worker Auditor, and real-time metric tracking. + """ + + def __init__(self, agent_id: str, services, user_service): + self.agent_id = agent_id + self.services = services + self.user_service = user_service + self.db: Session = SessionLocal() + self.instance: Optional[AgentInstance] = None + self.template: Optional[AgentTemplate] = None + self.evaluator: Optional[HarnessEvaluator] = None + @staticmethod async def run(agent_id: str, prompt: str, services, user_service, skip_coworker: bool = False): - """Asynchronous execution loop for the agent.""" - # Create a fresh DB session for the background task - db: Session = SessionLocal() - - def safe_commit(): - """Helper to commit and handle session errors or deleted objects gracefully.""" - try: - db.commit() - return True - except Exception as e: - db.rollback() - from sqlalchemy.orm.exc import ObjectDeletedError, StaleDataError - if isinstance(e, (ObjectDeletedError, StaleDataError)): - print(f"[AgentExecutor] Agent {agent_id} was deleted or modified externally. Exiting loop.") - return False - print(f"[AgentExecutor] Commit failed for {agent_id}: {e}") - # Re-raise if it's something truly unexpected - raise - + """Entry point for the background execution task.""" + executor = AgentExecutor(agent_id, services, user_service) try: - instance = db.query(AgentInstance).filter(AgentInstance.id == agent_id).first() - if not instance or not prompt: - db.close() - return + return await executor.execute(prompt, skip_coworker) + finally: + executor.db.close() - # Acquire Lease - instance.last_heartbeat = datetime.utcnow() - instance.status = "active" - instance.last_error = None - instance.total_runs = (instance.total_runs or 0) + 1 - if not safe_commit(): return + async def execute(self, prompt: str, skip_coworker: bool = False): + """Main execution sequence.""" + if not await self._initialize_instance(prompt): + return - - # Launch secondary heartbeat task - async def heartbeat(): - while True: - await asyncio.sleep(60) - try: - inner_db = SessionLocal() - inner_instance = inner_db.query(AgentInstance).filter(AgentInstance.id == agent_id).first() - if not inner_instance or inner_instance.status not in ["active", "starting"]: - inner_db.close() - break - inner_instance.last_heartbeat = datetime.utcnow() - inner_db.commit() - inner_db.close() - except: - break + heartbeat_task = asyncio.create_task(self._heartbeat_loop()) + + try: + # Phase 1: Setup & Initialization + rubric_task = await self._setup_evaluation(prompt, skip_coworker) - heartbeat_task = asyncio.create_task(heartbeat()) - - template = db.query(AgentTemplate).filter(AgentTemplate.id == instance.template_id).first() - if not template: - instance.status = "error_suspended" - instance.last_error = f"Template '{instance.template_id}' not found." - if not safe_commit(): return - return - - # Configuration for Rework Loop - co_worker_enabled = getattr(template, "co_worker_quality_gate", False) - rework_threshold = getattr(template, "rework_threshold", 80) - max_rework_attempts = getattr(template, "max_rework_attempts", 3) + # Phase 2: Execution & Rework Loop + final_result = await self._run_rework_loop(prompt, rubric_task) - # --- Phase 1: Pre-Execution Initialization (Harness Mirror) --- - from app.core.orchestration.harness_evaluator import HarnessEvaluator - evaluator = None - rubric_content = "" - rubric_task = None - - if co_worker_enabled and not skip_coworker: - from app.core.providers.factory import get_llm_provider - # For Evaluation, we use the same provider/model as the main task for consistency - # Load provider settings - from app.db.models.session import Session as SessionModel - agent_session = db.query(SessionModel).filter(SessionModel.id == instance.session_id).first() - provider_name = getattr(agent_session, "provider_name", None) - if not provider_name and user_service: - from app.config import settings - provider_name = settings.ACTIVE_LLM_PROVIDER - - # Resolve model/key from user preferences (mirrors rag.py logic) - base_provider_key = provider_name.split("/")[0] if provider_name and "/" in provider_name else provider_name - llm_prefs = {} - if agent_session and agent_session.user and agent_session.user.preferences: - llm_prefs = agent_session.user.preferences.get("llm", {}).get("providers", {}).get(base_provider_key, {}) - if (not llm_prefs or not llm_prefs.get("api_key") or "*" in str(llm_prefs.get("api_key"))) and user_service: - system_prefs = user_service.get_system_settings(db) - system_provider_prefs = system_prefs.get("llm", {}).get("providers", {}).get(base_provider_key, {}) - - if not system_provider_prefs or not system_provider_prefs.get("model"): - active_prov_key = system_prefs.get("llm", {}).get("active_provider") - if active_prov_key: - system_provider_prefs = system_prefs.get("llm", {}).get("providers", {}).get(active_prov_key, {}) - if system_provider_prefs: - provider_name = active_prov_key - - if system_provider_prefs: - merged = system_provider_prefs.copy() - if llm_prefs: merged.update({k: v for k, v in llm_prefs.items() if v}) - llm_prefs = merged - eval_api_key = llm_prefs.get("api_key") - eval_model = "" if "/" in (provider_name or "") else llm_prefs.get("model", "") - eval_kwargs = {k: v for k, v in llm_prefs.items() if k not in ["api_key", "model"]} - eval_provider = get_llm_provider(provider_name, model_name=eval_model, api_key_override=eval_api_key, **eval_kwargs) - - # Instantiate the main evaluator for the loop - evaluator = HarnessEvaluator(db, agent_id, instance.mesh_node_id, instance.session.sync_workspace_id if instance.session else str(instance.session_id), eval_provider, services) - - # LAUNCH RUBRIC GENERATION IN PARALLEL - # We use a specialized background runner to avoid session contention - async def rubric_runner(p, agent_id_inner, eval_provider_inner, services_inner): - bg_db = SessionLocal() - try: - bg_instance = bg_db.query(AgentInstance).filter(AgentInstance.id == agent_id_inner).first() - if not bg_instance: return None - - bg_evaluator = HarnessEvaluator( - bg_db, - agent_id_inner, - bg_instance.mesh_node_id, - bg_instance.session.sync_workspace_id if bg_instance.session else str(bg_instance.session_id), - eval_provider_inner, - services_inner - ) - # Initialize and generate - await bg_evaluator.initialize_cortex() - return await bg_evaluator.generate_rubric(p) - except Exception as e: - logger.error(f"[AgentExecutor] Background rubric generation failed: {e}") - return None - finally: - bg_db.close() - - rubric_task = asyncio.create_task(rubric_runner(prompt, agent_id, eval_provider, services)) - - # Update status immediately to reflect both tasks starting - db.query(AgentInstance).filter(AgentInstance.id == agent_id).update({ - "status": "starting", - "evaluation_status": "๐Ÿ“‹ Co-Worker: Initiating parallel rubric & mission setup...", - "current_rework_attempt": 0 - }) - if not safe_commit(): return - - # Emit status if registry exists - registry = getattr(services.rag_service, "node_registry_service", None) - if registry and instance.mesh_node_id: - registry.emit(instance.mesh_node_id, "status_update", {"evaluation_status": "๐Ÿ“‹ Co-Worker: Generating rubric (parallel)..."}) - - - - max_iterations = template.max_loop_iterations or 20 - session_id = instance.session_id - - from app.db.models.session import Message - from app.db.models.session import Session as SessionModel - - agent_session = db.query(SessionModel).filter(SessionModel.id == session_id).first() - provider_name = getattr(agent_session, "provider_name", None) - - # If not explicitly defined on session, fallback to system default - if not provider_name and user_service: - from app.config import settings - provider_name = settings.ACTIVE_LLM_PROVIDER - - # Area 4.2: Hippocampus (Scratchpad) Idempotency Check - if session_id: - if getattr(agent_session, "auto_clear_history", False): - db.query(Message).filter(Message.session_id == session_id).delete(synchronize_session=False) - if not safe_commit(): return - - current_prompt = prompt - current_attempt = 0 - final_result = None - - if evaluator: - await evaluator.log_event("Execution Initialized", "Agent loop warming up for primary task execution.") - - # Track cumulative metrics for this entire execution run (across all rework rounds) - total_task_input_tokens = 0 - total_task_output_tokens = 0 - total_task_tool_counts = {} - - # --- AWAIT INITIAL SETUP (Sync Point) --- - if rubric_task: - # We wait for the first round's rubric to ensure the node is ready for the test/UI - # Subsequent rounds use the already-captured rubric_content - rubric_content = await rubric_task - if not rubric_content: - rubric_content = "# Evaluation Rubric\nComplete the requested task with high technical accuracy." - # Reset task to None so we don't await it again in the loop - rubric_task = None - - # --- MAIN REWORK LOOP --- - loop_start = time.time() # Handle scope for exception reporting - while current_attempt <= max_rework_attempts: - # Refresh instance for loop state - instance = db.query(AgentInstance).filter(AgentInstance.id == agent_id).first() - if not instance: break - - round_sub_events = [] - try: - registry = getattr(services.rag_service, "node_registry_service", None) - round_tool_counts = {} - round_input_tokens = 0 - round_output_tokens = 0 - final_answer = "" - last_assistant_msg_id = None - - instance = db.query(AgentInstance).filter(AgentInstance.id == agent_id).first() - instance.last_reasoning = "" - instance.status = "active" - instance.evaluation_status = f"๐Ÿค– Main Agent (Rd {current_attempt + 1}): Executing..." - if not safe_commit(): return - - execution_start = time.time() - - if registry and instance.mesh_node_id: - registry.emit(instance.mesh_node_id, "status_update", {"evaluation_status": instance.evaluation_status}) - - # Buffers for real-time streaming to avoid O(N^2) regex and DB hammering - content_buffer = "" - last_db_sync_time = time.time() - sync_token_count = 0 - - async for event in services.rag_service.chat_with_rag( - db=db, - session_id=session_id, - prompt=current_prompt, - provider_name=provider_name, - load_faiss_retriever=False, - user_service=user_service - ): - if event.get("type") == "finish": - last_assistant_msg_id = event.get("message_id") - round_tool_counts = event.get("tool_counts", {}) - # Skip input_tokens/output_tokens from finish if we already got them from token_counted events - # or if they are just duplicates of what we already accumulated. - if round_input_tokens == 0: - round_input_tokens = event.get("input_tokens", 0) - if round_output_tokens == 0: - round_output_tokens = event.get("output_tokens", 0) - final_answer = event.get("full_answer", "") - elif event.get("type") == "token_counted": - usage = event.get("usage", {}) - round_input_tokens += usage.get("prompt_tokens", 0) - round_output_tokens += usage.get("completion_tokens", 0) - elif event.get("type") in ("reasoning", "content"): - new_content = event.get("content", "") - if event.get("type") == "content": - final_answer += new_content - content_buffer += new_content - sync_token_count += 1 - - now = time.time() - if now - last_db_sync_time > 2.0 or sync_token_count >= 50: - instance = db.query(AgentInstance).filter(AgentInstance.id == agent_id).first() - if instance: - instance.last_reasoning = (instance.last_reasoning or "") + content_buffer - content_buffer = "" - last_db_sync_time = now - sync_token_count = 0 - if not safe_commit(): return - else: - # Agent deleted, stop streaming - return - - try: - if registry and instance.mesh_node_id: - registry.emit(instance.mesh_node_id, "reasoning", { - "content": new_content, - "agent_id": agent_id, - "session_id": instance.session_id - }) - except ObjectDeletedError: - logger.info(f"Agent {agent_id} was deleted during execution. Stopping loop.") - return - - # Final flush - if content_buffer: - instance = db.query(AgentInstance).filter(AgentInstance.id == agent_id).first() - if instance: - instance.last_reasoning = (instance.last_reasoning or "") + content_buffer - if not safe_commit(): return - content_buffer = "" - - # --- Persistence: Update Cumulative Metrics in DB (Real-time) --- - instance = db.query(AgentInstance).filter(AgentInstance.id == agent_id).first() - if instance: - instance.total_input_tokens = (instance.total_input_tokens or 0) + round_input_tokens - instance.total_output_tokens = (instance.total_output_tokens or 0) + round_output_tokens - - # Merge tool counts - if round_tool_counts: - current_counts = (instance.tool_call_counts or {}).copy() - for tool, counts in round_tool_counts.items(): - if tool not in current_counts: - current_counts[tool] = {"calls": 0, "successes": 0, "failures": 0} - - # Handle both dict and legacy int formats - c_inc = counts.get("calls", counts) if isinstance(counts, dict) else counts - s_inc = counts.get("successes", counts) if isinstance(counts, dict) else counts - f_inc = counts.get("failures", 0) if isinstance(counts, dict) else 0 - - # Handle existing int in DB - if isinstance(current_counts[tool], int): - current_counts[tool] = {"calls": current_counts[tool], "successes": current_counts[tool], "failures": 0} - - current_counts[tool]["calls"] += c_inc - current_counts[tool]["successes"] += s_inc - current_counts[tool]["failures"] += f_inc - - instance.tool_call_counts = current_counts - from sqlalchemy.orm.attributes import flag_modified - flag_modified(instance, "tool_call_counts") - - if not safe_commit(): return - - # Accumulate round metrics into task local totals (for final logging/trace if needed) - total_task_input_tokens += round_input_tokens - total_task_output_tokens += round_output_tokens - # (total_task_tool_counts merging logic removed here as we update DB directly) - - exec_duration = time.time() - execution_start - round_sub_events.append({"name": "Agent execution", "duration": round(exec_duration, 2), "timestamp": time.time()}) - - # Execution complete - instance = db.query(AgentInstance).filter(AgentInstance.id == agent_id).first() - # 4.3: Post-processing to compress boilerplate from reasoning - final_reasoning = AgentExecutor._compress_reasoning(instance.last_reasoning or "") - - final_result = { - "response": final_answer, - "reasoning": final_reasoning - } - - # --- EVALUATION PHASE (Co-Worker Loop) --- - if evaluator and final_answer: - # Await parallel rubric task if it exists and hasn't been captured yet - if rubric_task and not rubric_content: - instance.evaluation_status = "๐Ÿ“‹ Co-Worker: Finalizing parallel rubric.md..." - if not safe_commit(): return - rubric_content = await rubric_task - if not rubric_content: - rubric_content = "# Evaluation Rubric\nComplete the requested task with high technical accuracy." - - instance.evaluation_status = "evaluating" - if not safe_commit(): return - - # Initial status write to feedback.md so it's not "Session Started" - evaluator.assistant.write( - evaluator.mesh_node_id, - ".cortex/feedback.md", - f"# ๐Ÿ•ต๏ธ Co-Worker Audit (Attempt {current_attempt + 1})\n\nAudit initiated. Reviewing technical accuracy and alignment...", - session_id=evaluator.sync_workspace_id - ) - - # Stage 2A: Blind Rating - instance = db.query(AgentInstance).filter(AgentInstance.id == agent_id).first() - instance.evaluation_status = f"๐Ÿ•ต๏ธ Co-Worker (Rd {current_attempt + 1}): Auditing result against criteria..." - if not safe_commit(): return - - if registry and instance.mesh_node_id: - registry.emit(instance.mesh_node_id, "status_update", {"evaluation_status": instance.evaluation_status}) - - # transparency context for Auditor - available_tools = [] - if services and hasattr(services, "tool_service"): - try: - # We pass features="chat" or relevant filter? - # For the auditor, we want to know what the agent *could* have called. - tools = services.tool_service.get_available_tools(db, instance.user_id, session_id=session_id) - available_tools = [t["function"]["name"] for t in tools] - except Exception as te: - logger.warning(f"Auditor failed to fetch tool list: {te}") - - partner_ctx = { - "system_prompt": template.system_prompt_content, - "skills": available_tools - } - - blind_eval = await evaluator.evaluate_blind(prompt, rubric_content, final_answer, partner_context=partner_ctx) - score = blind_eval.get("score", 0) - justification = blind_eval.get("justification", "") - blind_duration = blind_eval.get("duration", 0) - round_sub_events.append({"name": "Co-Worker review", "duration": round(blind_duration, 2), "timestamp": time.time()}) - - # Update instance with latest score - db.query(AgentInstance).filter(AgentInstance.id == agent_id).update({"latest_quality_score": score}) - if not safe_commit(): return - - # Check Threshold - if score >= rework_threshold: - instance.evaluation_status = f"โœ… PASSED (Score {score}%)" - if not safe_commit(): return - - if registry and instance.mesh_node_id: - registry.emit(instance.mesh_node_id, "status_update", {"evaluation_status": instance.evaluation_status}) - - # Log final success feedback to workspace even if no rework was needed - success_feedback = f"# Evaluation Passed\n\n**Score**: {score}/100\n\n**Justification**:\n{justification}" - evaluator.assistant.write(evaluator.mesh_node_id, ".cortex/feedback.md", success_feedback, session_id=evaluator.sync_workspace_id) - - # M3: Aggregate total success duration and truncated summary for timeline - total_success_duration = sum(e.get("duration", 0) for e in round_sub_events) - summary_reason = justification.split('\n\n')[0] if '\n\n' in justification else justification - if len(summary_reason) > 250: summary_reason = summary_reason[:247] + "..." - - await evaluator.log_round(current_attempt + 1, score, summary_reason, "Final answer passed quality gate.", sub_events=round_sub_events, duration=total_success_duration) - - # PERSISTENCE: Save this audit to the message for historical drill-down - if last_assistant_msg_id: - current_history = [] - try: - cmd_res = evaluator.assistant.dispatch_single(evaluator.mesh_node_id, "cat .cortex/history.log", session_id=evaluator.sync_workspace_id) - current_history = json.loads(cmd_res.get("stdout", "[]")) - except: pass - - db.query(Message).filter(Message.id == last_assistant_msg_id).update({ - "message_metadata": { - "evaluation": { - "rubric": rubric_content, - "feedback": success_feedback, - "history": current_history, - "score": score, - "passed": True - } - } - }) - if not safe_commit(): return - - break # Success! - - # Check Rework Limits - if current_attempt >= max_rework_attempts: - instance.evaluation_status = "failed_limit" - instance.last_error = f"Co-Worker Gate: Quality fell below {rework_threshold}% after {max_rework_attempts} attempts." - if not safe_commit(): return - - # M3: Aggregate total failure duration and truncated summary for timeline - total_fail_duration = sum(e.get("duration", 0) for e in round_sub_events) - summary_reason = justification.split('\n\n')[0] if '\n\n' in justification else justification - if len(summary_reason) > 250: summary_reason = summary_reason[:247] + "..." - - await evaluator.log_round(current_attempt + 1, score, summary_reason, "Failed quality gate after max attempts.", sub_events=round_sub_events, duration=total_fail_duration) - - # PERSISTENCE: Save this status to the message for historical drill-down - if last_assistant_msg_id: - current_history = [] - try: - cmd_res = evaluator.assistant.dispatch_single(evaluator.mesh_node_id, "cat .cortex/history.log", session_id=evaluator.sync_workspace_id) - current_history = json.loads(cmd_res.get("stdout", "[]")) - except: pass - - db.query(Message).filter(Message.id == last_assistant_msg_id).update({ - "message_metadata": { - "evaluation": { - "rubric": rubric_content, - "feedback": f"# Evaluation Failed (Max Attempts)\n\n**Score**: {score}/100\n\n**Justification**:\n{justification}", - "history": current_history, - "score": score, - "passed": False - } - } - }) - if not safe_commit(): return - - break # No more reworks - - # Stage Delta (Gap Analysis) - instance = db.query(AgentInstance).filter(AgentInstance.id == agent_id).first() - instance.evaluation_status = f"๐Ÿง  Co-Worker (Rd {current_attempt + 1}): Analyzing reasoning delta..." - if not safe_commit(): return - - if registry and instance.mesh_node_id: - registry.emit(instance.mesh_node_id, "status_update", {"evaluation_status": instance.evaluation_status}) - - # Fetch history for context - cmd_res = evaluator.assistant.dispatch_single(instance.mesh_node_id, "cat .cortex/history.log", session_id=evaluator.sync_workspace_id) - hist_log = [] - try: hist_log = json.loads(cmd_res.get("stdout", "[]")) - except: pass - - delta_start = time.time() - if current_attempt == 2: - directive_feedback = await evaluator.generate_compaction_summary(prompt, hist_log) - round_sub_events.append({"name": "Context compaction", "duration": round(time.time() - delta_start, 2), "timestamp": time.time()}) - else: - directive_feedback = await evaluator.evaluate_delta(prompt, rubric_content, justification, hist_log, final_reasoning, partner_context=partner_ctx) - round_sub_events.append({"name": "Delta analysis", "duration": round(time.time() - delta_start, 2), "timestamp": time.time()}) - - # M3: Categorization & Duration Metrics - full_audit_stream = f"# Co-Worker Review (Attempt {current_attempt + 1})\n\n**Justification**:\n{justification}\n\n---\n\n{directive_feedback}" - evaluator.assistant.write(evaluator.mesh_node_id, ".cortex/feedback.md", full_audit_stream, session_id=evaluator.sync_workspace_id) - - # Extract high-density summary for timeline - summary_reason = justification.split('\n\n')[0] if '\n\n' in justification else justification - if len(summary_reason) > 250: - summary_reason = summary_reason[:247] + "..." - - # Calculate total round duration - total_round_duration = sum(e.get("duration", 0) for e in round_sub_events) - - # Log this round with summary and duration - await evaluator.log_round(current_attempt + 1, score, summary_reason, directive_feedback, sub_events=round_sub_events, duration=total_round_duration) - - # PERSISTENCE: Save this audit to the message for historical drill-down - if last_assistant_msg_id: - current_history = [] - try: - cmd_res = evaluator.assistant.dispatch_single(evaluator.mesh_node_id, "cat .cortex/history.log", session_id=evaluator.sync_workspace_id) - current_history = json.loads(cmd_res.get("stdout", "[]")) - except: pass - - db.query(Message).filter(Message.id == last_assistant_msg_id).update({ - "message_metadata": { - "evaluation": { - "rubric": rubric_content, - "feedback": full_audit_stream, - "history": current_history, - "score": score, - "passed": False - } - } - }) - if not safe_commit(): return - - # Trigger next iteration - current_prompt = f"### CO-WORKER DIRECTIVE (ATTEMPT {current_attempt + 1})\n\n{directive_feedback}\n\nProceed with rework." - current_attempt += 1 - - db.add(Message(session_id=session_id, sender="system", content=f"โš ๏ธ **Co-Worker**: Quality check FAILED (Score: {score}/100). Requesting rework...")) - db.query(AgentInstance).filter(AgentInstance.id == agent_id).update({"evaluation_status": f"โš ๏ธ Rework Triggered ({score}%)"}) - if not safe_commit(): return - continue # Start next loop iteration - else: - break # No co-worker or no answer - - except Exception as e: - import traceback - print(f"[AgentExecutor] RAG attempt failed for {agent_id}: {e}") - print(traceback.format_exc()) - instance = db.query(AgentInstance).filter(AgentInstance.id == agent_id).first() - if instance: - instance.status = "error_suspended" - instance.last_error = str(e) - # Even on error, try to sync tokens used so far - instance.total_input_tokens = (instance.total_input_tokens or 0) + total_task_input_tokens - instance.total_output_tokens = (instance.total_output_tokens or 0) + total_task_output_tokens - if not safe_commit(): return - return { - "status": "error", - "response": f"Execution failed: {str(e)}", - "reasoning": instance.last_reasoning if instance else "" - } - - # Final loop cleanup & Stats - instance = db.query(AgentInstance).filter(AgentInstance.id == agent_id).first() - if instance: - # Update metrics regardless of final status (as long as we finished the loop) - elapsed = int(time.time() - loop_start) - instance.total_running_time_seconds = (instance.total_running_time_seconds or 0) + elapsed - - # Success calculation - final_score = getattr(instance, 'latest_quality_score', 0) or 0 - threshold = rework_threshold or 80 - - if instance.status == "active": - instance.status = "idle" - # Only increment successful runs if we didn't end in an error state and passed threshold (or were unchecked) - if final_score >= threshold or not co_worker_enabled: - instance.successful_runs = (instance.successful_runs or 0) + 1 - - # Clear reasoning as the task is now complete - instance.last_reasoning = None - if not safe_commit(): return - - if evaluator: - total_elapsed = time.time() - loop_start - await evaluator.log_event("Process Completed", f"Lifecycle finished successfully after {current_attempt + 1} rounds.", duration=total_elapsed) - + # Phase 3: Finalization + await self._finalize_execution() return final_result except Exception as e: - import traceback - print(f"[AgentExecutor] Unhandled loop error: {e}") - print(traceback.format_exc()) - instance = db.query(AgentInstance).filter(AgentInstance.id == agent_id).first() - if instance: - instance.status = "error_suspended" - instance.last_error = f"Unhandled loop error: {str(e)}" - if not safe_commit(): return - return { - "status": "error", - "response": "Internal server error during execution.", - "reasoning": instance.last_reasoning if instance else "" - } + logger.error(f"[AgentExecutor] Unhandled loop error: {e}\n{traceback.format_exc()}") + return await self._handle_fatal_error(e) finally: heartbeat_task.cancel() - db.close() + + def _safe_commit(self) -> bool: + """Commits current DB changes with robust error handling for concurrent deletions.""" + try: + self.db.commit() + return True + except (ObjectDeletedError, StaleDataError): + logger.info(f"[AgentExecutor] Agent {self.agent_id} deleted or modified externally. Exiting.") + self.db.rollback() + return False + except Exception as e: + logger.error(f"[AgentExecutor] Commit failed: {e}") + self.db.rollback() + raise + + async def _initialize_instance(self, prompt: str) -> bool: + """Loads and prepares the agent instance for execution.""" + self.instance = self.db.query(AgentInstance).filter(AgentInstance.id == self.agent_id).first() + if not self.instance or not prompt: + return False + + self.template = self.db.query(AgentTemplate).filter(AgentTemplate.id == self.instance.template_id).first() + if not self.template: + self.instance.status = "error_suspended" + self.instance.last_error = f"Template '{self.instance.template_id}' not found." + self._safe_commit() + return False + + # Initialize base metrics and status + self.instance.last_heartbeat = datetime.utcnow() + self.instance.status = "active" + self.instance.last_error = None + self.instance.total_runs = (self.instance.total_runs or 0) + 1 + + # Area 4.2: Optional history clearing + session = self.instance.session + if session and getattr(session, "auto_clear_history", False): + self.db.query(Message).filter(Message.session_id == session.id).delete(synchronize_session=False) + + return self._safe_commit() + + async def _heartbeat_loop(self): + """Maintains the 'active' lease in the background.""" + while True: + await asyncio.sleep(60) + try: + # We use a nested session to avoid interfering with the main thread's flushing + async with asyncio.Lock(): # Simple guard + inner_db = SessionLocal() + try: + obj = inner_db.query(AgentInstance).filter(AgentInstance.id == self.agent_id).first() + if not obj or obj.status not in ["active", "starting"]: + break + obj.last_heartbeat = datetime.utcnow() + inner_db.commit() + finally: + inner_db.close() + except Exception: + break + + async def _setup_evaluation(self, prompt: str, skip_coworker: bool) -> Optional[asyncio.Task]: + """Configures the Co-Worker Auditor and launches rubric generation if enabled.""" + if skip_coworker or not getattr(self.template, "co_worker_quality_gate", False): + return None + + # Resolve provider for evaluation + provider = self._resolve_eval_provider() + workspace_id = self.instance.session.sync_workspace_id if self.instance.session else str(self.instance.session_id) + + self.evaluator = HarnessEvaluator(self.db, self.agent_id, self.instance.mesh_node_id, workspace_id, provider, self.services) + + # Update UI status + self.instance.status = "starting" + self.instance.evaluation_status = "๐Ÿ“‹ Co-Worker: Initiating parallel rubric & mission setup..." + self.instance.current_rework_attempt = 0 + self._safe_commit() + + # Parallel rubric generation task + return asyncio.create_task(self._rubric_generator_bg(prompt, provider, workspace_id)) + + def _resolve_eval_provider(self): + """Determines the LLM provider configuration for the Auditor.""" + from app.config import settings + session = self.instance.session + provider_name = getattr(session, "provider_name", None) or settings.ACTIVE_LLM_PROVIDER + + base_key = provider_name.split("/")[0] if "/" in provider_name else provider_name + llm_prefs = {} + if session and session.user and session.user.preferences: + llm_prefs = session.user.preferences.get("llm", {}).get("providers", {}).get(base_key, {}) + + if (not llm_prefs or not llm_prefs.get("api_key") or "*" in str(llm_prefs.get("api_key"))) and self.user_service: + system_prefs = self.user_service.get_system_settings(self.db) + system_provider_prefs = system_prefs.get("llm", {}).get("providers", {}).get(base_key, {}) + if not system_provider_prefs or not system_provider_prefs.get("model"): + active_key = system_prefs.get("llm", {}).get("active_provider") + if active_key: + system_provider_prefs = system_prefs.get("llm", {}).get("providers", {}).get(active_key, {}) + provider_name = active_key + + if system_provider_prefs: + merged = system_provider_prefs.copy() + if llm_prefs: merged.update({k: v for k, v in llm_prefs.items() if v}) + llm_prefs = merged + + return get_llm_provider( + provider_name, + model_name="" if "/" in provider_name else llm_prefs.get("model", ""), + api_key_override=llm_prefs.get("api_key"), + **{k: v for k, v in llm_prefs.items() if k not in ["api_key", "model"]} + ) + + async def _rubric_generator_bg(self, prompt: str, provider, workspace_id) -> Optional[str]: + """Background helper for non-blocking rubric generation.""" + bg_db = SessionLocal() + try: + bg_evaluator = HarnessEvaluator(bg_db, self.agent_id, self.instance.mesh_node_id, workspace_id, provider, self.services) + await bg_evaluator.initialize_cortex() + return await bg_evaluator.generate_rubric(prompt) + except Exception as e: + logger.error(f"[AgentExecutor] Rubric generation failed: {e}") + return None + finally: + bg_db.close() + + async def _run_rework_loop(self, prompt: str, rubric_task: Optional[asyncio.Task]) -> Dict[str, Any]: + """Core loop handling primary execution and Auditor-driven reworks.""" + max_reworks = getattr(self.template, "max_rework_attempts", 3) + rework_threshold = getattr(self.template, "rework_threshold", 80) + + current_prompt = prompt + current_attempt = 0 + rubric_content = "" + loop_start_time = time.time() + + while current_attempt <= max_reworks: + round_metrics = {"start": time.time(), "sub_events": [], "input_tokens": 0, "output_tokens": 0, "tool_counts": {}} + + try: + # 1. Main Agent Execution (Streaming) + result = await self._execute_main_agent(current_prompt, current_attempt, round_metrics) + if not result: break + + # 2. Evaluation Phase + if self.evaluator and result.get("response"): + # Ensure rubric is ready + if rubric_task and not rubric_content: + rubric_content = await rubric_task or "# Evaluation Rubric\nTask completion required." + rubric_task = None + + passed, audit_summary = await self._perform_quality_audit( + current_prompt, rubric_content, result, current_attempt, rework_threshold, round_metrics + ) + + if passed: break + + # Plan Rework + current_prompt = f"### CO-WORKER DIRECTIVE (ATTEMPT {current_attempt + 1})\n\n{audit_summary}\n\nProceed with rework." + current_attempt += 1 + else: + break # No auditor or no response to evaluate + + except Exception as e: + logger.error(f"[AgentExecutor] Round {current_attempt} failed: {e}\n{traceback.format_exc()}") + return await self._handle_round_error(round_metrics, e) + + # If we exited the loop without a break (pass), we hit the limit + score = getattr(self.instance, 'latest_quality_score', 0) or 0 + threshold = getattr(self.template, 'rework_threshold', 80) or 80 + if score < threshold and getattr(self.template, "co_worker_quality_gate", False): + self.instance.evaluation_status = f"๐Ÿšซ failed_limit ({score}%)" + self._safe_commit() + + return result + + async def _execute_main_agent(self, prompt: str, attempt: int, metrics: Dict) -> Optional[Dict]: + """Orchestrates the LLM stream for the primary task.""" + self.instance.last_reasoning = "" + self.instance.status = "active" + self.instance.evaluation_status = f"๐Ÿค– Main Agent (Rd {attempt + 1}): Executing..." + if not self._safe_commit(): return None + + final_answer = "" + last_msg_id = None + registry = getattr(self.services.rag_service, "node_registry_service", None) + + # Buffer management for performance + sync_buffer = "" + last_sync = time.time() + + async for event in self.services.rag_service.chat_with_rag( + db=self.db, + session_id=self.instance.session_id, + prompt=prompt, + provider_name=self.instance.session.provider_name if self.instance.session else None, + load_faiss_retriever=False, + user_service=self.user_service + ): + e_type = event.get("type") + if e_type == "finish": + last_msg_id = event.get("message_id") + metrics["tool_counts"] = event.get("tool_counts", {}) + if metrics["input_tokens"] == 0: metrics["input_tokens"] = event.get("input_tokens", 0) + if metrics["output_tokens"] == 0: metrics["output_tokens"] = event.get("output_tokens", 0) + final_answer = event.get("full_answer", "") + elif e_type == "token_counted": + usage = event.get("usage", {}) + metrics["input_tokens"] += usage.get("prompt_tokens", 0) + metrics["output_tokens"] += usage.get("completion_tokens", 0) + elif e_type in ("reasoning", "content"): + content = event.get("content", "") + if e_type == "content": final_answer += content + + sync_buffer += content + if time.time() - last_sync > 2.0 or len(sync_buffer) > 200: + self.instance.last_reasoning = (self.instance.last_reasoning or "") + sync_buffer + sync_buffer = "" + last_sync = time.time() + if not self._safe_commit(): return None + + if registry and self.instance.mesh_node_id: + registry.emit(self.instance.mesh_node_id, "reasoning", {"content": content, "agent_id": self.agent_id, "session_id": self.instance.session_id}) + + # Final persistence refresh + self.instance.last_reasoning = (self.instance.last_reasoning or "") + sync_buffer + self.instance.total_input_tokens = (self.instance.total_input_tokens or 0) + metrics["input_tokens"] + self.instance.total_output_tokens = (self.instance.total_output_tokens or 0) + metrics["output_tokens"] + self._merge_tool_counts(metrics["tool_counts"]) + self._safe_commit() + + metrics["duration"] = time.time() - metrics["start"] + metrics["sub_events"].append({"name": "Agent execution", "duration": round(metrics["duration"], 2), "timestamp": time.time()}) + metrics["last_msg_id"] = last_msg_id + + return { + "response": final_answer, + "reasoning": self._compress_reasoning(self.instance.last_reasoning or "") + } + + async def _perform_quality_audit(self, prompt, rubric, result, attempt, threshold, metrics) -> (bool, str): + """Runs the Auditor's blind rating and delta analysis.""" + self.instance.evaluation_status = f"๐Ÿ•ต๏ธ Co-Worker (Rd {attempt + 1}): Auditing result..." + self._safe_commit() + + # Auditor technical context + available_tools = [] + if hasattr(self.services, "tool_service"): + tools = self.services.tool_service.get_available_tools(self.db, self.instance.user_id, session_id=self.instance.session_id) + available_tools = [t["function"]["name"] for t in tools] + + partner_ctx = {"system_prompt": self.template.system_prompt_content, "skills": available_tools} + + # 1. Blind Rating + blind_eval = await self.evaluator.evaluate_blind(prompt, rubric, result["response"], partner_context=partner_ctx) + score = blind_eval.get("score", 0) + just_msg = blind_eval.get("justification", "") + metrics["sub_events"].append({"name": "Co-Worker review", "duration": round(blind_eval.get("duration", 0), 2), "timestamp": time.time()}) + + self.instance.latest_quality_score = score + self._safe_commit() + + if score >= threshold: + await self._record_audit_passed(score, just_msg, rubric, metrics, attempt) + return True, just_msg + + # 2. Delta Analysis (Directive for rework) + self.instance.evaluation_status = f"๐Ÿง  Co-Worker (Rd {attempt + 1}): Analyzing delta..." + self._safe_commit() + + # Fetch history for context + hist_log = self._fetch_tester_history() + + delta_start = time.time() + if attempt >= 2: # Compaction trigger + directive = await self.evaluator.generate_compaction_summary(prompt, hist_log) + metrics["sub_events"].append({"name": "Context compaction", "duration": round(time.time() - delta_start, 2), "timestamp": time.time()}) + else: + directive = await self.evaluator.evaluate_delta(prompt, rubric, just_msg, hist_log, result["reasoning"], partner_context=partner_ctx) + metrics["sub_events"].append({"name": "Delta analysis", "duration": round(time.time() - delta_start, 2), "timestamp": time.time()}) + + await self._record_audit_failed(score, just_msg, directive, rubric, metrics, attempt, hist_log) + return False, directive + + def _fetch_tester_history(self) -> List: + """Retrieves raw history from the Auditor's workspace logs.""" + try: + res = self.evaluator.assistant.dispatch_single(self.instance.mesh_node_id, "cat .cortex/history.log", session_id=self.evaluator.sync_workspace_id) + return json.loads(res.get("stdout", "[]")) + except: return [] + + async def _record_audit_passed(self, score, justification, rubric, metrics, attempt): + """Records a successful quality gate pass in history and DB.""" + self.instance.evaluation_status = f"โœ… PASSED (Score {score}%)" + self._safe_commit() + + feedback = f"# Evaluation Passed\n\n**Score**: {score}/100\n\n**Justification**:\n{justification}" + self.evaluator.assistant.write(self.instance.mesh_node_id, ".cortex/feedback.md", feedback, session_id=self.evaluator.sync_workspace_id) + + duration = sum(e.get("duration", 0) for e in metrics["sub_events"]) + summary = self._truncate_text(justification, 250) + await self.evaluator.log_round(attempt + 1, score, summary, "Final answer passed quality gate.", sub_events=metrics["sub_events"], duration=duration) + + self._update_message_metadata(metrics.get("last_msg_id"), rubric, feedback, score, passed=True) + + async def _record_audit_failed(self, score, justification, directive, rubric, metrics, attempt, history): + """Records a quality gate failure and triggers rework protocol.""" + full_audit = f"# Co-Worker Review (Attempt {attempt + 1})\n\n**Justification**:\n{justification}\n\n---\n\n{directive}" + self.evaluator.assistant.write(self.instance.mesh_node_id, ".cortex/feedback.md", full_audit, session_id=self.evaluator.sync_workspace_id) + + duration = sum(e.get("duration", 0) for e in metrics["sub_events"]) + summary = self._truncate_text(justification, 250) + await self.evaluator.log_round(attempt + 1, score, summary, directive, sub_events=metrics["sub_events"], duration=duration) + + self.db.add(Message(session_id=self.instance.session_id, sender="system", content=f"โš ๏ธ **Co-Worker**: Quality check FAILED ({score}/100). Requesting rework...")) + self._update_message_metadata(metrics.get("last_msg_id"), rubric, full_audit, score, passed=False, history=history) + + self.instance.evaluation_status = f"โš ๏ธ Rework Triggered ({score}%)" + self._safe_commit() + + def _update_message_metadata(self, msg_id, rubric, feedback, score, passed, history=None): + """Enriches the assistant message with deep evaluation metadata.""" + if not msg_id: return + self.db.query(Message).filter(Message.id == msg_id).update({ + "message_metadata": { + "evaluation": { + "rubric": rubric, "feedback": feedback, "history": history or [], + "score": score, "passed": passed + } + } + }) + self._safe_commit() + + def _merge_tool_counts(self, round_counts: Dict): + """Merges round-level tool metrics into the global agent stats.""" + if not round_counts: return + counts = (self.instance.tool_call_counts or {}).copy() + for tool, results in round_counts.items(): + if tool not in counts: counts[tool] = {"calls": 0, "successes": 0, "failures": 0} + if isinstance(counts[tool], int): counts[tool] = {"calls": counts[tool], "successes": counts[tool], "failures": 0} + + c_inc = results.get("calls", results) if isinstance(results, dict) else results + s_inc = results.get("successes", results) if isinstance(results, dict) else results + f_inc = results.get("failures", 0) if isinstance(results, dict) else 0 + + counts[tool]["calls"] += c_inc + counts[tool]["successes"] += s_inc + counts[tool]["failures"] += f_inc + + self.instance.tool_call_counts = counts + from sqlalchemy.orm.attributes import flag_modified + flag_modified(self.instance, "tool_call_counts") + + async def _finalize_execution(self): + """Performs final state cleanup, metric aggregation, and logging.""" + if self.instance.status == "active": + self.instance.status = "idle" + score = getattr(self.instance, 'latest_quality_score', 0) or 0 + threshold = getattr(self.template, 'rework_threshold', 80) or 80 + if score >= threshold or not getattr(self.template, "co_worker_quality_gate", False): + self.instance.successful_runs = (self.instance.successful_runs or 0) + 1 + + self.instance.last_reasoning = None + self._safe_commit() + + if self.evaluator: + await self.evaluator.log_event("Process Completed", "Lifecycle finished successfully.") + + async def _handle_round_error(self, metrics, error): + """Gracefully handles errors within a specific rework round.""" + self.instance.status = "error_suspended" + self.instance.last_error = str(error) + self.instance.total_input_tokens = (self.instance.total_input_tokens or 0) + metrics["input_tokens"] + self.instance.total_output_tokens = (self.instance.total_output_tokens or 0) + metrics["output_tokens"] + self._safe_commit() + return {"status": "error", "response": f"Execution failed: {str(error)}", "reasoning": ""} + + async def _handle_fatal_error(self, error): + """Gracefully handles unhandled exceptions across the entire lifecycle.""" + self.instance = self.db.query(AgentInstance).filter(AgentInstance.id == self.agent_id).first() + if self.instance: + self.instance.status = "error_suspended" + self.instance.last_error = f"Unhandled fatal error: {str(error)}" + self._safe_commit() + return {"status": "error", "response": "Internal server error during execution."} + + @staticmethod + def _truncate_text(text: str, length: int) -> str: + """Splits text by double newline to get first paragraph or truncates to length.""" + summary = text.split('\n\n')[0] if '\n\n' in text else text + return f"{summary[:length-3]}..." if len(summary) > length else summary @staticmethod def _compress_reasoning(text: str) -> str: - """Deduplicates turn markers and collapses boilerplate using high-perf string logic.""" + """Deduplicates turn markers and collapses boilerplate using centralized patterns.""" if not text: return "" lines = text.splitlines(keepends=True) - cleaned = [] - last_important_line = "" + cleaned, last_line = [], "" - # Collapse repeating Turn headers and Strategy boilerplate without regex for line in lines: l_strip = line.strip() if not l_strip: cleaned.append(line) continue - l_lower = l_strip.lower() - - # Deduplicate the system thinking marker - is_turn_marker = "๐Ÿ›ฐ๏ธ" in line and "[turn" in l_lower and "thinking" in l_lower - if is_turn_marker: - if "๐Ÿ›ฐ๏ธ" in last_important_line and "[turn" in last_important_line.lower(): - continue - - # Collapse Strategy boilerplate - is_strategy = "strategy:" in l_lower and "executing orchestrated tasks" in l_lower - if is_strategy: - if "strategy:" in last_important_line.lower() and "executing orchestrated tasks" in last_important_line.lower(): - continue - + # Use central patterns to detect boilerplate without inline regex + if TURN_THINKING_MARKER.search(line) and TURN_THINKING_MARKER.search(last_line): + continue + if STRATEGY_BOILERPLATE.search(line) and STRATEGY_BOILERPLATE.search(last_line): + continue + cleaned.append(line) - last_important_line = l_strip + last_line = l_strip return "".join(cleaned).strip() - diff --git a/ai-hub/app/core/orchestration/harness_evaluator.py b/ai-hub/app/core/orchestration/harness_evaluator.py index 5d3c3c2..7d2c8e3 100644 --- a/ai-hub/app/core/orchestration/harness_evaluator.py +++ b/ai-hub/app/core/orchestration/harness_evaluator.py @@ -1,16 +1,22 @@ import logging import json import time -from typing import Dict, Any, List, Optional +from typing import Dict, Any, List, Optional, Tuple import os +import re from app.db.models.agent import AgentInstance, AgentTemplate from app.db import models from app.core.orchestration import Architect +from app.core._regex import ANSI_ESCAPE, FINAL_SCORE logger = logging.getLogger(__name__) class HarnessEvaluator: + """ + Quality gate auditor that evaluates agent performance against dynamic rubrics. + Refactored for structural clarity and modular evaluation flows. + """ def __init__(self, db, agent_id, mesh_node_id, sync_workspace_id, llm_provider, services): self.db = db self.agent_id = agent_id @@ -19,527 +25,141 @@ self.llm_provider = llm_provider self.services = services - # Resolve orchestrator assistant from services container self.orchestrator = getattr(services, "orchestrator", None) - self.assistant = None - if self.orchestrator: - self.assistant = self.orchestrator.assistant + self.assistant = self.orchestrator.assistant if self.orchestrator else None async def initialize_cortex(self): - """Creates .cortex/ directory and initializes history.log if missing.""" - if not self.assistant or not self.mesh_node_id: - logger.warning(f"[HarnessEvaluator] Assistant or mesh_node_id missing; skipping .cortex init.") - return - - # Ensure directory exists - self.assistant.dispatch_single( - self.mesh_node_id, - "mkdir -p .cortex", - session_id=self.sync_workspace_id - ) - - # Reset history for a fresh evaluation session - self.assistant.write( - self.mesh_node_id, - ".cortex/history.log", - "[]", - session_id=self.sync_workspace_id - ) - - # Reset feedback for the new run - self.assistant.write( - self.mesh_node_id, - ".cortex/feedback.md", - "# Session Started\n", - session_id=self.sync_workspace_id - ) + """Initializes the .cortex/ state directory on the mesh node.""" + if not self.assistant or not self.mesh_node_id: return + self.assistant.dispatch_single(self.mesh_node_id, "mkdir -p .cortex", session_id=self.sync_workspace_id) + self.assistant.write(self.mesh_node_id, ".cortex/history.log", "[]", session_id=self.sync_workspace_id) + self.assistant.write(self.mesh_node_id, ".cortex/feedback.md", "# Session Started\n", session_id=self.sync_workspace_id) async def log_event(self, name: str, details: str = "", duration: float = 0, event_type: str = "event", metadata: Dict = None): - """Records a generic event to the history log.""" + """Appends a lifecycle event to the history log.""" + history = self._read_history() + history.append({ + "type": event_type, "name": name, "details": details, + "duration": round(duration, 2), "timestamp": time.time(), "metadata": metadata or {} + }) + self._write_history(history) + + def _read_history(self) -> List[Dict]: + """Safely reads history.log from the node.""" + if not self.assistant: return [] + try: + res = self.assistant.dispatch_single(self.mesh_node_id, "cat .cortex/history.log", session_id=self.sync_workspace_id, timeout=5) + return json.loads(res.get("stdout", "[]")) if res.get("status") == "SUCCESS" else [] + except: return [] + + def _write_history(self, history: List[Dict]): + """Safely writes history.log back to the node.""" if not self.assistant: return try: - cmd_res = self.assistant.dispatch_single( - self.mesh_node_id, - "cat .cortex/history.log", - session_id=self.sync_workspace_id, - timeout=5 - ) - - history = [] - if cmd_res.get("status") == "SUCCESS": - try: - history = json.loads(cmd_res.get("stdout", "[]")) - except: - history = [] - - history.append({ - "type": event_type, - "name": name, - "details": details, - "duration": round(duration, 2), - "timestamp": time.time(), - "metadata": metadata or {} - }) - - self.assistant.write( - self.mesh_node_id, - ".cortex/history.log", - json.dumps(history, indent=2), - session_id=self.sync_workspace_id - ) + self.assistant.write(self.mesh_node_id, ".cortex/history.log", json.dumps(history, indent=2), session_id=self.sync_workspace_id) except Exception as e: - logger.error(f"[HarnessEvaluator] Event logging failed: {e}") + logger.error(f"[HarnessEvaluator] History write failed: {e}") async def ensure_coworker_ground_truth(self): - """ - SWARM ALIGNMENT: Ensures .coworker.md exists on the node. - If missing, it distills the main agent's instructions into a ground-truth document. - """ + """Bootstraps .coworker.md alignment doc if missing on node.""" if not self.assistant or not self.mesh_node_id: return - - # 1. Existence Check - check = self.assistant.dispatch_single( - self.mesh_node_id, - "ls .coworker.md", - session_id=self.sync_workspace_id, - timeout=5 - ) - if check.get("status") == "SUCCESS": - return # Already exists + check = self.assistant.dispatch_single(self.mesh_node_id, "ls .coworker.md", session_id=self.sync_workspace_id, timeout=5) + if check.get("status") == "SUCCESS": return - logger.info(f"[HarnessEvaluator] .coworker.md missing on {self.mesh_node_id}. Generating ground truth...") - - # 2. Source Material Discovery - source_instruction = "" + instr = self._get_agent_instructions() + sys_p = "You are the Swarm Alchemist. Distill instructions into a high-density .coworker.md project edict file." + try: + prediction = await self.llm_provider.acompletion(messages=[{"role":"system","content":sys_p},{"role":"user","content":f"Instructions:\n{instr}"}], stream=False) + self.assistant.write(self.mesh_node_id, ".coworker.md", prediction.choices[0].message.content, session_id=self.sync_workspace_id) + except Exception as e: logger.error(f"Ground truth failed: {e}") + + def _get_agent_instructions(self) -> str: + """Resolves the primary system instructions for the agent instance.""" try: instance = self.db.query(AgentInstance).filter(AgentInstance.id == self.agent_id).first() - if instance and instance.template: - # Priority: session override -> template path - if instance.session and instance.session.system_prompt_override: - source_instruction = instance.session.system_prompt_override - else: - # If it's a file path, we'd need to read it. If it's the raw text, use it. - # For now, we use a fallback if it looks like a short slug. - raw = instance.template.system_prompt_path or "" - if len(raw) > 100: - source_instruction = raw - else: - from app.core.orchestration.profiles import DEFAULT_PROMPT_TEMPLATE - source_instruction = DEFAULT_PROMPT_TEMPLATE - except Exception as e: - logger.warning(f"[HarnessEvaluator] Failed to fetch agent instructions for ground truth: {e}") - from app.core.orchestration.profiles import DEFAULT_PROMPT_TEMPLATE - source_instruction = DEFAULT_PROMPT_TEMPLATE + if instance and instance.session and instance.session.system_prompt_override: + return instance.session.system_prompt_override + return instance.template.system_prompt_content if instance and instance.template else "Follow standard protocols." + except: return "No instructions found." - # 3. AI Distillation - system_prompt = """You are the Swarm Alchemist. -Your goal is to distill the provided "Agent System Instructions" into a concise, high-density ".coworker.md" file. -This file acts as the "Ground Truth" for other agents working in the same swarm. - -It MUST include: -1. **Core Edicts**: The non-negotiable rules of this project. -2. **Architecture**: Key technical constraints (e.g., pathing, tool usage, sync folders). -3. **Alignment**: How a 'perfectly aligned' implementation should look. - -Keep it under 1000 tokens. Format as Markdown.""" - - user_prompt = f"Agent Instructions:\n{source_instruction}\n\nGenerate the .coworker.md Aligned Knowledge Base now." - - try: - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt} - ] - prediction = await self.llm_provider.acompletion(messages=messages, stream=False) - ground_truth = prediction.choices[0].message.content - - # 4. Persistence - self.assistant.write( - self.mesh_node_id, - ".coworker.md", - ground_truth, - session_id=self.sync_workspace_id - ) - logger.info(f"[HarnessEvaluator] Successfully bootstrapped .coworker.md for {self.agent_id}") - await self.log_event("Ground Truth Sync", ".coworker.md bootstrapped from Agent Template.") - except Exception as e: - logger.error(f"[HarnessEvaluator] Ground truth generation failed: {e}") - - async def generate_rubric(self, initial_prompt: str): - """Stage 1: Pre-Execution. Generate a task-specific rubric.md.""" + async def generate_rubric(self, initial_prompt: str) -> Optional[str]: + """Generates a task-specific evaluation rubric.""" if not self.assistant: return None - - # CLEANUP: Purge old rubric from the node to ensure a fresh start for the new request - try: - self.assistant.dispatch_single( - self.mesh_node_id, - "rm .cortex/rubric.md", - session_id=self.sync_workspace_id, - timeout=5 - ) - logger.debug(f"[HarnessEvaluator] Purged stale rubric on {self.mesh_node_id}") - except: - pass - - # BOOTSTRAP: Ensure Ground Truth / .coworker.md exists before generating rubric await self.ensure_coworker_ground_truth() - start = time.time() - - # --- File-Based Knowledge Discovery (Aligned with ClaudeCode) --- - coworker_context = "" + ctx = self._read_node_file(".coworker.md") + sys_p = f"You are a Quality Architect. Context:\n{ctx}\nGenerate a '# Evaluation Rubric' (0-100) for this task." try: - cmd_res = self.assistant.dispatch_single( - self.mesh_node_id, - "cat .coworker.md", - session_id=self.sync_workspace_id, - timeout=5 - ) - if cmd_res.get("status") == "SUCCESS": - coworker_context = f"\n\nPROJECT-SPECIFIC CONTEXT (from .coworker.md):\n{cmd_res.get('stdout', '')}" - except: - pass # Silently continue if cat fails - - system_prompt = f"""You are a Quality Control Architect for a live infrastructure swarm. -Your task is to analyze a user request and generate a specific Evaluation Rubric in Markdown. - -## Context Discovery (Architectural Constraints): -{coworker_context or "No specific .coworker.md found. Use general best practices for modern infrastructure and swarm orchestration."} - -The Rubric MUST include: -1. **Expectations**: A checklist of specific results the agent should satisfy for this specific task. -2. **Core Rubric**: A quantitative scoring guide (0-100) across these dimensions: - - **Quality**: Tone, structure, and readability. - - **Accuracy**: Completeness and technical correctness. - - **Efficiency (Non-AI Alike)**: Adherence to edicts. - -Format as a clean Markdown file with exactly one '# Evaluation Rubric' title.""" - - user_prompt = f"Target Prompt: \"{initial_prompt}\"\n\nConstruct the request-specific rubric.md now." - - try: - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt} - ] - prediction = await self.llm_provider.acompletion(messages=messages, stream=False) - rubric_content = prediction.choices[0].message.content - - # Save to node - self.assistant.write( - self.mesh_node_id, - ".cortex/rubric.md", - rubric_content, - session_id=self.sync_workspace_id - ) - - await self.log_event("Rubric Generation", "Task-specific evaluation criteria established.", duration=time.time() - start) - return rubric_content + res = await self.llm_provider.acompletion(messages=[{"role":"system","content":sys_p},{"role":"user","content":f"Task: {initial_prompt}"}], stream=False) + content = res.choices[0].message.content + self.assistant.write(self.mesh_node_id, ".cortex/rubric.md", content, session_id=self.sync_workspace_id) + return content except Exception as e: - logger.error(f"[HarnessEvaluator] Rubric generation failed: {e}") + logger.error(f"Rubric fail: {e}") return None - async def evaluate_blind(self, initial_prompt: str, rubric_content: str, result_content: str, partner_context: Dict = None) -> Dict[str, Any]: - """Stage 2A: The Blind Rating (Absolute Objectivity). Uses tools to inspect results.""" - start = time.time() - - partner_info = "" - if partner_context: - partner_info = f""" -## PARTNER PROFILE (Main Agent context): -SYSTEM PROMPT: -{partner_context.get('system_prompt', 'N/A')} - -AVAILABLE SKILLS (Tools): -{partner_context.get('skills', 'N/A')} ---- -""" - - system_prompt = f"""You are the Co-Worker Evaluator (Blind Auditor). -Your goal is to perform a BLIND evaluation of the Main Agent's work. -You have NO knowledge of previous rounds or internal reasoning. You only see the goal and the result. - -Original Request: {initial_prompt} -{partner_info} - -Current Result: ---- -{result_content} ---- - -Rubric: -{rubric_content} - -EDICTS: -- Don't add features/refactors beyond what was asked. -- Don't add docstrings/comments not explicitly requested. -- Don't create helpers/utilities for one-time operations. - -MISSION: -1. Review the Current Result above. -2. If the task involves code or files, explore the workspace using your tools (ls, cat, etc.) to verify truth. -3. Assign a numerical score (0-100) and a brief justification. -Your final response MUST end with exactly: FINAL_SCORE: [number]""" - - res = await self._run_evaluator_agent(system_prompt, "Perform Blind Evaluation of the result state.") - res["duration"] = time.time() - start - return res - - async def evaluate_delta(self, initial_prompt: str, rubric_content: str, blind_justification: str, history_log: List[Dict[str, Any]], transcript: str, partner_context: Dict = None) -> str: - """Stage 2B: The Delta Analysis. Identifies gaps by comparing result to reasoning transcript.""" - start = time.time() - - historical_context = "Historical Rework Instructions (Gap Context):\n" - for entry in history_log: - if entry.get("type") == "attempt": - historical_context += f"- Attempt {entry['round']}: {entry.get('reason', 'N/A')}\n" - - system_prompt = f"""You are the Co-Worker Quality Architect (Delta Analyst). -The Blind Evaluator assigned a score based solely on the file result, but now we must bridge the gap. -You see the FULL mental transcript of how the Main Agent reached this state. - -Original Request: {initial_prompt} - -Rubric: -{rubric_content} - -Blind Evaluation Justification: -{blind_justification} - -{historical_context} - -Main Agent Execution Transcript: ---- -{transcript} ---- - -## PARTNER PROFILE (Main Agent context): -SYSTEM PROMPT: -{partner_context.get('system_prompt', 'N/A') if partner_context else 'N/A'} - -AVAILABLE SKILLS (Tools): -{partner_context.get('skills', 'N/A') if partner_context else 'N/A'} ---- - -MISSION: -1. Compare the Blind Justification with the Execution Transcript. -2. Identify why the Agent failed to meet the criteria despite its reasoning. -3. Identify the 'Delta' (what was intended vs. what was actually committed). -4. Spot 'Gold-Plating' (did they do extra work not asked for?). -5. Format feedback as ACTIONABLE COMMANDS (Directives). -6. If this is a repeat failure, provide a solution sketch or absolute directive. -Example: "Directive: Refactor auth.py:L10 to use settings.API_KEY instead of hardcoded string." - -Format as Markdown. Start with '# Rework Instructions'.""" - + def _read_node_file(self, path: str) -> str: + """Helper to cat a file from the node.""" try: - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": "Analyze the delta and generate the Directive-based rework feedback."} - ] - prediction = await self.llm_provider.acompletion(messages=messages, stream=False) - feedback = prediction.choices[0].message.content - - # Content as return-only to allow orchestrator to aggregate reports - return feedback - except Exception as e: - logger.error(f"[HarnessEvaluator] Delta analysis failed: {e}") - return f"Error during Delta Analysis: {str(e)}" + res = self.assistant.dispatch_single(self.mesh_node_id, f"cat {path}", session_id=self.sync_workspace_id, timeout=5) + return res.get("stdout", "") if res.get("status") == "SUCCESS" else "" + except: return "" + + async def evaluate_blind(self, prompt: str, rubric: str, result: str, partner_context: Dict = None) -> Dict[str, Any]: + """Performs an objective blind audit of the final response.""" + sys_p = f"You are the Blind Auditor. Request: {prompt}\nResult: {result}\nRubric: {rubric}\nRate 0-100. End with: FINAL_SCORE: [number]" + return await self._run_evaluator_agent(sys_p, "Audit the result state.") + + async def evaluate_delta(self, prompt: str, rubric: str, blind_just: str, history: List[Dict], transcript: str, partner_context: Dict = None) -> str: + """Analyzes the delta between reasoning and final output to generate rework directives.""" + hist_text = "\n".join([f"- {h['round']}: {h['reason']}" for h in history if h['type'] == 'attempt']) + sys_p = f"You are the Delta Architect. Request: {prompt}\nRubric: {rubric}\nAuditor: {blind_just}\nHistory: {hist_text}\nExecution: {transcript}\nGenerate '# Rework Instructions' as actionable directives." + try: + res = await self.llm_provider.acompletion(messages=[{"role":"system","content":sys_p},{"role":"user","content":"Generate directives."}], stream=False) + return res.choices[0].message.content + except Exception as e: return f"Delta Error: {e}" async def _run_evaluator_agent(self, system_prompt: str, user_request: str) -> Dict[str, Any]: - """Utility to run a context-stripped Architect loop for verification.""" + """Runs a sub-architect to perform complex multi-tool audit logic.""" architect = Architect() - - # Resolve tools for the evaluator (same as parent session) - tool_service = getattr(self.services.rag_service, "tool_service", None) - tools = [] - user_id = "agent-system" - if tool_service: - instance = self.db.query(AgentInstance).filter(AgentInstance.id == self.agent_id).first() - if instance and instance.session: - user_id = instance.session.user_id - tools = tool_service.get_available_tools(self.db, user_id, feature="agent_harness", session_id=instance.session_id) - - # --- Global Blueprint Discovery (Hub Manifesto) --- - global_manifesto = "" - import os - manifesto_path = "/app/blueprints/manifesto.md" - if os.path.exists(manifesto_path): - try: - with open(manifesto_path, "r") as f: - global_manifesto = f"\n\nGLOBAL HUB MANIFESTO (Project Vision):\n{f.read()}" - except: - pass - - # --- Dynamic Knowledge Snapshot (Discovery Step) --- - dynamic_snapshot = "" - try: - # We take a 'Live Snapshot' of the node's status (similar to ClaudeCode's gitStatus) - snap_res = self.assistant.dispatch_single( - self.mesh_node_id, - "uname -a && uptime && df -h /", - session_id=self.sync_workspace_id, - timeout=5 - ) - if snap_res.get("status") == "SUCCESS": - dynamic_snapshot = f"\n\nLIVE SYSTEM SNAPSHOT (Discovery Step):\n{snap_res.get('stdout', '')}" - except: - pass - - # --- File-Based Knowledge Discovery --- - context_chunks = [] - # Inject Global Manifesto if found - if global_manifesto: - context_chunks.append({ - "id": "hub_manifesto", - "content": global_manifesto, - "metadata": {"source": "hub_blueprints", "priority": "critical"} - }) - - # Inject the live snapshot as a transient knowledge chunk - if dynamic_snapshot: - context_chunks.append({ - "id": "runtime_telemetry", - "content": dynamic_snapshot, - "metadata": {"source": "runtime_discovery", "priority": "high"} - }) - - try: - cmd_res = self.assistant.dispatch_single( - self.mesh_node_id, - "cat .coworker.md", - session_id=self.sync_workspace_id, - timeout=5 - ) - if cmd_res.get("status") == "SUCCESS": - context_chunks.append({ - "id": ".coworker.md", - "content": cmd_res.get("stdout", ""), - "metadata": {"source": "local_filesystem", "priority": "high"} - }) - except: - pass - final_answer = "" - score = 0 - final_answer = "" - # Run Architect with a strictly limited profile to ensure snappy evaluation - # We pass no history to ensure "Blind" context + # Gathering context chunks + chunks = [{"id": "hub_manifesto", "content": self._read_local_manifesto(), "metadata": {"priority":"critical"}}] + coworker = self._read_node_file(".coworker.md") + if coworker: chunks.append({"id": ".coworker.md", "content": coworker, "metadata": {"priority":"high"}}) + async for event in architect.run( - question=user_request, - context_chunks=context_chunks, - history=[], - llm_provider=self.llm_provider, - tool_service=tool_service, - tools=tools, - db=self.db, - user_id=user_id, - sync_workspace_id=self.sync_workspace_id, - session_id=None, # Evaluation shouldn't append to session Message table - feature_name="agent_harness", - session_override=system_prompt + question=user_request, context_chunks=chunks, history=[], llm_provider=self.llm_provider, + db=self.db, sync_workspace_id=self.sync_workspace_id, feature_name="agent_harness", session_override=system_prompt ): if event["type"] == "content": final_answer += event["content"] - # Stream to feedback.md for UI visibility during evaluation - if self.assistant: - self.assistant.write( - self.mesh_node_id, - ".cortex/feedback.md", - f"# ๐Ÿ•ต๏ธ Co-Worker Audit in Progress...\n\n{final_answer}\n\n*(Analyzing results against rubric...)*", - session_id=self.sync_workspace_id - ) - elif event["type"] == "reasoning": - # Also include reasoning thoughts in the live feedback - thought = event["content"] - if self.assistant: - # Prepend reasoning if we want, or just append. - # Let's just use final_answer for now to keep it clean, - # but maybe add a header for the thoughts. - pass - elif event["type"] == "error": - logger.error(f"[HarnessEvaluator] Sub-evaluator fault: {event['content']}") + if self.assistant: self.assistant.write(self.mesh_node_id, ".cortex/feedback.md", f"# Audit in Progress...\n\n{final_answer}", session_id=self.sync_workspace_id) - import re - score = 0 - score_match = re.search(r"FINAL_SCORE:\s*(\d+)", final_answer) - if score_match: - try: score = int(score_match.group(1)) - except: score = 0 - - return { - "score": score, - "justification": final_answer - } + score_match = FINAL_SCORE.search(final_answer) + return {"score": int(score_match.group(1)) if score_match else 0, "justification": final_answer} - async def generate_compaction_summary(self, initial_prompt: str, history_log: List[Dict[str, Any]]) -> str: - """Micro-compaction strategy: Distills multiple failed rework attempts into a high-density directive.""" - if not self.llm_provider or not history_log: - return "COMPACTED DIRECTIVE: Resolve all remaining implementation gaps identified in previous rounds." - - failure_path = "" - for i, entry in enumerate(history_log): - if entry.get("type") == "attempt": - failure_path += f"Attempt {entry.get('round', i+1)} (Score: {entry.get('score', 0)}): {entry.get('reason', 'Unknown failure')}\n" - - system_prompt = """You are a Quality Control Compactor. -You have analyzed 2+ failed attempts to solve a task. -Your goal is to distill all previous critiques, failures, and delta-analysis reports into a SINGLE, high-density 'Compacted Directive'. -Remove all repetitive context, filler, and conversational markers. Focus only on the 'Critical Delta' that remains unfixed. - -Format exactly as: -# COMPACTED DIRECTIVE -[Dense listing of required fixes]""" - - user_prompt = f"Original Mission: {initial_prompt}\n\nFailure History:\n{failure_path}\n\nGenerate the Compacted Directive now." - + def _read_local_manifesto(self) -> str: + """Reads the regional hub manifesto for architectural alignment.""" try: - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt} - ] - prediction = await self.llm_provider.acompletion(messages=messages, stream=False) + with open("/app/blueprints/manifesto.md", "r") as f: return f.read() + except: return "" + + async def generate_compaction_summary(self, prompt: str, history: List[Dict]) -> str: + """Compacts multiple failure reports into a single dense rework directive.""" + fail_path = "\n".join([f"Attempt {h.get('round')}: {h.get('reason')}" for h in history if h['type'] == 'attempt']) + sys_p = "You are a Quality Compactor. Distill failures into a '# COMPACTED DIRECTIVE'." + try: + prediction = await self.llm_provider.acompletion(messages=[{"role":"system","content":sys_p},{"role":"user","content":f"Mission: {prompt}\nHistory:\n{fail_path}"}], stream=False) return prediction.choices[0].message.content - except Exception as e: - logger.error(f"[HarnessEvaluator] Compaction fault: {e}") - return "# COMPACTED DIRECTIVE\n- Critical implementation failure. Perform deep audit and stabilize core logic." + except: return "# COMPACTED DIRECTIVE\n- Stabilize core logic and address all remaining gaps." async def log_round(self, round_num: int, score: int, reason: str, feedback: str = "", sub_events: List[Dict] = None, duration: float = 0): - """Append-only record-keeping in history.log for a full attempt round.""" - if not self.assistant: return - try: - cmd_res = self.assistant.dispatch_single( - self.mesh_node_id, - "cat .cortex/history.log", - session_id=self.sync_workspace_id, - timeout=5 - ) - - history = [] - if cmd_res.get("status") == "SUCCESS": - try: - history = json.loads(cmd_res.get("stdout", "[]")) - except Exception as je: - logger.warning(f"[HarnessEvaluator] history.log corruption detected, resetting: {je}") - history = [] - - history.append({ - "type": "attempt", - "round": round_num, - "score": score, - "reason": reason, - "feedback": feedback, - "duration": round(duration, 2), - "timestamp": time.time(), - "sub_events": sub_events or [] - }) - - self.assistant.write( - self.mesh_node_id, - ".cortex/history.log", - json.dumps(history, indent=2), - session_id=self.sync_workspace_id - ) - except Exception as e: - logger.error(f"[HarnessEvaluator] Critical fault during log_round: {e}") + """Records a completed evaluation round in history.log.""" + history = self._read_history() + history.append({ + "type": "attempt", "round": round_num, "score": score, "reason": reason, "feedback": feedback, + "duration": round(duration, 2), "timestamp": time.time(), "sub_events": sub_events or [] + }) + self._write_history(history) diff --git a/ai-hub/app/core/services/agent.py b/ai-hub/app/core/services/agent.py new file mode 100644 index 0000000..84c3302 --- /dev/null +++ b/ai-hub/app/core/services/agent.py @@ -0,0 +1,232 @@ +import os +import uuid +import logging +import secrets +from typing import List, Optional +from sqlalchemy.orm import Session, joinedload +from fastapi import HTTPException +from app.api import schemas +from app.db import models +from app.db.models.agent import AgentTemplate, AgentInstance, AgentTrigger +from app.api.dependencies import ServiceContainer + +logger = logging.getLogger(__name__) + +class AgentService: + def __init__(self, services: ServiceContainer = None): + self.services = services + + def get_agent_instance(self, db: Session, agent_id: str, user_id: str) -> AgentInstance: + instance = db.query(AgentInstance).options( + joinedload(AgentInstance.template), + joinedload(AgentInstance.session) + ).filter( + AgentInstance.id == agent_id, + AgentInstance.user_id == user_id + ).first() + if not instance: + raise HTTPException(status_code=404, detail="Agent not found") + self.ensure_workspace_binding(db, instance) + return instance + + def list_user_agents(self, db: Session, user_id: str) -> List[AgentInstance]: + agents = db.query(AgentInstance).options( + joinedload(AgentInstance.template), + joinedload(AgentInstance.session) + ).filter( + AgentInstance.user_id == user_id + ).all() + + changed = False + for instance in agents: + if self.ensure_workspace_binding(db, instance): + changed = True + + if changed: + db.commit() + return agents + + def ensure_workspace_binding(self, db: Session, instance: AgentInstance) -> bool: + if not instance or not instance.session: + return False + + workspace_id = self._derive_workspace_id(instance.current_workspace_jail, instance.session_id) + desired_jail = f"/tmp/cortex/{workspace_id}/" + + changed = False + if instance.session.sync_workspace_id != workspace_id: + instance.session.sync_workspace_id = workspace_id + changed = True + + if instance.current_workspace_jail != desired_jail: + instance.current_workspace_jail = desired_jail + changed = True + + if changed: + db.flush() + try: + orchestrator = getattr(self.services, "orchestrator", None) + if orchestrator and instance.mesh_node_id: + orchestrator.assistant.push_workspace(instance.mesh_node_id, workspace_id) + orchestrator.assistant.control_sync(instance.mesh_node_id, workspace_id, action="START") + orchestrator.assistant.control_sync(instance.mesh_node_id, workspace_id, action="UNLOCK") + except Exception as e: + logger.error(f"Failed to heal workspace binding for agent {instance.id}: {e}") + + return changed + + def _derive_workspace_id(self, jail_path: str | None, session_id: int | None) -> str: + if jail_path: + base = os.path.basename(jail_path.rstrip("/")) + if base: return base + if session_id is not None: + return f"session-{session_id}" + return f"agent-{uuid.uuid4().hex[:8]}" + + def deploy_agent(self, db: Session, user_id: str, request: schemas.DeployAgentRequest) -> dict: + from app.config import settings + + # 1. Create Template + template = AgentTemplate( + name=request.name, + description=request.description, + system_prompt_path=request.system_prompt, + user_id=user_id, + max_loop_iterations=request.max_loop_iterations, + co_worker_quality_gate=request.co_worker_quality_gate, + rework_threshold=request.rework_threshold, + max_rework_attempts=request.max_rework_attempts + ) + db.add(template) + db.flush() + + # Resolve provider + resolved_provider = request.provider_name + if not resolved_provider: + sys_prefs = self.services.user_service.get_system_settings(db) + resolved_provider = sys_prefs.get('llm', {}).get('active_provider', settings.ACTIVE_LLM_PROVIDER) + + # 2. Create Session + new_session = models.Session( + user_id=user_id, + provider_name=resolved_provider, + feature_name="agent_harness", + is_locked=True, + system_prompt_override=request.system_prompt, + attached_node_ids=[request.mesh_node_id] if getattr(request, "mesh_node_id", None) else [] + ) + db.add(new_session) + db.flush() + + workspace_id = f"agent_{template.id[:8]}" + new_session.sync_workspace_id = workspace_id + workspace_jail = f"/tmp/cortex/{workspace_id}/" + db.flush() + + # Bootstrap Orchestrator + try: + orchestrator = getattr(self.services, "orchestrator", None) + if orchestrator and request.mesh_node_id: + orchestrator.assistant.push_workspace(request.mesh_node_id, workspace_id) + orchestrator.assistant.control_sync(request.mesh_node_id, workspace_id, action="START") + orchestrator.assistant.control_sync(request.mesh_node_id, workspace_id, action="UNLOCK") + except Exception as e: + logger.error(f"Failed to bootstrap Orchestrator Sync for Agent Deploy: {e}") + + # 3. Create Instance + instance = AgentInstance( + template_id=template.id, + user_id=user_id, + session_id=new_session.id, + mesh_node_id=request.mesh_node_id, + status="idle", + current_workspace_jail=workspace_jail + ) + db.add(instance) + db.flush() + + # 4. Trigger + trigger = AgentTrigger( + instance_id=instance.id, + trigger_type=request.trigger_type or "manual", + cron_expression=request.cron_expression, + interval_seconds=request.interval_seconds, + default_prompt=request.default_prompt + ) + if trigger.trigger_type == "webhook": + trigger.webhook_secret = secrets.token_hex(16) + db.add(trigger) + db.flush() + + db.commit() + db.refresh(instance) + + def create_template(self, db: Session, user_id: str, request: schemas.AgentTemplateCreate) -> AgentTemplate: + template = AgentTemplate(**request.model_dump()) + template.user_id = user_id + db.add(template) + db.commit() + db.refresh(template) + return template + + def create_instance(self, db: Session, user_id: str, request: schemas.AgentInstanceCreate) -> AgentInstance: + template = db.query(AgentTemplate).filter(AgentTemplate.id == request.template_id).first() + if not template: + raise HTTPException(status_code=404, detail="Template not found") + instance = AgentInstance(**request.model_dump()) + instance.user_id = user_id + db.add(instance) + db.commit() + db.refresh(instance) + return instance + + def update_status(self, db: Session, agent_id: str, user_id: str, status: str) -> AgentInstance: + instance = db.query(AgentInstance).filter(AgentInstance.id == agent_id, AgentInstance.user_id == user_id).first() + if not instance: raise HTTPException(status_code=404, detail="Instance not found") + instance.status = status + if status == "idle": + instance.last_error = None + instance.evaluation_status = None + db.commit() + db.refresh(instance) + return instance + + def update_config(self, db: Session, agent_id: str, user_id: str, request: schemas.AgentConfigUpdate) -> AgentInstance: + from app.db.models.session import Session as SessionModel + instance = db.query(AgentInstance).filter(AgentInstance.id == agent_id, AgentInstance.user_id == user_id).first() + if not instance: raise HTTPException(status_code=404, detail="Instance not found") + + template = db.query(AgentTemplate).filter(AgentTemplate.id == instance.template_id).first() + if template: + if request.name is not None: template.name = request.name + if request.system_prompt is not None: template.system_prompt_path = request.system_prompt + if request.max_loop_iterations is not None: template.max_loop_iterations = request.max_loop_iterations + if request.co_worker_quality_gate is not None: template.co_worker_quality_gate = request.co_worker_quality_gate + if request.rework_threshold is not None: template.rework_threshold = request.rework_threshold + if request.max_rework_attempts is not None: template.max_rework_attempts = request.max_rework_attempts + + if request.mesh_node_id is not None: instance.mesh_node_id = request.mesh_node_id + + if instance.session_id: + session = db.query(SessionModel).filter(SessionModel.id == instance.session_id).first() + if session: + if request.system_prompt is not None: session.system_prompt_override = request.system_prompt + if hasattr(request, 'provider_name') and request.provider_name is not None: session.provider_name = request.provider_name + if hasattr(request, 'model_name') and request.model_name is not None: session.model_name = request.model_name + if request.mesh_node_id is not None: + try: + self.services.session_service.attach_nodes(db, session.id, schemas.NodeAttachRequest(node_ids=[request.mesh_node_id] if request.mesh_node_id else [])) + except: session.attached_node_ids = [request.mesh_node_id] if request.mesh_node_id else [] + if hasattr(request, 'restrict_skills') and request.restrict_skills is not None: session.restrict_skills = request.restrict_skills + if hasattr(request, 'is_locked') and request.is_locked is not None: session.is_locked = request.is_locked + if hasattr(request, 'auto_clear_history') and request.auto_clear_history is not None: session.auto_clear_history = request.auto_clear_history + + db.commit() + db.refresh(instance) + return instance + + def delete_agent(self, db: Session, agent_id: str, user_id: str): + instance = db.query(AgentInstance).filter(AgentInstance.id == agent_id, AgentInstance.user_id == user_id).first() + if not instance: raise HTTPException(status_code=404, detail="Agent not found") + db.delete(instance) + db.commit() diff --git a/ai-hub/app/core/services/mesh.py b/ai-hub/app/core/services/mesh.py index 8de05ad..cd9c80d 100644 --- a/ai-hub/app/core/services/mesh.py +++ b/ai-hub/app/core/services/mesh.py @@ -10,16 +10,184 @@ from app.db import models from app.api import schemas +from app.api.dependencies import ServiceContainer +from app.core.grpc.utils.crypto import sign_payload +from app.protos import agent_pb2 logger = logging.getLogger(__name__) class MeshService: - def __init__(self, services=None): + def __init__(self, services: ServiceContainer = None): self.services = services # Setup Jinja2 templates self.templates_dir = os.path.join(os.path.dirname(__file__), "..", "templates", "provisioning") self.jinja_env = jinja2.Environment(loader=jinja2.FileSystemLoader(self.templates_dir)) if os.path.exists(self.templates_dir) else None + # --- Admin Logic --- + + def register_node(self, request: schemas.AgentNodeCreate, admin_id: str, db: Session) -> models.AgentNode: + existing = db.query(models.AgentNode).filter(models.AgentNode.node_id == request.node_id).first() + if existing: + raise HTTPException(status_code=409, detail=f"Node '{request.node_id}' already exists.") + + invite_token = secrets.token_urlsafe(32) + node = models.AgentNode( + node_id=request.node_id, + display_name=request.display_name, + description=request.description, + registered_by=admin_id, + skill_config=request.skill_config.model_dump(), + invite_token=invite_token, + last_status="offline", + ) + db.add(node) + db.commit() + db.refresh(node) + return node + + def update_node(self, node_id: str, update: schemas.AgentNodeUpdate, db: Session) -> models.AgentNode: + node = self.get_node_or_404(node_id, db) + if update.display_name is not None: node.display_name = update.display_name + if update.description is not None: node.description = update.description + if update.skill_config is not None: + node.skill_config = update.skill_config.model_dump() + try: + self.services.orchestrator.push_policy(node_id, node.skill_config) + except Exception as e: + logger.warning(f"Could not push live policy to {node_id}: {e}") + + if update.is_active is not None: node.is_active = update.is_active + db.commit() + db.refresh(node) + return node + + def delete_node(self, node_id: str, db: Session): + node = self.get_node_or_404(node_id, db) + self.services.node_registry_service.deregister(node_id) + db.delete(node) + db.commit() + + def grant_access(self, node_id: str, grant: schemas.NodeAccessGrant, admin_id: str, db: Session): + existing = db.query(models.NodeGroupAccess).filter( + models.NodeGroupAccess.node_id == node_id, + models.NodeGroupAccess.group_id == grant.group_id + ).first() + if existing: + existing.access_level = grant.access_level + existing.granted_by = admin_id + else: + access = models.NodeGroupAccess( + node_id=node_id, group_id=grant.group_id, + access_level=grant.access_level, granted_by=admin_id, + ) + db.add(access) + db.commit() + + # --- User Logic --- + + def list_accessible_nodes(self, user_id: str, db: Session) -> List[models.AgentNode]: + user = db.query(models.User).filter(models.User.id == user_id).first() + if not user: + raise HTTPException(status_code=404, detail="User not found.") + + if user.role == "admin": + return db.query(models.AgentNode).filter(models.AgentNode.is_active == True).all() + + accesses = db.query(models.NodeGroupAccess).filter(models.NodeGroupAccess.group_id == user.group_id).all() + node_ids = set([a.node_id for a in accesses]) + if user.group and user.group.policy: + for nid in user.group.policy.get("nodes", []): node_ids.add(nid) + + return db.query(models.AgentNode).filter( + models.AgentNode.node_id.in_(list(node_ids)), + models.AgentNode.is_active == True + ).all() + + def dispatch_task(self, node_id: str, command: str, user_id: str, db: Session, session_id: str = "", task_id: str = None, timeout_ms: int = 30000): + self.require_node_access(user_id, node_id, db) + registry = self.services.node_registry_service + live = registry.get_node(node_id) + if not live: + raise HTTPException(status_code=503, detail=f"Node '{node_id}' is not connected.") + + t_id = task_id or str(uuid.uuid4()) + registry.emit(node_id, "task_assigned", {"command": command, "session_id": session_id}, task_id=t_id) + + task_req = agent_pb2.TaskRequest( + task_id=t_id, payload_json=command, signature=sign_payload(command), + timeout_ms=timeout_ms, session_id=session_id or "" + ) + live.send_message(agent_pb2.ServerTaskMessage(task_request=task_req), priority=1) + registry.emit(node_id, "task_start", {"command": command}, task_id=t_id) + return t_id + + # --- Utilities --- + + def get_node_or_404(self, node_id: str, db: Session) -> models.AgentNode: + node = db.query(models.AgentNode).filter(models.AgentNode.node_id == node_id).first() + if not node: + raise HTTPException(status_code=404, detail=f"Node '{node_id}' not found.") + return node + + def node_to_admin_detail(self, node: models.AgentNode) -> schemas.AgentNodeAdminDetail: + registry = self.services.node_registry_service + live = registry.get_node(node.node_id) + status = live._compute_status() if live else node.last_status or "offline" + stats = schemas.AgentNodeStats(**live.stats) if live else schemas.AgentNodeStats() + return schemas.AgentNodeAdminDetail( + node_id=node.node_id, + display_name=node.display_name, + description=node.description, + skill_config=node.skill_config or {}, + capabilities=node.capabilities or {}, + invite_token=node.invite_token, + is_active=node.is_active, + last_status=status, + last_seen_at=node.last_seen_at, + created_at=node.created_at, + registered_by=node.registered_by, + group_access=[ + schemas.NodeAccessResponse( + id=a.id, node_id=a.node_id, group_id=a.group_id, + access_level=a.access_level, granted_at=a.granted_at + ) for a in (node.group_access or []) + ], + stats=stats, + ) + + def generate_node_config_yaml(self, node: models.AgentNode, skill_overrides: dict = None) -> str: + from app.config import settings + import yaml + + hub_url = settings.GRPC_EXTERNAL_ENDPOINT or os.getenv("HUB_PUBLIC_URL", "http://127.0.0.1:8000") + hub_grpc = settings.GRPC_TARGET_ORIGIN or os.getenv("HUB_GRPC_ENDPOINT", "127.0.0.1:50051") + secret_key = os.getenv("SECRET_KEY", "dev-secret-key-1337") + + skill_cfg = node.skill_config or {} + if isinstance(skill_cfg, str): + try: skill_cfg = json.loads(skill_cfg) + except: skill_cfg = {} + + if skill_overrides: + for skill, cfg in skill_overrides.items(): + skill_cfg.setdefault(skill, {}).update(cfg) + + config_data = { + "node_id": node.node_id, + "node_description": node.display_name, + "hub_url": hub_url, + "grpc_endpoint": hub_grpc, + "invite_token": node.invite_token, + "auth_token": node.invite_token, + "secret_key": secret_key, + "skills": skill_cfg, + "sync_root": "/tmp/cortex-sync", + "tls": settings.GRPC_TLS_ENABLED + } + + header = f"# Cortex Hub - Agent Node Configuration\n# Generated for node '{node.node_id}'\n\n" + return header + yaml.dump(config_data, sort_keys=False, default_flow_style=False) + # Extracted from nodes.py def require_node_access(self, user_id: str, node_id: str, db: Session): user = db.query(models.User).filter(models.User.id == user_id).first() @@ -66,55 +234,27 @@ ) def generate_provisioning_script(self, node: models.AgentNode, config_yaml: str, base_url: str) -> str: - if not self.jinja_env: - return "Error: Templates directory not found." - try: - template = self.jinja_env.get_template("provision.py.j2") - return template.render( - node_id=node.node_id, - config_yaml=config_yaml, - base_url=base_url, - invite_token=node.invite_token - ) - except Exception as e: - logger.error(f"Failed to generate provisioning script: {e}") - return f"Error: {e}" + return self._render_provision_template("provision.py.j2", node, config_yaml, base_url) def generate_provisioning_sh(self, node: models.AgentNode, config_yaml: str, base_url: str) -> str: - if not self.jinja_env: - return "Error: Templates directory not found." - try: - template = self.jinja_env.get_template("provision.sh.j2") - return template.render( - node_id=node.node_id, - config_yaml=config_yaml, - base_url=base_url, - invite_token=node.invite_token - ) - except Exception as e: - logger.error(f"Failed to generate provisioning script: {e}") - return f"Error: {e}" + return self._render_provision_template("provision.sh.j2", node, config_yaml, base_url) def generate_provisioning_ps1(self, node: models.AgentNode, config_yaml: str, base_url: str, grpc_url: str = "") -> str: - if not self.jinja_env: - return "Error: Templates directory not found." + params = {"grpc_url": grpc_url or base_url.replace("http://", "").replace("https://", "")} + return self._render_provision_template("provision.ps1.j2", node, config_yaml, base_url, **params) + + def _render_provision_template(self, template_name: str, node: models.AgentNode, config_yaml: str, base_url: str, **kwargs) -> str: + if not self.jinja_env: return "Error: Templates directory not found." try: - template = self.jinja_env.get_template("provision.ps1.j2") - return template.render( - node_id=node.node_id, - config_yaml=config_yaml, - base_url=base_url, - grpc_url=grpc_url or base_url.replace("http://", "").replace("https://", ""), - invite_token=node.invite_token + return self.jinja_env.get_template(template_name).render( + node_id=node.node_id, config_yaml=config_yaml, + base_url=base_url, invite_token=node.invite_token, **kwargs ) except Exception as e: - logger.error(f"Failed to generate provisioning script: {e}") + logger.error(f"Failed to render {template_name}: {e}") return f"Error: {e}" def get_template_content(self, filename: str) -> str: - if not self.jinja_env: - return "" - try: - return self.jinja_env.get_template(filename).render() - except: - return "" + if not self.jinja_env: return "" + try: return self.jinja_env.get_template(filename).render() + except: return "" diff --git a/ai-hub/app/core/services/node_registry.py b/ai-hub/app/core/services/node_registry.py index 0567c29..7c2b68f 100644 --- a/ai-hub/app/core/services/node_registry.py +++ b/ai-hub/app/core/services/node_registry.py @@ -94,30 +94,52 @@ msg.signature = sign_bytes(msg_bytes) item = (priority, time.time(), msg) - + self._dispatch_to_queue(item) + + def _dispatch_to_queue(self, item): + """Internal helper to bridge sync/async queue access with backpressure.""" def _blocking_put(): try: - # Reduced timeout from 300s to 5s to avoid blocking gRPC threads for too long. - # 5s timeout provides backpressure while protecting against deadlocks. self.queue.put(item, block=True, timeout=5.0) except queue.Full: - logger.warning(f"[๐Ÿ“‹โš ๏ธ] Message dropped for {self.node_id}: outbound queue FULL. Node may be unresponsive.") + logger.warning(f"[๐Ÿ“‹โš ๏ธ] Message dropped for {self.node_id}: outbound queue FULL.") except Exception as e: logger.error(f"[๐Ÿ“‹โŒ] Sync error sending to {self.node_id}: {e}") try: - # Check if we are in an async loop (FastAPI context) loop = asyncio.get_running_loop() if loop.is_running(): - # Run blocking_put in the default executor (unbounded) to avoid blocking event loop loop.run_in_executor(None, _blocking_put) return - except RuntimeError: - pass # Not in async loop - - # Standard sync put (from gRPC thread) + except RuntimeError: pass _blocking_put() + def append_history(self, event_type: str, data: Any): + """Maintains the terminal history buffer for AI context awareness.""" + from app.core._regex import ANSI_ESCAPE + + if event_type == "task_assigned" and isinstance(data, dict): + cmd = data.get("command") + if cmd and not (isinstance(cmd, str) and cmd.startswith('{"tty"')): + self.terminal_history.append(f"$ {cmd}\n") + + elif event_type == "task_stdout" and isinstance(data, str): + self._append_clean_output(data) + + elif event_type == "skill_event" and isinstance(data, dict) and data.get("type") == "output": + self._append_clean_output(data.get("data", "")) + + elif event_type == "reasoning" and isinstance(data, dict): + content = data.get("content", "") + if content: self.terminal_history.append(content) + + def _append_clean_output(self, output: str): + from app.core._regex import ANSI_ESCAPE + clean = ANSI_ESCAPE.sub('', output) + if len(clean) > 100_000: + clean = clean[:100_000] + "\n[... Output Truncated ...]\n" + self.terminal_history.append(clean) + def update_stats(self, stats: dict): self.stats.update(stats) self.last_heartbeat_at = datetime.utcnow() @@ -136,9 +158,7 @@ def _compute_status(self) -> str: delta = (datetime.utcnow() - self.last_heartbeat_at).total_seconds() - if delta > 30: - return "stale" - return "online" + return "stale" if delta > 30 else "online" def is_healthy(self) -> bool: """True if the node has reported metrics recently and has an active stream.""" @@ -386,16 +406,19 @@ # Emit heartbeat event to live UI self.emit(node_id, "heartbeat", stats) + def emergency_reset(self) -> int: + """Emergency cleanup of all live nodes and DB status reset.""" + self.reset_all_statuses() + return self.clear_memory_cache() + # ------------------------------------------------------------------ # # Event Bus # # ------------------------------------------------------------------ # def emit(self, node_id: str, event_type: str, data: Any = None, task_id: str = ""): """ - Emit a rich structured execution event. - Delivered to: - - Per-node WS subscribers โ†’ powers the single-node execution pane - - Per-user WS subscribers โ†’ powers the global multi-node execution bus + Emit a rich structured execution event and update live state. + Delivered to WS subscribers for real-time UI visibility. """ with self._lock: node = self._nodes.get(node_id) @@ -403,60 +426,22 @@ node_qs = list(self._node_listeners.get(node_id, [])) user_qs = list(self._user_listeners.get(user_id, [])) if user_id else [] - if user_id and not user_qs and event_type in ["node_online", "node_offline"]: - logger.debug(f"[Registry] emit({event_type}) for node {node_id}: No user listeners found for user {user_id}") + if node: node.append_history(event_type, data) event = { "event": event_type, "label": EVENT_TYPES.get(event_type, event_type), - "node_id": node_id, - "user_id": user_id, - "task_id": task_id, - "timestamp": datetime.utcnow().isoformat(), - "data": data, + "node_id": node_id, "user_id": user_id, "task_id": task_id, + "timestamp": datetime.utcnow().isoformat(), "data": data, } - # M6: Store terminal history locally for AI reading - # We only store raw shell output and the commands themselves to keep the context clean. - if node: - if event_type == "task_assigned" and isinstance(data, dict): - cmd = data.get("command") - if cmd: - # Skip TTY keypress echos (manual typing) to keep AI context clean - # We usually only care about the final result or purposeful command execution - # If it's a JSON dict for tty, it's likely a character-by-character input - is_tty_char = isinstance(cmd, str) and cmd.startswith('{"tty"') - if not is_tty_char: - node.terminal_history.append(f"$ {cmd}\n") - elif event_type == "task_stdout" and isinstance(data, str): - # Use pre-compiled global regex to avoid overhead on every token - from app.core._regex import ANSI_ESCAPE - clean_output = ANSI_ESCAPE.sub('', data) - if len(clean_output) > 100_000: - clean_output = clean_output[:100_000] + "\n[... Output Truncated ...]\n" - node.terminal_history.append(clean_output) - elif event_type == "skill_event" and isinstance(data, dict): - if data.get("type") == "output": - output_data = data.get("data", "") - from app.core._regex import ANSI_ESCAPE - clean_output = ANSI_ESCAPE.sub('', output_data) - if len(clean_output) > 100_000: - clean_output = clean_output[:100_000] + "\n[... Output Truncated ...]\n" - node.terminal_history.append(clean_output) - elif event_type == "reasoning" and isinstance(data, dict): - content = data.get("content", "") - if content: - # Append reasoning as a distinct "thought" block in terminal history - node.terminal_history.append(content) - + # Broadcast to all unique listeners seen = set() - for q in node_qs + user_qs: + for q in (node_qs + user_qs): if id(q) not in seen: seen.add(id(q)) - try: - q.put_nowait(event) - except Exception: - pass + try: q.put_nowait(event) + except: pass # ------------------------------------------------------------------ # # WS Subscriptions # diff --git a/ai-hub/app/core/services/preference.py b/ai-hub/app/core/services/preference.py index 66e129a..5a5c5b3 100644 --- a/ai-hub/app/core/services/preference.py +++ b/ai-hub/app/core/services/preference.py @@ -284,3 +284,144 @@ statuses=user.preferences.get("statuses", {}) ) + def export_config_yaml(self, user, reveal_secrets: bool) -> str: + import yaml + from app.core.grpc.utils.crypto import encrypt_value + + prefs_dict = copy.deepcopy(user.preferences) if user.preferences else {} + sensitive_keys = ["api_key", "client_secret", "webhook_secret", "password", "key_content", "key_file"] + + def process_export(obj): + if isinstance(obj, dict): + res = {} + for k, v in obj.items(): + if v is None: continue + if k in sensitive_keys and v: + res[k] = v if reveal_secrets else encrypt_value(v) + else: + res[k] = process_export(v) + return res + elif isinstance(obj, list): + return [process_export(x) for x in obj] + return obj + + export_data = { + "llm": prefs_dict.get("llm", {"providers": {}, "active_provider": "deepseek"}), + "tts": prefs_dict.get("tts", {"providers": {}, "active_provider": settings.TTS_PROVIDER}), + "stt": prefs_dict.get("stt", {"providers": {}, "active_provider": settings.STT_PROVIDER}) + } + + # Backfill from settings if empty + if not export_data["llm"].get("providers"): + export_data["llm"]["providers"] = { + "deepseek": {"api_key": settings.DEEPSEEK_API_KEY, "model": settings.DEEPSEEK_MODEL_NAME}, + "gemini": {"api_key": settings.GEMINI_API_KEY, "model": settings.GEMINI_MODEL_NAME} + } + + return yaml.dump(process_export(export_data), sort_keys=False, default_flow_style=False) + + async def import_config_yaml(self, db, user, content: bytes) -> schemas.UserPreferences: + import yaml + from app.core.grpc.utils.crypto import decrypt_value + from sqlalchemy.orm.attributes import flag_modified + + try: data = yaml.safe_load(content) + except Exception as e: raise Exception(f"Invalid YAML: {e}") + + def process_import(obj): + if isinstance(obj, dict): return {k: process_import(v) for k, v in obj.items()} + elif isinstance(obj, str): return decrypt_value(obj) + elif isinstance(obj, list): return [process_import(x) for x in obj] + return obj + + data = process_import(data) + user.preferences = { + "llm": data.get("llm", {}), + "tts": data.get("tts", {}), + "stt": data.get("stt", {}), + "statuses": {} + } + flag_modified(user, "preferences") + db.commit() + return schemas.UserPreferences(llm=user.preferences["llm"], tts=user.preferences["tts"], stt=user.preferences["stt"]) + + async def verify_provider(self, db, user, req: schemas.VerifyProviderRequest, section: str) -> schemas.VerifyProviderResponse: + from app.core.providers.factory import get_llm_provider, get_tts_provider, get_stt_provider + + # Admin or personal key check + is_masked = not req.api_key or "***" in str(req.api_key) + if is_masked and user.role != "admin": + return schemas.VerifyProviderResponse(success=False, message="Forbidden: Admin only for masked keys") + + actual_key = req.api_key + prefs = user.preferences.get(section, {}).get("providers", {}).get(req.provider_name, {}) if user.preferences else {} + + if is_masked: + actual_key = prefs.get("api_key") + if not actual_key: + s_prefs = self.services.user_service.get_system_settings(db) + actual_key = s_prefs.get(section, {}).get("providers", {}).get(req.provider_type or req.provider_name, {}).get("api_key") + if not actual_key: + if section == "llm": actual_key = settings.DEEPSEEK_API_KEY + elif section == "tts": actual_key = settings.TTS_API_KEY + else: actual_key = settings.STT_API_KEY + + try: + if section == "llm": + llm = get_llm_provider(req.provider_name, model_name=req.model or "", api_key_override=actual_key) + await llm.acompletion(prompt="Hello") + elif section == "tts": + p = get_tts_provider(req.provider_name, api_key=actual_key, model_name=req.model or "", voice_name=req.voice or "") + await p.generate_speech("Test") + else: + get_stt_provider(req.provider_name, api_key=actual_key, model_name=req.model or "") + return schemas.VerifyProviderResponse(success=True, message="Success!") + except Exception as e: + return schemas.VerifyProviderResponse(success=False, message=str(e)) + + def resolve_llm_provider(self, db, user, provider_name: str, model_name: str = None) -> Any: + """ + Unified resolution for LLM providers with full fallback chain: + User Preference -> System Override (Admin UI) -> Config Defaults (YAML/Env) + """ + from app.core.providers.factory import get_llm_provider + + base_key = provider_name.split("/")[0] if provider_name else "" + if not base_key and user: + base_key = user.preferences.get("llm", {}).get("active_provider", "deepseek") + provider_name = base_key + + llm_prefs = user.preferences.get("llm", {}).get("providers", {}).get(base_key, {}) if user else {} + + # Resolve Key Fallbacks + if not llm_prefs or not llm_prefs.get("api_key") or "*" in str(llm_prefs.get("api_key")): + u_svc = getattr(self.services, "user_service", None) + if u_svc: + system_prefs = u_svc.get_system_settings(db) + system_prov = system_prefs.get("llm", {}).get("providers", {}).get(base_key, {}) + + # Fallback to system's active_provider if specified provider is missing key + if (not system_prov or not system_prov.get("api_key")) and system_prefs.get("llm", {}).get("active_provider"): + active_key = system_prefs["llm"]["active_provider"] + system_prov, provider_name = system_prefs["llm"]["providers"].get(active_key, {}), active_key + base_key = active_key + + if system_prov: + merged = system_prov.copy() + merged.update({k: v for k, v in llm_prefs.items() if v}) + llm_prefs = merged + + # Resolve Model Override (handles 'provider/model' syntax) + resolved_model = provider_name.split("/")[1] if "/" in provider_name else (model_name or llm_prefs.get("model", "")) + resolved_provider_name = provider_name.split("/")[0] if "/" in provider_name else provider_name + + try: + return get_llm_provider( + resolved_provider_name, + model_name=resolved_model, + api_key_override=llm_prefs.get("api_key"), + **{k: v for k, v in llm_prefs.items() if k not in ["api_key", "model"]} + ), resolved_provider_name + except Exception as e: + logger.error(f"Failed to resolve LLM provider: {e}") + return None, resolved_provider_name diff --git a/ai-hub/app/core/services/rag.py b/ai-hub/app/core/services/rag.py index 168971c..0bbdfb9 100644 --- a/ai-hub/app/core/services/rag.py +++ b/ai-hub/app/core/services/rag.py @@ -1,8 +1,6 @@ import logging -import re -from typing import List, Tuple, Optional - -logger = logging.getLogger(__name__) +import time +from typing import List, Optional, Dict, Any, AsyncGenerator, Tuple from sqlalchemy.orm import Session, joinedload from app.db import models @@ -11,13 +9,17 @@ from app.core.providers.factory import get_llm_provider from app.core.orchestration import Architect from app.core.orchestration.profiles import get_profile +from app.core._regex import ANSI_ESCAPE +from app.db.session import async_db_op + +logger = logging.getLogger(__name__) class RAGService: """ - Service for orchestrating conversational RAG pipelines. - Manages chat interactions and message history for a session. + Orchestrates conversational RAG pipelines. + Decomposed into manageable components for maintainability. """ - def __init__(self, retrievers: List[Retriever], prompt_service = None, tool_service = None, node_registry_service = None): + def __init__(self, retrievers: List[Retriever], prompt_service=None, tool_service=None, node_registry_service=None): self.retrievers = retrievers self.prompt_service = prompt_service self.tool_service = tool_service @@ -31,317 +33,162 @@ prompt: str, provider_name: str, load_faiss_retriever: bool = False, - user_service = None, + user_service=None, user_id: Optional[str] = None - ): - """ - Processes a user prompt within a session, yields events in real-time, - and saves the chat history at the end. - """ - session = db.query(models.Session).options( - joinedload(models.Session.messages) - ).filter(models.Session.id == session_id).first() - - if not session: - raise ValueError(f"Session with ID {session_id} not found.") - - # Save user message - user_message = models.Message(session_id=session_id, sender="user", content=prompt) - db.add(user_message) - db.commit() - db.refresh(user_message) - - # Auto-title the session from the very first user message - if session.title in (None, "New Chat Session", ""): - session.title = prompt[:60].strip() + ("..." if len(prompt) > 60 else "") - - # Resolve provider - extract base key for settings lookup - # e.g. "gemini/gemini-1.5-flash" -> base key "gemini" - base_provider_key = provider_name.split("/")[0] if "/" in provider_name else provider_name - - llm_prefs = {} - user = session.user - if user and user.preferences: - llm_prefs = user.preferences.get("llm", {}).get("providers", {}).get(base_provider_key, {}) - - if (not llm_prefs or not llm_prefs.get("api_key") or "*" in str(llm_prefs.get("api_key"))) and user_service: - system_prefs = user_service.get_system_settings(db) - system_provider_prefs = system_prefs.get("llm", {}).get("providers", {}).get(base_provider_key, {}) - # If the user passed a generic string like 'gemini' but there is no block explicitly for it, - # try to fallback to the explicitly defined active provider to avoid throwing 400s - if not system_provider_prefs or not system_provider_prefs.get("model"): - active_prov_key = system_prefs.get("llm", {}).get("active_provider") - if active_prov_key: - system_provider_prefs = system_prefs.get("llm", {}).get("providers", {}).get(active_prov_key, {}) - if system_provider_prefs: - provider_name = active_prov_key # Remap so factory gets the right full config - - if system_provider_prefs: - merged = system_provider_prefs.copy() - if llm_prefs: merged.update({k: v for k, v in llm_prefs.items() if v}) - llm_prefs = merged - - api_key_override = llm_prefs.get("api_key") - - # If provider_name already contains an explicit model (e.g. "gemini/gemini-1.5-flash"), - # do NOT override it with the model from system settings (which might be "gemini-1.5-flash") - if "/" in provider_name: - model_name_override = "" # Let factory extract model from provider_name - else: - # M3: Priority order: - # 1. Message-level override (done via provider_name) - # 2. Session-level override (persisted in DB) - # 3. User-level preference - model_name_override = session.model_name or llm_prefs.get("model", "") - - - - kwargs = {k: v for k, v in llm_prefs.items() if k not in ["api_key", "model"]} - llm_provider = get_llm_provider( - provider_name, - model_name=model_name_override, - api_key_override=api_key_override, - **kwargs - ) + ) -> AsyncGenerator[Dict[str, Any], None]: + """Entry point for the RAG pipeline.""" + session = self._resolve_session(db, session_id, prompt) + llm_provider, resolved_provider_name = self._resolve_provider(db, session, provider_name, user_service) context_chunks = [] - if load_faiss_retriever: - if self.faiss_retriever: - context_chunks.extend(self.faiss_retriever.retrieve_context(query=prompt, db=db)) + if load_faiss_retriever and self.faiss_retriever: + context_chunks.extend(self.faiss_retriever.retrieve_context(query=prompt, db=db)) - architect = Architect() - - tools = [] - if self.tool_service: - tools = self.tool_service.get_available_tools(db, session.user_id, feature=session.feature_name, session_id=session.id) - + mesh_context = self._gather_mesh_context(db, session, user_service) + tools = self.tool_service.get_available_tools(db, session.user_id, feature=session.feature_name, session_id=session.id) if self.tool_service else [] profile = get_profile(session.feature_name) - mesh_context = "" - if session.attached_node_ids and profile.include_mesh_context: - nodes = db.query(models.AgentNode).filter(models.AgentNode.node_id.in_(session.attached_node_ids)).all() - if nodes: - mesh_context = "Attached Agent Nodes (Infrastructure):\n" - for node in nodes: - mesh_context += f"- Node ID: {node.node_id}\n" - mesh_context += f" Name: {node.display_name}\n" - mesh_context += f" Description: {node.description or 'No description provided.'}\n" - mesh_context += f" Status: {node.last_status}\n" - - caps = node.capabilities or {} - if caps.get("local_ip"): - mesh_context += f" Local IP: {caps.get('local_ip')}\n" - if caps.get("arch"): - mesh_context += f" Architecture: {caps['arch']} ({caps.get('os', 'unknown')})\n" - if caps.get("gpu") and caps["gpu"] != "none": - mesh_context += f" GPU: {caps['gpu']}\n" - # Privilege level โ€” critical for knowing whether to use sudo - # Values are stored as strings ("true"/"false") due to protobuf map - is_root = caps.get("is_root") - has_sudo = caps.get("has_sudo") - if is_root == "true" or is_root is True: - mesh_context += f" Privilege Level: root (skip sudo โ€” run all commands directly)\n" - elif has_sudo == "true" or has_sudo is True: - mesh_context += f" Privilege Level: standard user with passwordless sudo\n" - elif is_root == "false" or is_root is False: - mesh_context += f" Privilege Level: standard user (sudo NOT available โ€” avoid privileged ops)\n" - # If neither field exists yet (old node version), omit to avoid confusion - - shell_config = (node.skill_config or {}).get("shell", {}) - if shell_config.get("enabled"): - sandbox = shell_config.get("sandbox") or {} - mode = sandbox.get("mode", "PERMISSIVE") - allowed = sandbox.get("allowed_commands", []) - denied = sandbox.get("denied_commands", []) - - mesh_context += f" Terminal Sandbox Mode: {mode}\n" - if mode == "STRICT": - mesh_context += f" AI Permitted Commands (Allow-list): {', '.join(allowed) if allowed else 'None'}\n" - elif mode == "PERMISSIVE": - mesh_context += f" AI Restricted Commands (Blacklist): {', '.join(denied) if denied else 'None'}\n" - - if mode == "STRICT" and not allowed: - mesh_context += " โš ๏ธ Warning: All shell commands are currently blocked by sandbox policy.\n" - - # AI Visibility: Recent terminal history - registry = getattr(self, "node_registry_service", None) - if not registry and user_service: - registry = getattr(user_service, "node_registry_service", None) - - if registry: - live = registry.get_node(node.node_id) - if live and live.terminal_history: - # Grab recent chunks and join - # defensive join: only take enough chunks for ~4000 chars total - chunks = [] - total_len = 0 - for chunk in reversed(list(live.terminal_history)[-40:]): - # Defensive handle for non-string chunks (e.g. dicts from some serialization) - chunk_str = "" - if isinstance(chunk, str): - chunk_str = chunk - elif isinstance(chunk, dict): - chunk_str = chunk.get("output") or chunk.get("content") or str(chunk) - else: - chunk_str = str(chunk) - - chunks.insert(0, chunk_str) - total_len += len(chunk_str) - if total_len > 4000: break - - history_blob = "".join(chunks) - - # Use pre-compiled regex from utility - from app.core._regex import ANSI_ESCAPE - clean_history = ANSI_ESCAPE.sub('', history_blob) - - # Limit to 2000 chars to avoid bloating the context - if len(clean_history) > 2000: - clean_history = "...[truncated]...\n" + clean_history[-2000:] - - mesh_context += " Recent Terminal Output:\n" - mesh_context += " ```\n" - mesh_context += f" {clean_history}" - if not clean_history.endswith('\n'): mesh_context += "\n" - mesh_context += " ```\n" - mesh_context += "\n" - - logger.info(f"[RAG] Mesh Context gathered. Length: {len(mesh_context)} chars.") - if mesh_context: - logger.info(f"[RAG] Mesh Context excerpt: {mesh_context[:200]}...") - - # Accumulators for the DB save at the end - full_answer = "" - full_reasoning = "" - tool_counts = {} - input_tokens = 0 - output_tokens = 0 - current_assistant_msg = None - - # Stream from specialized Architect - async for event in architect.run( - question=prompt, - history=session.messages, - context_chunks = context_chunks, - llm_provider = llm_provider, - prompt_service = self.prompt_service, - tool_service = self.tool_service, - tools = tools, - mesh_context = mesh_context, - db = db, - user_id = user_id or session.user_id, - sync_workspace_id = session.sync_workspace_id or str(session_id), - session_id = session_id, - feature_name = session.feature_name, - prompt_slug = profile.default_prompt_slug, - session_override = session.system_prompt_override - ): - if event["type"] == "content": - full_answer += event["content"] - elif event["type"] == "reasoning": - full_reasoning += event["content"] - elif event["type"] == "tool_start": - t_name = event.get("name") - if t_name: - if t_name not in tool_counts: - tool_counts[t_name] = {"calls": 0, "successes": 0, "failures": 0} - tool_counts[t_name]["calls"] += 1 - elif event["type"] == "tool_result": - t_name = event.get("name") - if t_name and t_name in tool_counts: - result_data = event.get("result") - if result_data and (not isinstance(result_data, dict) or result_data.get("success") is False): - tool_counts[t_name]["failures"] += 1 - else: - tool_counts[t_name]["successes"] += 1 - elif event["type"] == "token_counted": - usage = event.get("usage", {}) - if hasattr(usage, "get"): - input_tokens += usage.get("prompt_tokens", 0) - output_tokens += usage.get("completion_tokens", 0) - - # Forward the event to the API stream - yield event - - # Handle background updates for persistent UI observability - if event["type"] in ("content", "reasoning"): - # Initialize assistant message in DB on first token if it doesn't exist - if not current_assistant_msg: - current_assistant_msg = models.Message( - session_id=session_id, - sender="assistant", - content="" - ) - db.add(current_assistant_msg) - from app.db.session import async_db_op - await async_db_op(db.commit) - - # Update local accumulators - if event["type"] == "content": - current_assistant_msg.content += event["content"] - elif event["type"] == "reasoning" and hasattr(current_assistant_msg, "reasoning_content"): - if not current_assistant_msg.reasoning_content: - current_assistant_msg.reasoning_content = "" - current_assistant_msg.reasoning_content += event["content"] - - # Forward to Swarm Registry so the Swarm Control / Node Dash views see it too - registry = getattr(self, "node_registry_service", None) - if registry and session.attached_node_ids: - for node_id in session.attached_node_ids: - registry.emit(node_id, "reasoning", { - "content": event.get("content", ""), - "session_id": session_id, - "type": event["type"] - }) - - # Commit every 50 tokens or when it makes sense UI-wise. - # Frequent commits block the async event loop with synchronous disk I/O. - if (input_tokens + output_tokens) % 50 == 0: - try: - from app.db.session import async_db_op - await async_db_op(db.commit) - except: - await async_db_op(db.rollback) - - # Final cleanup of the transient assistant message state - if current_assistant_msg: - assistant_message = current_assistant_msg - assistant_message.content = full_answer - if full_reasoning and hasattr(assistant_message, "reasoning_content"): - assistant_message.reasoning_content = full_reasoning - from app.db.session import async_db_op - await async_db_op(db.commit) - else: - # Fallback if no tokens were yielded but we reached the end - assistant_message = models.Message( - session_id=session_id, - sender="assistant", - content=full_answer - ) - if full_reasoning and hasattr(assistant_message, "reasoning_content"): - assistant_message.reasoning_content = full_reasoning - db.add(assistant_message) - from app.db.session import async_db_op - await async_db_op(db.commit) - - # Yield a final finish event with metadata - yield { - "type": "finish", - "message_id": assistant_message.id, - "provider": provider_name, - "full_answer": full_answer, - "tool_counts": tool_counts, - "input_tokens": input_tokens, - "output_tokens": output_tokens + # Accumulators + state = { + "answer": "", "reasoning": "", "tool_counts": {}, + "usage": {"input": 0, "output": 0}, "msg": None } + architect = Architect() + async for event in architect.run( + question=prompt, history=session.messages, context_chunks=context_chunks, + llm_provider=llm_provider, prompt_service=self.prompt_service, tool_service=self.tool_service, + tools=tools, mesh_context=mesh_context, db=db, user_id=user_id or session.user_id, + sync_workspace_id=session.sync_workspace_id or str(session_id), session_id=session_id, + feature_name=session.feature_name, prompt_slug=profile.default_prompt_slug, + session_override=session.system_prompt_override + ): + await self._process_event(db, session_id, event, state) + yield event + + # Final persistence + assistant_msg = await self._finalize_assistant_message(db, session_id, state) + yield { + "type": "finish", "message_id": assistant_msg.id, "provider": resolved_provider_name, + "full_answer": state["answer"], "tool_counts": state["tool_counts"], + "input_tokens": state["usage"]["input"], "output_tokens": state["usage"]["output"] + } + + def _resolve_session(self, db: Session, session_id: int, prompt: str) -> models.Session: + """Fetches and initializes the session state.""" + session = db.query(models.Session).options(joinedload(models.Session.messages)).filter(models.Session.id == session_id).first() + if not session: raise ValueError(f"Session {session_id} not found.") + + # Save user message + db.add(models.Message(session_id=session_id, sender="user", content=prompt)) + if session.title in (None, "New Chat Session", ""): + session.title = prompt[:60].strip() + ("..." if len(prompt) > 60 else "") + + db.commit() + return session + + def _resolve_provider(self, db: Session, session: models.Session, provider_name: str, user_service) -> Tuple[Any, str]: + """Resolves LLM provider with user-preference and system-level fallbacks.""" + if not self.prompt_service or not self.prompt_service.services.preference_service: + from app.core.providers.factory import get_llm_provider + return get_llm_provider(provider_name), provider_name + + return self.prompt_service.services.preference_service.resolve_llm_provider( + db, session.user, provider_name, model_name=session.model_name + ) + + def _gather_mesh_context(self, db: Session, session: models.Session, user_service) -> str: + """Aggregates technical infrastructure context from attached agent nodes.""" + profile = get_profile(session.feature_name) + if not session.attached_node_ids or not profile.include_mesh_context: + return "" + + nodes = db.query(models.AgentNode).filter(models.AgentNode.node_id.in_(session.attached_node_ids)).all() + ctx = "Attached Agent Nodes (Infrastructure):\n" + for node in nodes: + ctx += f"- Node ID: {node.node_id}\n Name: {node.display_name}\n" + ctx += f" Status: {node.last_status}\n" + + caps = node.capabilities or {} + if caps.get("arch"): ctx += f" Arch: {caps['arch']} ({caps.get('os', 'unknown')})\n" + + # Privilege inference + is_root, has_sudo = caps.get("is_root") == "true", caps.get("has_sudo") == "true" + ctx += f" Privilege: {'root' if is_root else 'sudo-user' if has_sudo else 'standard'}\n" + + # Sandbox status + sb = (node.skill_config or {}).get("shell", {}).get("sandbox", {}) + if sb: ctx += f" Sandbox: {sb.get('mode', 'PERMISSIVE')}\n" + + # Live terminal tailing + registry = self.node_registry_service or (user_service.node_registry_service if user_service else None) + if registry: + ctx += self._render_node_history(registry, node.node_id) + + return ctx + + def _render_node_history(self, registry, node_id: str) -> str: + """Extracts and cleans the recent terminal history for a specific node.""" + live = registry.get_node(node_id) + if not live or not live.terminal_history: return "" + + chunks, total_len = [], 0 + for chunk in reversed(list(live.terminal_history)[-40:]): + c_str = chunk if isinstance(chunk, str) else chunk.get("output", str(chunk)) if isinstance(chunk, dict) else str(chunk) + chunks.insert(0, c_str) + total_len += len(c_str) + if total_len > 4000: break + + clean = ANSI_ESCAPE.sub('', "".join(chunks)) + if len(clean) > 2000: clean = "...[truncated]...\n" + clean[-2000:] + return f" Recent Terminal Output:\n ```\n {clean}\n ```\n" + + async def _process_event(self, db, session_id, event, state): + """Updates internal state and DB progress based on pipeline events.""" + e_type = event["type"] + if e_type == "content": state["answer"] += event["content"] + elif e_type == "reasoning": state["reasoning"] += event["content"] + elif e_type == "tool_start": + name = event.get("name") + if name: state["tool_counts"][name] = state["tool_counts"].get(name, {"calls":0, "successes":0, "failures":0}); state["tool_counts"][name]["calls"] += 1 + elif e_type == "tool_result": + name, res = event.get("name"), event.get("result") + if name and name in state["tool_counts"]: + if res and (not isinstance(res, dict) or res.get("success") is False): state["tool_counts"][name]["failures"] += 1 + else: state["tool_counts"][name]["successes"] += 1 + elif e_type == "token_counted": + u = event.get("usage", {}) + state["usage"]["input"] += u.get("prompt_tokens", 0); state["usage"]["output"] += u.get("completion_tokens", 0) + + # Persistent UI Observability: Commit assistant chunks occasionally + if e_type in ("content", "reasoning"): + await self._update_assistant_db(db, session_id, event, state) + + async def _update_assistant_db(self, db, session_id, event, state): + """Incrementally saves the assistant's response to the DB for real-time frontend visibility.""" + if not state["msg"]: + state["msg"] = models.Message(session_id=session_id, sender="assistant", content="") + db.add(state["msg"]) + await async_db_op(db.commit) + + if event["type"] == "content": state["msg"].content += event["content"] + elif event["type"] == "reasoning" and hasattr(state["msg"], "reasoning_content"): + state["msg"].reasoning_content = (state["msg"].reasoning_content or "") + event["content"] + + if (state["usage"]["input"] + state["usage"]["output"]) % 50 == 0: + try: await async_db_op(db.commit) + except: await async_db_op(db.rollback) + + async def _finalize_assistant_message(self, db, session_id, state) -> models.Message: + """Ensures the final assistant message is correctly persisted and closed.""" + msg = state["msg"] or models.Message(session_id=session_id, sender="assistant", content="") + msg.content = state["answer"] + if hasattr(msg, "reasoning_content"): msg.reasoning_content = state["reasoning"] + if not state["msg"]: db.add(msg) + await async_db_op(db.commit) + return msg def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: - """ - Retrieves all messages for a given session, ordered by creation time. - """ - session = db.query(models.Session).options( - joinedload(models.Session.messages) - ).filter(models.Session.id == session_id).first() - - return sorted(session.messages, key=lambda msg: msg.created_at) if session else None \ No newline at end of file + """Retrieves and sorts the conversational history for a session.""" + session = db.query(models.Session).options(joinedload(models.Session.messages)).filter(models.Session.id == session_id).first() + return sorted(session.messages, key=lambda m: m.created_at) if session else None \ No newline at end of file diff --git a/ai-hub/app/core/services/session.py b/ai-hub/app/core/services/session.py index b858953..41e288b 100644 --- a/ai-hub/app/core/services/session.py +++ b/ai-hub/app/core/services/session.py @@ -4,6 +4,7 @@ from sqlalchemy.orm import Session from sqlalchemy.orm.attributes import flag_modified from sqlalchemy.exc import SQLAlchemyError +from typing import Dict, List, Optional, Any from app.db import models from app.api import schemas @@ -14,7 +15,8 @@ self.services = services def _mount_skills_to_workspace(self, db: Session, session: models.Session): - if not session.sync_workspace_id: return + """Standardizes workspace skill availability by linking local skill files.""" + if not session.sync_workspace_id or not self.services: return try: orchestrator = getattr(self.services, "orchestrator", None) tool_service = getattr(self.services, "tool_service", None) @@ -29,25 +31,26 @@ from app.core.skills.fs_loader import fs_loader from app.config import settings - all_fs_skills = fs_loader.get_all_skills() - for fs_skill in all_fs_skills: + for fs_skill in fs_loader.get_all_skills(): skill_name = fs_skill.get("name") if skill_name in valid_tool_names: - feature = fs_skill.get("features", ["chat"])[0] - skill_id = fs_skill.get("id", "").replace("fs-", "") - skill_path = os.path.join(settings.DATA_DIR, "skills", feature, skill_id) - link_path = os.path.join(skills_dir, skill_name) - - if os.path.exists(skill_path): - if not os.path.exists(link_path): - try: - os.symlink(skill_path, link_path, target_is_directory=True) - except OSError: - pass + self._create_skill_symlink(fs_skill, skills_dir, settings.DATA_DIR) except Exception as e: logger.error(f"Failed to mount skills to workspace: {e}") + def _create_skill_symlink(self, fs_skill: dict, skills_dir: str, data_dir: str): + """Internal helper to safely create skill symlinks.""" + skill_name = fs_skill.get("name") + feature = fs_skill.get("features", ["chat"])[0] + skill_id = fs_skill.get("id", "").replace("fs-", "") + skill_path = os.path.join(data_dir, "skills", feature, skill_id) + link_path = os.path.join(skills_dir, skill_name) + + if os.path.exists(skill_path) and not os.path.exists(link_path): + try: os.symlink(skill_path, link_path, target_is_directory=True) + except OSError: pass + def create_session( self, db: Session, @@ -77,9 +80,9 @@ raise def auto_attach_default_nodes(self, db: Session, session: models.Session, request: schemas.SessionCreate): + """Automatically attaches a user's default nodes to a new session.""" user = db.query(models.User).filter(models.User.id == request.user_id).first() - if not user: - return session + if not user: return session node_prefs = (user.preferences or {}).get("nodes", {}) default_nodes = node_prefs.get("default_node_ids", []) @@ -90,53 +93,22 @@ db.commit() db.refresh(session) - try: - if self.services and hasattr(self.services, "orchestrator") and self.services.orchestrator.mirror: - self.services.orchestrator.mirror.get_workspace_path(session.sync_workspace_id) - except Exception as mirror_err: - logger.error(f"Failed to pre-initialize server mirror: {mirror_err}") - if default_nodes: session.attached_node_ids = list(default_nodes) - session.node_sync_status = { - nid: {"status": "pending", "last_sync": None} - for nid in default_nodes - } + session.node_sync_status = {nid: {"status": "pending", "last_sync": None} for nid in default_nodes} db.commit() - db.refresh(session) - registry = getattr(self.services, "node_registry_service", None) orchestrator = getattr(self.services, "orchestrator", None) - - try: - assistant = orchestrator.assistant if orchestrator else None - source = node_config.get("source", "empty") - path = node_config.get("path", "") - + if orchestrator and orchestrator.assistant: for nid in default_nodes: - if registry: - try: - registry.emit(nid, "info", { - "message": f"Auto-attached to session {session.id}", - "workspace_id": session.sync_workspace_id, + try: + self._trigger_orchestrator_sync(orchestrator.assistant, nid, session.sync_workspace_id, node_config) + if self.services.node_registry_service: + self.services.node_registry_service.emit(nid, "info", { + "message": f"Auto-attached to session {session.id}", "workspace_id": session.sync_workspace_id, }) - except Exception: pass + except Exception as e: logger.error(f"Auto-attach sync failed for {nid}: {e}") - if assistant: - try: - if source == "server": - assistant.push_workspace(nid, session.sync_workspace_id) - elif source == "empty": - assistant.push_workspace(nid, session.sync_workspace_id) - assistant.control_sync(nid, session.sync_workspace_id, action="START") - assistant.control_sync(nid, session.sync_workspace_id, action="UNLOCK") - elif source == "node_local": - assistant.request_manifest(nid, session.sync_workspace_id, path=path or ".") - assistant.control_sync(nid, session.sync_workspace_id, action="START", path=path or ".") - except Exception as sync_err: - logger.error(f"Failed to trigger sync for node {nid}: {sync_err}") - except Exception as e: - logger.error(f"Failed to initialize orchestrator sync: {e}") self._mount_skills_to_workspace(db, session) return session @@ -188,47 +160,18 @@ else: try: assistant = orchestrator.assistant - config = request.config or schemas.NodeWorkspaceConfig(source="empty") - old_config = session.sync_config or {} - - strategy_changed = False - if old_config and (config.source != old_config.get("source") or \ - config.path != old_config.get("path") or \ - config.source_node_id != old_config.get("source_node_id")): - strategy_changed = True - session.sync_config = config.model_dump() db.commit() if strategy_changed: - for nid in old_node_ids: - assistant.clear_workspace(nid, session.sync_workspace_id) - if getattr(orchestrator, "mirror", None): - orchestrator.mirror.purge(session.sync_workspace_id) + for nid in old_node_ids: assistant.clear_workspace(nid, session.sync_workspace_id) + if getattr(orchestrator, "mirror", None): orchestrator.mirror.purge(session.sync_workspace_id) else: - for nid in detached_nodes: - assistant.clear_workspace(nid, session.sync_workspace_id) + for nid in detached_nodes: assistant.clear_workspace(nid, session.sync_workspace_id) for nid in request.node_ids: - if config.source == "server": - assistant.push_workspace(nid, session.sync_workspace_id) - assistant.control_sync(nid, session.sync_workspace_id, action="LOCK") - elif config.source == "empty": - assistant.push_workspace(nid, session.sync_workspace_id) - assistant.control_sync(nid, session.sync_workspace_id, action="START") - assistant.control_sync(nid, session.sync_workspace_id, action="UNLOCK") - elif config.source == "node_local": - if config.source_node_id == nid: - assistant.request_manifest(nid, session.sync_workspace_id, path=config.path or ".") - assistant.control_sync(nid, session.sync_workspace_id, action="START", path=config.path or ".") - assistant.control_sync(nid, session.sync_workspace_id, action="UNLOCK") - else: - assistant.control_sync(nid, session.sync_workspace_id, action="START") - assistant.control_sync(nid, session.sync_workspace_id, action="LOCK") - assistant.push_workspace(nid, session.sync_workspace_id) - - if config.read_only_node_ids and nid in config.read_only_node_ids: - assistant.control_sync(nid, session.sync_workspace_id, action="LOCK") + try: self._trigger_orchestrator_sync(assistant, nid, session.sync_workspace_id, session.sync_config) + except Exception as e: logger.error(f"Manual sync failed for {nid}: {e}") except Exception as e: logger.error(f"Failed to trigger session node sync: {e}") @@ -246,4 +189,117 @@ for nid in session.attached_node_ids ], sync_config=session.sync_config or {} - ) \ No newline at end of file + ) + + def get_token_usage(self, db: Session, session_id: int) -> schemas.SessionTokenUsageResponse: + """Centralized token counter with effective LLM model resolution.""" + session = db.query(models.Session).filter(models.Session.id == session_id).first() + if not session: raise HTTPException(status_code=404, detail="Session not found.") + + messages = self.services.rag_service.get_message_history(db=db, session_id=session_id) + combined_text = " ".join([m.content for m in messages]) + + # Resolve effective LLM model + user = session.user + if not user: + from app.config import settings + admin_email = settings.SUPER_ADMINS[0] if settings.SUPER_ADMINS else None + user = db.query(models.User).filter(models.User.email == admin_email).first() + + provider, provider_name = self.services.preference_service.resolve_llm_provider(db, user, session.provider_name, session.model_name) + + from app.core.providers.factory import get_model_limit + try: + # Fallback model name if provider object is complex + m_name = getattr(provider, "model", "") if hasattr(provider, "model") else (session.model_name or "") + token_limit = get_model_limit(provider_name, model_name=m_name) + except: + token_limit = 100000 + + from app.core.orchestration.validator import Validator + validator = Validator(token_limit=token_limit) + token_count = validator.get_token_count(combined_text) + percentage = round((token_count / token_limit) * 100, 2) + + return schemas.SessionTokenUsageResponse(token_count=token_count, token_limit=token_limit, percentage=percentage) + + def archive_session(self, db: Session, session_id: int): + """Archives a session and purges associated node workspaces.""" + session = db.query(models.Session).filter(models.Session.id == session_id).first() + if not session: raise HTTPException(status_code=404, detail="Session not found.") + if session.is_locked: raise HTTPException(status_code=403, detail="Session is locked.") + + session.is_archived = True + wid = session.sync_workspace_id + db.commit() + if wid: self._broadcast_workspace_purge([wid]) + + def archive_all_feature_sessions(self, db: Session, user_id: str, feature_name: str) -> int: + """Archives all non-locked sessions for a specific feature and user.""" + sessions = db.query(models.Session).filter( + models.Session.user_id == user_id, models.Session.feature_name == feature_name, + models.Session.is_archived == False, models.Session.is_locked == False + ).all() + + wids = [s.sync_workspace_id for s in sessions if s.sync_workspace_id] + count = len(sessions) + for s in sessions: s.is_archived = True + db.commit() + + if wids: self._broadcast_workspace_purge(wids) + return count + + def _broadcast_workspace_purge(self, workspace_ids: List[str]): + """Helper to send PURGE commands to all active nodes and clean up Hub local mirror.""" + import shutil + from app.config import settings + from app.protos import agent_pb2 + + orchestrator = getattr(self.services, "orchestrator", None) + if not orchestrator: return + + live_nodes = orchestrator.registry.list_nodes() + for nid in live_nodes: + live = orchestrator.registry.get_node(nid) + if not live: continue + for wid in workspace_ids: + try: + live.send_message(agent_pb2.ServerTaskMessage( + file_sync=agent_pb2.FileSyncMessage( + session_id=wid, + control=agent_pb2.SyncControl(action=agent_pb2.SyncControl.PURGE) + ) + ), priority=0) + except: pass + + for wid in workspace_ids: + path = os.path.join(settings.DATA_DIR, "mirrors", wid) + if os.path.exists(path): + shutil.rmtree(path, ignore_errors=True) + + def _trigger_orchestrator_sync(self, assistant, nid, workspace_id, config): + """Unified sync dispatcher based on configured source strategy.""" + source = config.get("source", "empty") + path = config.get("path", ".") + source_nid = config.get("source_node_id") + read_only = (nid in config.get("read_only_node_ids", [])) if config.get("read_only_node_ids") else False + + if source == "server": + assistant.push_workspace(nid, workspace_id) + assistant.control_sync(nid, workspace_id, action="LOCK") + elif source == "empty": + assistant.push_workspace(nid, workspace_id) + assistant.control_sync(nid, workspace_id, action="START") + assistant.control_sync(nid, workspace_id, action="UNLOCK") + elif source == "node_local": + if source_nid == nid: + assistant.request_manifest(nid, workspace_id, path=path) + assistant.control_sync(nid, workspace_id, action="START", path=path) + assistant.control_sync(nid, workspace_id, action="UNLOCK") + else: + assistant.control_sync(nid, workspace_id, action="START") + assistant.control_sync(nid, workspace_id, action="LOCK") + assistant.push_workspace(nid, workspace_id) + + if read_only: + assistant.control_sync(nid, workspace_id, action="LOCK") \ No newline at end of file diff --git a/ai-hub/app/core/services/tool.py b/ai-hub/app/core/services/tool.py index cc05c92..40826cd 100644 --- a/ai-hub/app/core/services/tool.py +++ b/ai-hub/app/core/services/tool.py @@ -9,6 +9,15 @@ from app.core.tools.registry import tool_registry import time +from app.core._regex import ( + SKILL_CONFIG_JSON, SKILL_DESC_OVERRIDE, + SKILL_PARAM_TABLE, SKILL_BASH_LOGIC +) +import json +import yaml +import litellm +import shlex + logger = logging.getLogger(__name__) class ToolService: @@ -23,177 +32,119 @@ tool_registry.load_plugins() def get_available_tools(self, db: Session, user_id: str, feature: str = None, session_id: int = None) -> List[Dict[str, Any]]: - """ - Retrieves all tools the user is authorized to use, optionally filtered by feature. - """ - allowed_skill_names = None - if session_id and db: - session_obj = db.query(models.Session).filter(models.Session.id == session_id).first() - if session_obj and getattr(session_obj, "restrict_skills", False): - allowed_skill_names = set() - if session_obj.allowed_skill_names: - allowed_skill_names.update(session_obj.allowed_skill_names) - if getattr(session_obj, "skills", None): - allowed_skill_names.update(s.name for s in session_obj.skills) - - # 1. Fetch system/local skills and filter by feature if requested - local_skills = self._local_skills.values() - if feature: - local_skills = [s for s in local_skills if feature in getattr(s, "features", ["chat"])] + allowed_skill_names = self._get_allowed_skills(db, session_id) - tools = [s.to_tool_definition() for s in local_skills] + # 1. Local Skills + local_skills = [s.to_tool_definition() for s in self._local_skills.values() + if not feature or feature in getattr(s, "features", ["chat"])] if allowed_skill_names is not None: - tools = [t for t in tools if t["function"]["name"] in allowed_skill_names] + local_skills = [t for t in local_skills if t["function"]["name"] in allowed_skill_names] - # 2. Add FS-defined skills (System skills or user-owned) + # 2. VFS/FS Skills from app.core.skills.fs_loader import fs_loader - all_fs_skills = fs_loader.get_all_skills() - - class _DictObj: - def __init__(self, d): - for k, v in d.items(): - setattr(self, k, v) - db_skills = [] - for fs_skill in all_fs_skills: - if fs_skill.get("is_enabled", True) and (fs_skill.get("is_system") or fs_skill.get("owner_id") == user_id): - if feature and feature not in fs_skill.get("features", ["chat"]): - continue - # Map virtual files array to object arrays for the legacy parsing logic - fs_skill["files"] = [_DictObj(f) for f in fs_skill.get("files", [])] - db_skills.append(_DictObj(fs_skill)) - - if allowed_skill_names is not None: - db_skills = [ds for ds in db_skills if ds.name in allowed_skill_names] + max_md_len = self._resolve_model_max_len(db, user_id) - import litellm - max_md_len = 1000 + for fs_skill in fs_loader.get_all_skills(): + if not fs_skill.get("is_enabled", True): continue + if not (fs_skill.get("is_system") or fs_skill.get("owner_id") == user_id): continue + if feature and feature not in fs_skill.get("features", ["chat"]): continue + if allowed_skill_names is not None and fs_skill["name"] not in allowed_skill_names: continue + if any(t["function"]["name"] == fs_skill["name"] for t in local_skills): continue + + db_skills.append(self._parse_vfs_skill(fs_skill)) + + return local_skills + db_skills + + def _get_allowed_skills(self, db: Session, session_id: int) -> Optional[set]: + if not session_id or not db: return None + session_obj = db.query(models.Session).filter(models.Session.id == session_id).first() + if not session_obj or not getattr(session_obj, "restrict_skills", False): return None + + allowed = set() + if session_obj.allowed_skill_names: allowed.update(session_obj.allowed_skill_names) + if getattr(session_obj, "skills", None): allowed.update(s.name for s in session_obj.skills) + return allowed + + def _resolve_model_max_len(self, db: Session, user_id: str) -> int: + from app.config import settings + m_name = settings.ACTIVE_LLM_PROVIDER try: - # Attempt to resolve the active user's model configuration dynamically to get exact context sizes user = db.query(models.User).filter(models.User.id == user_id).first() if db else None - from app.config import settings - m_name = settings.ACTIVE_LLM_PROVIDER if user and user.preferences: - # User preference override m_name = user.preferences.get("llm_model", m_name) + if "/" not in m_name: + p = user.preferences.get("llm_provider", settings.ACTIVE_LLM_PROVIDER) + m_name = f"{p}/{m_name}" - # M6: Fix litellm mapping by ensuring provider prefix - if "/" not in m_name: - provider = user.preferences.get("llm_provider", settings.ACTIVE_LLM_PROVIDER) if user and user.preferences else settings.ACTIVE_LLM_PROVIDER - m_name = f"{provider}/{m_name}" - - model_info = litellm.get_model_info(m_name) - if model_info: - max_tokens = model_info.get("max_input_tokens", 8192) - # Cap a single skill's instruction block at 5% of the total context window to leave room - # for chat history and other plugins, with an absolute roof of 40k chars. (1 token ~= 4 chars) - max_md_len = max(min(int(max_tokens * 4 * 0.05), 40000), 1000) + info = litellm.get_model_info(m_name) + if info: + max_t = info.get("max_input_tokens", 8192) + return max(min(int(max_t * 4 * 0.05), 40000), 1000) except Exception as e: - logger.warning(f"Dynamic tool schema truncation failed to query model size: {e}") + logger.warning(f"Tool schema truncation fail: {e}") + return 1000 - for ds in db_skills: - # Prevent duplicates if name overlaps with local - if any(t["function"]["name"] == ds.name for t in tools): - continue + def _parse_vfs_skill(self, fs_skill: dict) -> dict: + name = fs_skill["name"] + description = fs_skill.get("description", "") + parameters = {"type": "object", "properties": {}, "required": []} + + # Binary/VFS normalization + class _Obj: + def __init__(self, d): + for k, v in d.items(): setattr(self, k, v) + files = [_Obj(f) for f in fs_skill.get("files", [])] + + skill_md = next((f for f in files if f.file_path == "SKILL.md"), None) + if skill_md and skill_md.content: + content = str(skill_md.content) + exec_file = next((f.file_path for f in files if f.file_path.endswith((".sh", ".py")) or "run." in f.file_path), "") + exec_cmd = f"bash .skills/{name}/{exec_file}" if exec_file.endswith(".sh") else f"python3 .skills/{name}/{exec_file}" if exec_file.endswith(".py") else f".skills/{name}/{exec_file}" - # --- Lazy-Loading VFS Pattern (Phase 3 - Skills as Folders) --- - # Extract parameters from SKILL.md frontmatter instead of legacy DB config column - description = ds.description or "" - parameters = {} + description += f"\n\n[Native VFS Skill - Execute via: `{exec_cmd}`]\n{content}" - skill_md_file = next((f for f in ds.files if f.file_path == "SKILL.md"), None) if ds.files else None + # YAML Frontmatter + if content.startswith("---"): + try: + parts = content.split("---", 2) + if len(parts) >= 3: + fm = yaml.safe_load(parts[1]) + parameters = fm.get("config", {}).get("parameters", parameters) + except: pass - if skill_md_file and skill_md_file.content: - skill_content_str = str(skill_md_file.content) - exec_file = "" - for f in ds.files: - if f.file_path.endswith(".sh") or f.file_path.endswith(".py") or "run." in f.file_path: - exec_file = f.file_path - break - exec_cmd = f"bash .skills/{ds.name}/{exec_file}" if exec_file.endswith(".sh") else f"python3 .skills/{ds.name}/{exec_file}" if exec_file.endswith(".py") else f".skills/{ds.name}/{exec_file}" + # Regex Parsers + if not parameters or not parameters.get("properties"): + mig_match = SKILL_CONFIG_JSON.search(content) + if mig_match: + try: parameters = json.loads(mig_match.group(1).strip()) + except: pass + + if not parameters or not parameters.get("properties"): + desc_match = SKILL_DESC_OVERRIDE.search(content) + if desc_match: + description = f"{desc_match.group(1).strip()}\n\n[Native VFS Skill - Execute via: `{exec_cmd}`]\n{content}" + + table_match = SKILL_PARAM_TABLE.search(content) + if table_match: + parameters = {"type": "object", "properties": {}, "required": []} + for row in table_match.group(1).strip().split('\n'): + cols = [c.strip() for c in row.split('|')][1:-1] + if len(cols) >= 4: + p_n = cols[0].replace('`', '').strip() + parameters["properties"][p_n] = {"type": cols[1].strip(), "description": cols[2].strip()} + if cols[3].strip().lower() in ['yes', 'true', '1', 'y']: + parameters["required"].append(p_n) - description += ( - f"\n\n[Native VFS Skill - Execute via: `{exec_cmd}`]\n" - f"{skill_content_str}" - ) - - # Parse YAML frontmatter to get the tool schema parameters - if skill_content_str.startswith("---"): - try: - import yaml - parts = skill_content_str.split("---", 2) - if len(parts) >= 3: - fm = yaml.safe_load(parts[1]) - parameters = fm.get("config", {}).get("parameters", {}) - except Exception as e: - logger.warning(f"Error parsing SKILL.md frontmatter for {ds.name}: {e}") - - # If no parameters found in frontmatter, try parsing markdown directly - if not parameters: - try: - import re - # Parse legacy migrated json configs - mig_match = re.search(r"### Tool Config JSON\s+```(?:yaml|json)\s+(.+?)\s+```", skill_content_str, re.DOTALL | re.IGNORECASE) - if mig_match: - try: - import json - parameters = json.loads(mig_match.group(1).strip()) - except: - pass - - if not parameters: - # Parse Description override (optional) - desc_match = re.search(r"\*\*Description:\*\*\s*(.*?)(?=\n\n|\n#|$)", skill_content_str, re.DOTALL | re.IGNORECASE) - if desc_match: - extracted_desc = desc_match.group(1).strip() - description = ( - f"{extracted_desc}\n\n[Native VFS Skill - Execute via: `{exec_cmd}`]\n" - f"{skill_content_str}" - ) - - # Parse Parameters Table - table_pattern = r"\|\s*Name\s*\|\s*Type\s*\|\s*Description\s*\|\s*Required\s*\|\n(?:\|[-:\s]+\|[-:\s]+\|[-:\s]+\|[-:\s]+\|\n)(.*?)(?=\n\n|\n#|$)" - param_table_match = re.search(table_pattern, skill_content_str, re.DOTALL | re.IGNORECASE) - if param_table_match: - parameters = {"type": "object", "properties": {}, "required": []} - rows = param_table_match.group(1).strip().split('\n') - for row in rows: - if not row.strip() or '|' not in row: continue - cols = [c.strip() for c in row.split('|')][1:-1] - if len(cols) >= 4: - p_name = cols[0].replace('`', '').strip() - p_type = cols[1].strip() - p_desc = cols[2].strip() - p_req = cols[3].strip().lower() in ['yes', 'true', '1', 'y'] - - parameters["properties"][p_name] = { - "type": p_type, - "description": p_desc - } - if p_req: - parameters["required"].append(p_name) - except Exception as e: - logger.warning(f"Error parsing SKILL.md markdown for {ds.name}: {e}") - - # Automatically inject logical node parameters into the schema for all tools - if not parameters: parameters = {"type": "object", "properties": {}, "required": []} + # Inject Node Selector + if "node_id" not in parameters.get("properties", {}): if "properties" not in parameters: parameters["properties"] = {} - if "node_id" not in parameters["properties"]: - parameters["properties"]["node_id"] = { - "type": "string", - "description": "Optional specific mesh node ID to execute this on. Leave empty to auto-use the session's first attached node. DO NOT invent or guess node IDs (e.g., node1), they must match the actual Node IDs in the context exactly." - } - - tools.append({ - "type": "function", - "function": { - "name": ds.name, - "description": description, - "parameters": parameters - } - }) + parameters["properties"]["node_id"] = { + "type": "string", + "description": "Mesh node ID for execution. Leave empty to use session default." + } - return tools + return {"type": "function", "function": {"name": name, "description": description, "parameters": parameters}} async def call_tool(self, tool_name: str, arguments: Dict[str, Any], db: Session = None, user_id: str = None, session_id: str = None, session_db_id: int = None, on_event = None, provider_name: str = None) -> Any: """ @@ -310,47 +261,9 @@ else: resolved_sid = session_id_arg - logger.info(f"[ToolService] Executing {skill.name} on {node_id or 'swarm'} (Resolved Session: {resolved_sid})") - - if db and user_id: - user = db.query(models.User).filter(models.User.id == user_id).first() - if user: - # Preference priority: - # 1. Passed provider (inherited from parent agent loop) - # 2. User preference - # 3. System default - if provider_name: - p_name = provider_name.split("/")[0] if "/" in provider_name else provider_name - actual_m_name = provider_name.split("/")[1] if "/" in provider_name else "" - else: - p_name = user.preferences.get("llm_provider", "gemini") - actual_m_name = user.preferences.get("llm_model", "") - - # Fetch provider-specific keys from user or system defaults - llm_prefs = user.preferences.get("llm", {}).get("providers", {}).get(p_name, {}) - user_service = getattr(self._services, "user_service", None) - - if (not llm_prefs or not llm_prefs.get("api_key") or "*" in str(llm_prefs.get("api_key"))) and user_service: - system_prefs = user_service.get_system_settings(db) - system_provider_prefs = system_prefs.get("llm", {}).get("providers", {}).get(p_name, {}) - if system_provider_prefs: - merged = system_provider_prefs.copy() - if llm_prefs: merged.update({k: v for k, v in llm_prefs.items() if v}) - llm_prefs = merged - - api_key_override = llm_prefs.get("api_key") - # actual_m_name is already set from provider_name or preferences above - # Fallback to provider default if still empty - if not actual_m_name: - actual_m_name = llm_prefs.get("model", "") - - kwargs = {k: v for k, v in llm_prefs.items() if k not in ["api_key", "model"]} - - try: - llm_provider = get_llm_provider(p_name, model_name=actual_m_name, api_key_override=api_key_override, **kwargs) - logger.info(f"[ToolService] AI Sub-Agent enabled using {p_name}/{actual_m_name}") - except Exception as e: - logger.warning(f"[ToolService] Could not init LLM for sub-agent: {e}") + llm_provider = self._resolve_llm_for_sub_agent(db, user_id, provider_name) + if llm_provider: + logger.info(f"[ToolService] AI Sub-Agent enabled for {skill.name}") # Define the task function and arguments for the SubAgent task_fn = None @@ -529,3 +442,10 @@ logger.warning(f"Failed to persist browser data to workspace: {sse}") return res + def _resolve_llm_for_sub_agent(self, db: Session, user_id: str, provider_name: str) -> Optional[Any]: + if not db or not user_id or not self._services: return None + user = db.query(models.User).filter(models.User.id == user_id).first() + if not user: return None + + provider, _ = self._services.preference_service.resolve_llm_provider(db, user, provider_name) + return provider diff --git a/ai-hub/integration_tests/conftest.py b/ai-hub/integration_tests/conftest.py index a2ef1ce..af1fd44 100644 --- a/ai-hub/integration_tests/conftest.py +++ b/ai-hub/integration_tests/conftest.py @@ -5,8 +5,19 @@ import httpx from datetime import datetime +from dotenv import load_dotenv + +# Try to find .env in current or parent dirs +env_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".env")) +if os.path.exists(env_path): + load_dotenv(dotenv_path=env_path) +else: + load_dotenv() # Fallback + BASE_URL = os.getenv("SYNC_TEST_BASE_URL", "http://127.0.0.1:8002/api/v1") -ADMIN_EMAIL = os.getenv("SUPER_ADMINS", "axieyangb@gmail.com").split(',')[0] +# Primary admin for integration tests: always try the FIRST one in the list +_admins = os.getenv("SUPER_ADMINS", "axieyangb@gmail.com").split(',') +ADMIN_EMAIL = _admins[0].strip() ADMIN_PASSWORD = os.getenv("CORTEX_ADMIN_PASSWORD", "admin") NODE_1 = os.getenv("SYNC_TEST_NODE1", "test-node-1") NODE_2 = os.getenv("SYNC_TEST_NODE2", "test-node-2") @@ -22,7 +33,7 @@ 5. Spin up node docker containers with correct tokens. """ print("\n[conftest] Starting Mesh Integration Setup...") - client = httpx.Client(timeout=30.0) + client = httpx.Client(timeout=90.0) # 1. Login print(f"[conftest] Logging in as {ADMIN_EMAIL}...") diff --git a/ai-hub/integration_tests/test_node_purge.py b/ai-hub/integration_tests/test_node_purge.py new file mode 100644 index 0000000..0f90d53 --- /dev/null +++ b/ai-hub/integration_tests/test_node_purge.py @@ -0,0 +1,59 @@ +import os +import httpx +import pytest +import uuid + +BASE_URL = os.getenv("SYNC_TEST_BASE_URL", "http://127.0.0.1:8002/api/v1") + +def _get_user_id() -> str: + return os.getenv("SYNC_TEST_USER_ID", "c4401d34-8784-4d6e-93a0-c702bd202b66") + +def _headers(): + return {"X-User-ID": _get_user_id(), "Authorization": os.getenv("SYNC_TEST_AUTH_TOKEN", "")} + +def test_node_self_purge_logic(): + """ + Verifies that a node can deregister itself using its invite_token via the /purge endpoint. + """ + user_id = _get_user_id() + node_id = f"purge-test-{uuid.uuid4().hex[:8]}" + + payload = { + "node_id": node_id, + "display_name": "Purge Test Node", + "is_active": True, + "skill_config": {"shell": {"enabled": True}} + } + + with httpx.Client(timeout=15.0) as client: + # 1. Register the node + r_create = client.post(f"{BASE_URL}/nodes/admin", params={"admin_id": user_id}, json=payload, headers=_headers()) + assert r_create.status_code == 200 + node_data = r_create.json() + invite_token = node_data["invite_token"] + + # Verify it exists + r_check = client.get(f"{BASE_URL}/nodes/admin/{node_id}", params={"admin_id": user_id}, headers=_headers()) + assert r_check.status_code == 200 + + # 2. Attempt purge with INVALID token + r_fail = client.post(f"{BASE_URL}/nodes/purge", params={"node_id": node_id, "token": "wrong-token"}) + assert r_fail.status_code == 401 + + # 3. Attempt purge with VALID token + r_purge = client.post(f"{BASE_URL}/nodes/purge", params={"node_id": node_id, "token": invite_token}) + assert r_purge.status_code == 200 + assert "deregistered" in r_purge.json()["message"] + + # 4. Verify the node is GONE from the Hub + r_gone = client.get(f"{BASE_URL}/nodes/admin/{node_id}", params={"admin_id": user_id}, headers=_headers()) + assert r_gone.status_code == 404 + + # 5. Verify it's gone from the mesh registry (indirectly) + r_list = client.get(f"{BASE_URL}/nodes/admin", params={"admin_id": user_id}, headers=_headers()) + assert not any(n["node_id"] == node_id for n in r_list.json()) + +def test_purge_non_existent_node(): + with httpx.Client() as client: + r = client.post(f"{BASE_URL}/nodes/purge", params={"node_id": "ghost-node", "token": "any"}) + assert r.status_code == 401 diff --git a/frontend/src/features/nodes/pages/NodesPage.js b/frontend/src/features/nodes/pages/NodesPage.js index 52b9527..1d53d60 100644 --- a/frontend/src/features/nodes/pages/NodesPage.js +++ b/frontend/src/features/nodes/pages/NodesPage.js @@ -14,6 +14,7 @@ const [error, setError] = useState(null); const [showCreateModal, setShowCreateModal] = useState(false); const [nodeToDelete, setNodeToDelete] = useState(null); + const [attemptPurge, setAttemptPurge] = useState(false); const [newNode, setNewNode] = useState({ node_id: '', display_name: '', description: '', skill_config: { shell: { enabled: true }, sync: { enabled: true } } }); const [expandedTerminals, setExpandedTerminals] = useState({}); // node_id -> boolean const [expandedNodes, setExpandedNodes] = useState({}); // node_id -> boolean @@ -124,8 +125,25 @@ const confirmDeleteNode = async () => { if (!nodeToDelete) return; try { + if (attemptPurge) { + // To perform a remote purge, we dispatch the purge command. + console.log(`[๐Ÿš€] Dispatching purge command to ${nodeToDelete}...`); + + const nodeObj = nodes.find(n => n.node_id === nodeToDelete); + const caps = nodeObj?.capabilities || {}; + const isWindows = caps.os === 'windows'; + const cmd = isWindows ? 'python purge.py' : 'python3 purge.py'; + + // We use the dispatch API to send the purge command + // This is a "fire and forget" action because the node will deregister itself. + import('../../../services/apiService').then(api => { + api.dispatchNodeTask(nodeToDelete, { command: cmd, user_id: user.id }); + }).catch(err => console.error("Failed to dispatch purge:", err)); + } + await adminDeleteNode(nodeToDelete); setNodeToDelete(null); + setAttemptPurge(false); fetchData(); } catch (err) { alert(err.message); @@ -885,6 +903,27 @@

Are you sure you want to completely deregister node {nodeToDelete}? This will permanently remove all access grants for this node.

+ + {/* Purge Option */} +
+ +