Newer
Older
cortex-hub / ai-hub / app / core / orchestration / architect.py
import logging
import queue
import time
from typing import List, Dict, Any, Optional

from app.db import models
from .memory import ContextManager
from .stream import StreamProcessor
from .body import ToolExecutor
from .guards import SafetyGuard

class Architect:
    """
    The Master-Architect Orchestrator. 
    Decomposed successor to RagPipeline. 100% REGEX-FREE.
    """

    def __init__(self, context_manager: Optional[ContextManager] = None):
        self.memory = context_manager or ContextManager()
        self.stream = None # Created during run()

    async def run(
        self,
        question: str,
        context_chunks: List[Dict[str, Any]],
        history: List[models.Message],
        llm_provider,
        prompt_service = None,
        tool_service = None,
        tools: List[Dict[str, Any]] = None,
        mesh_context: str = "",
        db = None,
        user_id: Optional[str] = None,
        sync_workspace_id: Optional[str] = None,
        session_id: Optional[int] = None,
        feature_name: str = "chat",
        prompt_slug: str = "rag-pipeline",
        session_override: Optional[str] = None
    ):
        """Dispatches an autonomous orchestration loop with turn-based strategy."""
        messages = self.memory.prepare_initial_messages(
            question, context_chunks, history, feature_name, mesh_context, sync_workspace_id,
            db=db, user_id=user_id, prompt_service=prompt_service, prompt_slug=prompt_slug,
            tools=tools, session_override=session_override
        )
        
        logging.info(f"[Architect] Starting loop. Prompt Size: {sum(len(m.get('content','') or '') for m in messages)} chars.")

        mesh_bridge = queue.Queue()
        registry = self._get_registry(tool_service)
        if registry and user_id: registry.subscribe_user(user_id, mesh_bridge)

        safety = SafetyGuard(db, session_id)
        from .profiles import get_profile
        profile = get_profile(feature_name)
        self.stream = StreamProcessor(profile=profile)
        turn = 0
        start_time = time.time()
        try:
            while turn < profile.autonomous_limit:
                turn += 1
                self.stream.reset_turn()

                if safety.check_cancellation():
                    yield {"type": "reasoning", "content": "\n> **🛑 User Interruption:** Terminating loop.\n"}
                    break
                
                messages = await self.memory.compress_history(messages, llm_provider)
                self._update_turn_marker(messages, turn)
                if profile.show_heartbeat: yield {"type": "status", "content": f"Turn {turn}: architecting next step"}

                prediction = await self._call_llm(llm_provider, messages, tools)
                if not prediction: break

                # A. Handle Stream Turn
                content, reasoning, tc_map, finish_reason = "", "", {}, None
                async for event in self._process_llm_stream(prediction, turn, profile):
                    e_type = event.get("type")
                    if e_type == "content": content += event["content"]
                    elif e_type == "reasoning": reasoning += event["content"]
                    elif e_type == "tool_calls_detected": tc_map.update(event["map"])
                    elif e_type == "finish_reason": finish_reason = event["reason"]
                    yield event

                # B. Decision Branch: Tools or Exit?
                if not tc_map:
                    events = []
                    should_continue = await self._handle_no_tools_branch(finish_reason, content, reasoning, profile, safety, tool_service, sync_workspace_id, messages, events)
                    for e in events: yield e
                    if should_continue:
                        continue # Watchdog or continuation triggered
                    break # Natural exit

                # C. Execute Tools
                processed_tc = list(tc_map.values())
                if safety.detect_loop(processed_tc):
                    yield {"type": "reasoning", "content": "\n> **🚨 Loop Guard:** Loop detected.\n"}
                    messages.append({"role": "user", "content": "STUCK: Change strategy."})
                    continue

                yield {"type": "status", "content": f"Dispatching {len(processed_tc)} tools"}
                messages.append(self._format_assistant_msg(content, reasoning, processed_tc))

                executor = ToolExecutor(tool_service, user_id, db, sync_workspace_id, session_id, provider_name=getattr(llm_provider, "provider_name", None))
                async for event in executor.run_tools(processed_tc, safety, mesh_bridge):
                    if "role" in event: messages.append(event)
                    else: yield event
            
            elapsed = time.time() - start_time
            if turn >= profile.autonomous_limit:
                yield {"type": "status", "content": f"Autonomous limit reached after {elapsed:.1f}s. Please provide more instructions if needed."}
            else:
                yield {"type": "status", "content": f"Task complete in {elapsed:.1f}s"}

        except Exception as e:
            import traceback
            logging.error(f"[Architect] CRITICAL FAULT:\n{traceback.format_exc()}")
            yield {"type": "status", "content": "Fatal Orchestration Error"}
            yield {"type": "content", "content": f"\n\n> **🚨 Core Orchestrator Fault:** `{str(e)}`"}
        finally:
            if registry and user_id: registry.unsubscribe_user(user_id, mesh_bridge)
            # --- M7: Automatic Terminal Task Cancellation ---
            if tool_service and hasattr(tool_service, "_services") and sync_workspace_id:
                try: 
                    mesh = getattr(tool_service._services, "mesh_service", None)
                    if mesh: mesh.cancel_session(sync_workspace_id, user_id, db)
                except: pass

    async def _process_llm_stream(self, prediction, turn, profile):
        """Internal helper for processing raw LLM stream into architectural events."""
        tc_map, finish_reason = {}, None
        async for chunk in prediction:
            if getattr(chunk, "usage", None):
                yield {"type": "token_counted", "usage": getattr(chunk, "usage").model_dump() if hasattr(getattr(chunk, "usage"), "model_dump") else getattr(chunk, "usage")}
            
            if not chunk.choices: continue
            delta = chunk.choices[0].delta
            finish_reason = getattr(chunk.choices[0], "finish_reason", None) or chunk.choices[0].get("finish_reason")
            
            # Native reasoning (O-series)
            r = getattr(delta, "reasoning_content", None) or delta.get("reasoning_content")
            if r: yield {"type": "reasoning", "content": r}
            
            # Content & Thinking Tags
            c = getattr(delta, "content", None) or delta.get("content")
            if c:
                async for event in self.stream.process_chunk(c, turn):
                    if not (profile.buffer_content and event["type"] == "content"): yield event
            
            self._accumulate_tool_calls(delta, tc_map)

        async for event in self.stream.end_stream(turn):
             if not (profile.buffer_content and event["type"] == "content"): yield event

        # Standardize tool calls for JSON serialization in the event stream
        serializable_tc_map = {}
        for idx, tc in tc_map.items():
            if hasattr(tc, "model_dump"):
                serializable_tc_map[idx] = tc.model_dump()
            else:
                # Manual fallback for non-Pydantic objects
                serializable_tc_map[idx] = {
                    "id": getattr(tc, "id", None),
                    "type": "function",
                    "function": {
                        "name": getattr(tc.function, "name", ""),
                        "arguments": getattr(tc.function, "arguments", "")
                    }
                }

        yield {"type": "tool_calls_detected", "map": serializable_tc_map}
        yield {"type": "finish_reason", "reason": finish_reason}

    async def _handle_no_tools_branch(self, finish_reason, content, reasoning, profile, safety, tool_svc, ws_id, messages, events_out: list) -> bool:
        """Determines if a no-tool turn should exit or trigger a continuation/watchdog."""
        if finish_reason == "length":
            messages.append({"role": "user", "content": "You were cut off. Please continue."})
            return True

        if not content.strip():
            fallback = self.stream._apply_turn_header(reasoning.strip()) if reasoning.strip() else ""
            if not fallback.strip(): fallback = "Task complete. Check thought trace for details."
            events_out.append({"type": "content", "content": fallback})
        elif profile.buffer_content:
            events_out.append({"type": "content", "content": self.stream._apply_turn_header(content).strip()})

        if safety.should_activate_watchdog(self._get_assistant(tool_svc), ws_id):
            messages.append({"role": "user", "content": "WATCHDOG: .ai_todo.md is not empty. Proceed."})
            return True
        
        return False


    # --- Internal Helpers ---

    def _get_registry(self, tool_service):
        if tool_service and hasattr(tool_service, "_services"):
            orchestrator = getattr(tool_service._services, "orchestrator", None)
            return getattr(orchestrator, "registry", None)
        return None

    def _get_assistant(self, tool_service):
        if tool_service and hasattr(tool_service, "_services"):
            orchestrator = getattr(tool_service._services, "orchestrator", None)
            return getattr(orchestrator, "assistant", None)
        return None

    def _update_turn_marker(self, messages, turn):
        if messages[0]["role"] == "system":
            content = messages[0]["content"]
            marker_anchor = "[System: Current Turn:"
            if marker_anchor in content:
                 messages[0]["content"] = content.split(marker_anchor)[0].strip() + f"\n\n{marker_anchor} {turn}]"
            else:
                 messages[0]["content"] = content + f"\n\n{marker_anchor} {turn}]"

    async def _call_llm(self, llm_provider, messages, tools):
        kwargs = {"stream": True, "stream_options": {"include_usage": True}, "max_tokens": 4096}
        if tools:
            kwargs["tools"] = tools
            kwargs["tool_choice"] = "auto"
        try:
            return await llm_provider.acompletion(messages=messages, timeout=60, **kwargs)
        except Exception as e:
            logging.error(f"[Architect] LLM Exception: {e}")
            return None

    def _accumulate_tool_calls(self, delta, t_map):
        tc_deltas = getattr(delta, "tool_calls", None) or delta.get("tool_calls")
        if not tc_deltas: return
        for tcd in tc_deltas:
            idx = tcd.index
            if idx not in t_map:
                t_map[idx] = tcd
            else:
                if getattr(tcd, "id", None): t_map[idx].id = tcd.id
                if tcd.function.name: t_map[idx].function.name = tcd.function.name
                if tcd.function.arguments: t_map[idx].function.arguments += tcd.function.arguments

    def _format_assistant_msg(self, content, reasoning, tool_calls):
        if content:
             # Fast string cleaning instead of regex for assistant message formatting
             content = self.stream._apply_turn_header(content).strip()

        clean_tc = []
        for tc in tool_calls:
            # Handle both object and dict access (Migration to Serializable Swarm)
            tc_id = tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None)
            tc_func = tc.get("function", {}) if isinstance(tc, dict) else getattr(tc, "function", None)
            func_name = tc_func.get("name") if isinstance(tc_func, dict) else getattr(tc_func, "name", "")
            func_args = tc_func.get("arguments") if isinstance(tc_func, dict) else getattr(tc_func, "arguments", "")
            
            clean_tc.append({
                "id": tc_id, "type": "function",
                "function": {"name": func_name, "arguments": func_args}
            })
        msg = {"role": "assistant", "content": content or None, "tool_calls": clean_tc}
        if reasoning: msg["reasoning_content"] = reasoning
        return msg