Newer
Older
cortex-hub / ai-hub / app / core / orchestration / architect.py
import logging
import queue
from typing import List, Dict, Any, Optional
from app.db import models
from .memory import ContextManager
from .stream import StreamProcessor
from .body import ToolExecutor
from .guards import SafetyGuard

class Architect:
    """
    The Master-Architect Orchestrator. 
    Decomposed successor to RagPipeline.
    """

    def __init__(self, context_manager: Optional[ContextManager] = None):
        self.memory = context_manager or ContextManager()
        self.stream = StreamProcessor()

    async def run(
        self,
        question: str,
        context_chunks: List[Dict[str, Any]],
        history: List[models.Message],
        llm_provider,
        prompt_service = None,
        tool_service = None,
        tools: List[Dict[str, Any]] = None,
        mesh_context: str = "",
        db = None,
        user_id: Optional[str] = None,
        sync_workspace_id: Optional[str] = None,
        session_id: Optional[int] = None,
        feature_name: str = "chat",
        prompt_slug: str = "rag-pipeline"
    ):
        # 1. Initialize Context & Messages
        messages = self.memory.prepare_initial_messages(
            question, context_chunks, history, feature_name, mesh_context, sync_workspace_id,
            db=db, user_id=user_id, prompt_service=prompt_service, prompt_slug=prompt_slug
        )

        # 2. Setup Mesh Observation
        mesh_bridge = queue.Queue()
        registry = self._get_registry(tool_service)
        if registry and user_id:
            registry.subscribe_user(user_id, mesh_bridge)

        # 3. Setup Guards
        safety = SafetyGuard(db, session_id)
        
        # 4. Main Autonomous Loop
        try:
            turn = 0
            while turn < 500:
                turn += 1
                self.stream.reset_turn()

                # A. Cancellation / Memory check
                if safety.check_cancellation():
                    yield {"type": "reasoning", "content": "\n> **🛑 User Interruption:** Terminating loop.\n"}
                    return
                messages = self.memory.compress_history(messages)

                # B. Turn Start Heartbeat
                self._update_turn_marker(messages, turn)
                yield {"type": "status", "content": f"Turn {turn}: architecting next step..."}

                # C. LLM Call
                prediction = await self._call_llm(llm_provider, messages, tools)
                if not prediction:
                     yield {"type": "error", "content": "LLM Provider failed to generate a response."}
                     return

                # D. Process Stream
                accumulated_content = ""
                accumulated_reasoning = ""
                tool_calls_map = {}
                
                async for chunk in prediction:
                    if not chunk.choices: continue
                    delta = chunk.choices[0].delta
                    
                    # Native reasoning (O-series)
                    r = getattr(delta, "reasoning_content", None) or delta.get("reasoning_content")
                    if r: 
                        accumulated_reasoning += r
                        yield {"type": "reasoning", "content": r}
                    
                    # Content & Thinking Tags
                    c = getattr(delta, "content", None) or delta.get("content")
                    if c:
                        async for event in self.stream.process_chunk(c, turn):
                            if event["type"] == "content": accumulated_content += event["content"]
                            if event["type"] == "reasoning": accumulated_reasoning += event["content"]
                            yield event
                    
                    # Tool delta accumulation
                    self._accumulate_tool_calls(delta, tool_calls_map)

                # E. Branch: Tools or Exit?
                if not tool_calls_map:
                    # Watchdog Check
                    if safety.should_activate_watchdog(self._get_assistant(tool_service), sync_workspace_id):
                        yield {"type": "status", "content": "Watchdog: tasks remain. continuing..."}
                        messages.append({"role": "user", "content": "WATCHDOG: .ai_todo.md has open items. Please continue until all are marked [COMPLETED]."})
                        continue
                    return # Natural exit

                # F. Execute Tools
                processed_tc = list(tool_calls_map.values())
                if safety.detect_loop(processed_tc):
                    yield {"type": "reasoning", "content": "\n> **🚨 Loop Guard:** Repetitive plan detected. Retrying with warnings.\n"}
                    messages.append({"role": "user", "content": "LOOP GUARD: You are stuck. Change strategy."})
                    continue

                yield {"type": "status", "content": f"Architect analysis complete. Dispatching {len(processed_tc)} tools..."}
                
                # Append assistant message to history
                messages.append(self._format_assistant_msg(accumulated_content, accumulated_reasoning, processed_tc))

                # Run parallel execution
                executor = ToolExecutor(tool_service, user_id, db, sync_workspace_id)
                async for event in executor.run_tools(processed_tc, safety, mesh_bridge):
                    if "role" in event: # It's a tool result for history
                         messages.append(event)
                    else:
                         yield event

        finally:
            if registry and user_id:
                registry.unsubscribe_user(user_id, mesh_bridge)

    # --- Internal Helpers ---

    def _get_registry(self, tool_service):
        if tool_service and hasattr(tool_service, "_services"):
            orchestrator = getattr(tool_service._services, "orchestrator", None)
            return getattr(orchestrator, "registry", None)
        return None

    def _get_assistant(self, tool_service):
        if tool_service and hasattr(tool_service, "_services"):
            orchestrator = getattr(tool_service._services, "orchestrator", None)
            return getattr(orchestrator, "assistant", None)
        return None

    def _update_turn_marker(self, messages, turn):
        if messages[0]["role"] == "system":
            base = messages[0]["content"].split("[System:")[0].strip()
            messages[0]["content"] = base + f"\n\n[System: Current Turn: {turn}]"

    async def _call_llm(self, llm_provider, messages, tools):
        kwargs = {"stream": True}
        if tools:
            kwargs["tools"] = tools
            kwargs["tool_choice"] = "auto"
        try:
            return await llm_provider.acompletion(messages=messages, **kwargs)
        except Exception as e:
            logging.error(f"[Architect] LLM Exception: {e}")
            return None

    def _accumulate_tool_calls(self, delta, t_map):
        tc_deltas = getattr(delta, "tool_calls", None) or delta.get("tool_calls")
        if not tc_deltas: return
        for tcd in tc_deltas:
            idx = tcd.index
            if idx not in t_map:
                t_map[idx] = tcd
            else:
                if getattr(tcd, "id", None): t_map[idx].id = tcd.id
                if tcd.function.name: t_map[idx].function.name = tcd.function.name
                if tcd.function.arguments: t_map[idx].function.arguments += tcd.function.arguments

    def _format_assistant_msg(self, content, reasoning, tool_calls):
        clean_tc = []
        for tc in tool_calls:
            clean_tc.append({
                "id": tc.id, "type": "function",
                "function": {"name": tc.function.name, "arguments": tc.function.arguments}
            })
        msg = {"role": "assistant", "content": content or None, "tool_calls": clean_tc}
        if reasoning: msg["reasoning_content"] = reasoning
        return msg