Newer
Older
cortex-hub / ai-hub / app / core / services / sub_agent.py
import asyncio
import time
import logging
import json
from app.protos import agent_pb2

logger = logging.getLogger(__name__)

class SubAgent:
    """
    A stateful watcher for a specific task on an agent node.
    Handles execution, result accumulation, and state monitoring.
    """
    def __init__(self, name: str, task_fn, args: dict, retries: int = 1, 
                 llm_provider=None, assistant=None, subagent_system_prompt: str = None,
                 on_event=None):
        self.name = name
        self.task_fn = task_fn
        self.args = args
        self.retries = retries
        self.llm = llm_provider
        self.assistant = assistant
        self.subagent_system_prompt = subagent_system_prompt
        self.on_event = on_event
        self.status = "PENDING"
        self.result = None
        self.start_time = None
        self.end_time = None
        self.error = None
        self.task_id = None

    async def run(self):
        self.start_time = time.time()
        self.status = "RUNNING"
        
        # If AI-monitoring is disabled or not enough context, fallback to standard execution
        if not self.llm or not self.assistant or not self.subagent_system_prompt:
            return await self._run_standard()
        
        return await self._run_ai_powered()

    async def _run_standard(self):
        """Legacy blocking execution with simple retry logic."""
        for attempt in range(self.retries + 1):
            try:
                self.result = await asyncio.to_thread(self.task_fn, **self.args)
                if isinstance(self.result, dict) and self.result.get("error"):
                    err_msg = str(self.result.get("error")).lower()
                    is_busy = "busy" in err_msg
                    if is_busy or any(x in err_msg for x in ["timeout", "offline", "disconnected", "capacity", "rejected"]):
                        if attempt < self.retries:
                            backoff = (attempt + 1) * 3
                            self.status = f"RETRYING ({attempt+1}/{self.retries})"
                            await asyncio.sleep(backoff)
                            continue
                self.status = "COMPLETED"
                break
            except Exception as e:
                logger.error(f"SubAgent {self.name} execution error: {e}")
                self.error = str(e)
                if attempt < self.retries:
                    self.status = f"ERROR_RETRYING ({attempt+1}/{self.retries})"
                    await asyncio.sleep(2)
                else:
                    self.status = "FAILED"
        self.end_time = time.time()
        return self.result

    async def _run_ai_powered(self):
        """AI-powered 'Observe-Think-Act' loop for per-node task management."""
        logger.info(f"[🤖 SubAgent] Starting AI-powered monitoring for {self.name}")
        
        # 1. Initiate task with no_abort=True and respect requested timeout for init
        requested_timeout = int(self.args.get("timeout", 5))
        init_timeout = min(5, requested_timeout) if requested_timeout > 0 else 5
        init_args = {**self.args, "no_abort": True, "timeout": init_timeout}
        node_id = init_args.get("node_id") or "swarm"
        
        try:
            # Emit a 'start' thought immediately for every dispatch
            start_msg = f"🚀 Dispatching: `{self.args.get('command', '?')}` on `{node_id}` (init_timeout={init_timeout}s)"
            logger.info(f"    [🤖 SubAgent] {start_msg}")
            if self.on_event:
                await self.on_event({"type": "subagent_thought", "node_id": node_id, "content": start_msg})

            res = await asyncio.to_thread(self.task_fn, **init_args)
            
            # Swarm handling (might return map of node_id -> result)
            task_map = {}
            if "task_id" in res:
                task_map = {node_id: res["task_id"]}
            elif isinstance(res, dict) and not any(k in res for k in ["stdout", "error"]):
                # Looks like a swarm map
                task_map = {nid: r.get("task_id") for nid, r in res.items() if r.get("status") == "TIMEOUT_PENDING"}
            
            if not task_map:
                # Task completed immediately — emit a completion thought
                status_icon = "✅" if not (isinstance(res, dict) and res.get("error")) else "❌"
                stdout_preview = ""
                if isinstance(res, dict):
                    raw = res.get("stdout") or res.get("error") or ""
                    stdout_preview = raw.strip()[-300:] if len(raw.strip()) > 300 else raw.strip()
                done_msg = f"{status_icon} Quick-complete on `{node_id}`. Output preview: `{stdout_preview}`"
                if self.on_event:
                    await self.on_event({"type": "subagent_thought", "node_id": node_id, "content": done_msg})
                logger.info(f"[🤖 SubAgent] Task finished immediately or failed.")
                self.status = "COMPLETED"
                self.result = res
                self.end_time = time.time()
                return res

            # 2. Intelligence Loop
            max_loops = 50
            for loop in range(max_loops):
                # A. FAST-PATH HEURISTIC: Check for prompts before sleeping/AI analysis
                from app.core.grpc.services.assistant import TaskAssistant
                peek = await asyncio.to_thread(self.assistant.wait_for_swarm, task_map, timeout=0, no_abort=True)
                
                heuristic_action = self._check_heuristics(peek)
                if heuristic_action == "FINISH":
                    fast_path_reason = "Prompt detected - Finishing task."
                    logger.info(f"    [⚡ Fast-Path] {fast_path_reason} {self.name}")
                    
                    # Emit to UI
                    for nid, tid in task_map.items():
                        self.assistant.registry.emit(nid, "subagent_thought", fast_path_reason)
                        if tid:
                            self.assistant.journal.add_thought(tid, fast_path_reason)

                    self.result = await asyncio.to_thread(self.assistant.wait_for_swarm, task_map, timeout=2)
                    self.status = "COMPLETED"
                    break

                # B. AI Analysis Loop
                # Analyze with AI
                analysis = await self._analyze_progress(peek)
                action = analysis.get("action", "WAIT")
                reason = analysis.get("reason", "")
                
                # C. Smart Wait: AI determines how long to wait before next tick
                # Default to 5s if not specified, range 1s to 60s
                wait_time = analysis.get("next_wait", 5)
                try:
                    wait_time = max(1, min(60, int(wait_time)))
                except:
                    wait_time = 5

                logger.info(f"    [🔍 AI] Loop {loop}: Action={action} | Wait={wait_time}s | {reason}")
                
                # Emit thinking process and record in journal
                for nid, tid in task_map.items():
                    msg = f"{reason} (Next check in {wait_time}s)"
                    self.assistant.registry.emit(nid, "subagent_thought", msg)
                    if self.on_event:
                        await self.on_event({"type": "subagent_thought", "node_id": nid, "content": msg})
                    if tid:
                        self.assistant.journal.add_thought(tid, reason)

                if action == "FINISH" or all(r.get("status") not in ["RUNNING", "TIMEOUT_PENDING"] for r in peek.values()):
                    # One last blocking wait to gather final result if needed
                    self.result = await asyncio.to_thread(self.assistant.wait_for_swarm, task_map, timeout=10)
                    self.status = "COMPLETED"
                    # Emit completion summary
                    for nid in task_map:
                        res_preview = ""
                        if isinstance(self.result, dict):
                            node_res = self.result.get(nid, self.result)
                            raw = (node_res.get("stdout") or node_res.get("error") or "") if isinstance(node_res, dict) else str(node_res)
                            res_preview = raw.strip()[-400:] if len(raw.strip()) > 400 else raw.strip()
                        elapsed = int(time.time() - self.start_time)
                        done_msg = f"✅ Task complete on `{nid}` in {elapsed}s. Output:\n```\n{res_preview}\n```"
                        if self.on_event:
                            await self.on_event({"type": "subagent_thought", "node_id": nid, "content": done_msg})
                    break
                
                if action == "EXECUTE":
                    cmd = analysis.get("command")
                    target_nid = analysis.get("node_id") or analysis.get("node_ids")
                    if cmd and target_nid:
                        exec_reason = f"Branching Execution: Running '{cmd}' on {target_nid}"
                        logger.info(f"    [🚀 Branch] {exec_reason}")
                        
                        # Emit branch thinking
                        for nid in (target_nid if isinstance(target_nid, list) else [target_nid]):
                            self.assistant.registry.emit(nid, "subagent_thought", exec_reason)
                            if self.on_event:
                                await self.on_event({"type": "subagent_thought", "node_id": nid, "content": exec_reason})

                        # Dispatch new tasks
                        if isinstance(target_nid, list):
                            new_res = await asyncio.to_thread(self.assistant.dispatch_swarm, target_nid, cmd, no_abort=True)
                            for nid, r in new_res.items():
                                if r.get("task_id"):
                                    task_map[nid] = r["task_id"]
                        else:
                            new_res = await asyncio.to_thread(self.assistant.dispatch_single, target_nid, cmd, no_abort=True)
                            if new_res.get("task_id") and not cmd.startswith("!RAW:"):
                                task_map[target_nid] = new_res["task_id"]
                    
                    # Continue monitoring all tasks (old + new)
                    continue

                if action == "ABORT":
                    # Kill tasks
                    for nid, tid in task_map.items():
                        await asyncio.to_thread(self.assistant.registry.get_node(nid).queue.put, 
                                             agent_pb2.ServerTaskMessage(task_cancel=agent_pb2.TaskCancelRequest(task_id=tid)))
                    self.status = "ABORTED"
                    self.result = {"error": "AI aborted task", "reason": analysis.get("reason")}
                    break
                
                # Dynamic sleep based on AI recommendation OR Edge Signal
                await self._edge_aware_sleep(task_map, wait_time)
            
            self.status = "COMPLETED"
            self.end_time = time.time()
            return self.result

        except Exception as e:
            logger.exception("[🤖 SubAgent] AI Intelligence Loop Crashed")
            return await self._run_standard()

    async def _edge_aware_sleep(self, task_map, timeout):
        """Wait for timeout OR until any node in task_map signals a prompt."""
        # Find all prompt events for our tasks
        events = []
        with self.assistant.journal.lock:
            for tid in task_map.values():
                if tid in self.assistant.journal.tasks:
                    events.append(self.assistant.journal.tasks[tid]["prompt_event"])
        
        if not events:
            await asyncio.sleep(timeout)
            return

        def waiter():
            # Wait for ANY of the events to be set, or timeout
            start = time.time()
            while time.time() - start < timeout:
                for ev in events:
                    if ev.is_set():
                        return True
                time.sleep(0.1)
            return False

        # Run the multi-event wait in a thread to keep Hub event loop free
        await asyncio.to_thread(waiter)

    def _check_heuristics(self, peek_results: dict) -> str:
        """Detects common shell/REPL prompts in stdout to trigger early finish."""
        import re
        # Patterns for bash, zsh, python, node, and generic prompts
        # We look at the last ~100 characters of stdout for speed
        PROMPT_PATTERNS = [
            r"[\r\n].*[@\w\.\-]+:.*[#$]\s*$",  # bash/zsh: user@host:~$
            r">>>\s*$",                        # python
            r"\.\.\.\s*$",                      # python multi-line
            r">\s*$",                           # node/js
        ]
        
        all_ready = True
        for nid, res in peek_results.items():
            # If the task is already finished by the node, it's ready
            status = res.get("status")
            if status not in ["RUNNING", "TIMEOUT_PENDING"]:
                continue

            stdout = res.get("stdout", "")
            tail = stdout[-100:] if len(stdout) > 100 else stdout
            
            is_at_prompt = any(re.search(p, tail) for p in PROMPT_PATTERNS)
            if not is_at_prompt:
                all_ready = False
                break
        
        return "FINISH" if all_ready and peek_results else "WAIT"

    async def _analyze_progress(self, peek_results):
        """Calls LLM to analyze the live stream and decide next move."""
        try:
            prompt = (
                f"SYSTEM PROMPT: {self.subagent_system_prompt}\n\n"
                f"CURRENT TERMINAL STATE (SWARM):\n"
                f"{json.dumps(peek_results, indent=2)}\n\n"
                "INSTRUCTIONS: Analyze the status. Respond ONLY with a JSON object:\n"
                "{\n"
                "  \"action\": \"WAIT\" | \"FINISH\" | \"ABORT\" | \"EXECUTE\",\n"
                "  \"reason\": \"...\",\n"
                "  \"next_wait\": <int_seconds_to_wait_before_next_analysis>,\n"
                "  \"command\": \"<optional_new_command_to_run_if_action_is_EXECUTE>\",\n"
                "  \"node_id\": \"<node_to_run_command_on>\"\n"
                "}\n"
                "Tip: Use 'EXECUTE' for branching agency (e.g. if node 1 is ready, run a new command on node 2). "
                "For long tasks, set next_wait to 10-20. For quick ticks, use 3-5."
            )
            response = await self.llm.acompletion(prompt=prompt, response_format={"type": "json_object"})
            return json.loads(response.choices[0].message.content)
        except:
            return {"action": "WAIT", "reason": "AI analysis unavailable, continuing default wait.", "next_wait": 5}

    def get_elapsed(self) -> int:
        if not self.start_time:
            return 0
        end = self.end_time or time.time()
        return int(end - self.start_time)