Newer
Older
cortex-hub / ai-hub / app / core / orchestration / architect.py
import logging
import queue
import time
from typing import List, Dict, Any, Optional

from app.db import models
from .memory import ContextManager
from .stream import StreamProcessor
from .body import ToolExecutor
from .guards import SafetyGuard

class Architect:
    """
    The Master-Architect Orchestrator. 
    Decomposed successor to RagPipeline. 100% REGEX-FREE.
    """

    def __init__(self, context_manager: Optional[ContextManager] = None):
        self.memory = context_manager or ContextManager()
        self.stream = None # Created during run()

    async def run(
        self,
        question: str,
        context_chunks: List[Dict[str, Any]],
        history: List[models.Message],
        llm_provider,
        prompt_service = None,
        tool_service = None,
        tools: List[Dict[str, Any]] = None,
        mesh_context: str = "",
        db = None,
        user_id: Optional[str] = None,
        sync_workspace_id: Optional[str] = None,
        session_id: Optional[int] = None,
        feature_name: str = "chat",
        prompt_slug: str = "rag-pipeline",
        session_override: Optional[str] = None
    ):
        # 1. Initialize Context & Messages
        messages = self.memory.prepare_initial_messages(
            question, context_chunks, history, feature_name, mesh_context, sync_workspace_id,
            db=db, user_id=user_id, prompt_service=prompt_service, prompt_slug=prompt_slug,
            tools=tools, session_override=session_override
        )
        
        # DEBUG: Log the total prompt size to detect bloated contexts
        total_chars = sum(len(m.get("content", "") or "") for m in messages)
        logging.info(f"[Architect] Starting autonomous loop (Turn 1). Prompt Size: {total_chars} chars across {len(messages)} messages.")

        # 2. Setup Mesh Observation
        mesh_bridge = queue.Queue()
        registry = self._get_registry(tool_service)
        if registry and user_id:
            registry.subscribe_user(user_id, mesh_bridge)

        # 3. Setup Guards
        safety = SafetyGuard(db, session_id)
        
        # 4. Main Autonomous Loop
        from .profiles import get_profile
        profile = get_profile(feature_name)
        self.stream = StreamProcessor(profile=profile)
        turn = 0
        
        try:
            while turn < profile.autonomous_limit:
                turn += 1
                self.stream.reset_turn()

                # A. Cancellation / Memory check
                if safety.check_cancellation():
                    yield {"type": "reasoning", "content": "\n> **🛑 User Interruption:** Terminating loop.\n"}
                    return
                messages = await self.memory.compress_history(messages, llm_provider)

                # B. Turn Start Heartbeat
                self._update_turn_marker(messages, turn)
                if profile.show_heartbeat:
                    yield {"type": "status", "content": f"Turn {turn}: architecting next step..."}

                # C. LLM Call
                logging.info(f"[Architect] Turn {turn}: Calling LLM (Messages: {len(messages)})")
                llm_start_time = time.time()
                prediction = await self._call_llm(llm_provider, messages, tools)
                if not prediction:
                     logging.error(f"[Architect] Turn {turn}: LLM Provider returned None")
                     yield {"type": "error", "content": "LLM Provider failed to generate a response."}
                     return

                # D. Process Stream
                accumulated_content = ""
                accumulated_reasoning = ""
                tool_calls_map = {}
                finish_reason = None
                
                chunk_count = 0
                async for chunk in prediction:
                    chunk_count += 1
                    if getattr(chunk, "usage", None):
                        yield {"type": "token_counted", "usage": getattr(chunk, "usage").model_dump() if hasattr(getattr(chunk, "usage"), "model_dump") else getattr(chunk, "usage")}
                    
                    if not chunk.choices: continue
                    delta = chunk.choices[0].delta
                    finish_reason = getattr(chunk.choices[0], "finish_reason", None) or chunk.choices[0].get("finish_reason")
                    
                    # Native reasoning (O-series)
                    r = getattr(delta, "reasoning_content", None) or delta.get("reasoning_content")
                    if r: 
                        accumulated_reasoning += r
                        yield {"type": "reasoning", "content": r}
                    
                    # Content & Thinking Tags
                    c = getattr(delta, "content", None) or delta.get("content")
                    if c:
                        async for event in self.stream.process_chunk(c, turn):
                            if event["type"] == "content": accumulated_content += event["content"]
                            if event["type"] == "reasoning": accumulated_reasoning += event["content"]
                            
                            if not profile.buffer_content:
                                yield event
                            else:
                                if event["type"] == "reasoning":
                                    yield event
                    
                    self._accumulate_tool_calls(delta, tool_calls_map)

                # End Stream & Flush buffers
                async for event in self.stream.end_stream(turn):
                    if event["type"] == "content": accumulated_content += event["content"]
                    if event["type"] == "reasoning": accumulated_reasoning += event["content"]
                    if not profile.buffer_content:
                        yield event

                # Heartbeat Fallback
                if tool_calls_map and not self.stream.header_sent and not profile.silent_stream:
                    fallback_text = f"Strategy: Executing orchestrated tasks in progress..."
                    async for event in self.stream.process_chunk(fallback_text, turn):
                         if event["type"] == "content": accumulated_content += event["content"]
                         yield event
                    async for event in self.stream.end_stream(turn):
                         if event["type"] == "content": accumulated_content += event["content"]
                         yield event

                # Branch: Tools or Exit?
                if not tool_calls_map:
                    if finish_reason == "length":
                        yield {"type": "reasoning", "content": "\n> **⚠️ System Note:** Response was truncated. Prompting continuation...\n"}
                        messages.append({"role": "user", "content": "You were cut off. Please continue."})
                        continue

                    if not accumulated_content.strip():
                        if accumulated_reasoning.strip():
                            # Clean reasoning without regex
                            fallback = self.stream._apply_turn_header(accumulated_reasoning.strip())
                            if not fallback.strip():
                                fallback = "I've completed the requested task. Please check the thought trace for details."
                        else:
                            fallback = "I've completed the requested task. Please check the thought trace for details."
                        
                        yield {"type": "content", "content": fallback}
                    elif profile.buffer_content:
                        content_to_yield = self.stream._apply_turn_header(accumulated_content)
                        yield {"type": "content", "content": content_to_yield.strip()}

                    if safety.should_activate_watchdog(self._get_assistant(tool_service), sync_workspace_id):
                        yield {"type": "status", "content": "Watchdog: tasks remain. continuing..."}
                        messages.append({"role": "user", "content": "WATCHDOG: .ai_todo.md is not empty. Proceed."})
                        continue
                    
                    yield {"type": "status", "content": ""}
                    return # Natural exit

                # F. Execute Tools
                processed_tc = list(tool_calls_map.values())
                if safety.detect_loop(processed_tc):
                    yield {"type": "reasoning", "content": "\n> **🚨 Loop Guard:** Loop detected.\n"}
                    messages.append({"role": "user", "content": "STUCK: Change strategy."})
                    continue

                yield {"type": "status", "content": f"Dispatching {len(processed_tc)} tools..."}
                
                messages.append(self._format_assistant_msg(accumulated_content, accumulated_reasoning, processed_tc))

                executor = ToolExecutor(
                    tool_service, user_id, db, sync_workspace_id, session_id,
                    provider_name=getattr(llm_provider, "provider_name", None)
                )
                async for event in executor.run_tools(processed_tc, safety, mesh_bridge):
                    if "role" in event: # It's a tool result for history
                         messages.append(event)
                    else:
                         yield event

        except Exception as e:
            import traceback
            logging.error(f"[Architect] CRITICAL FAULT:\n{traceback.format_exc()}")
            yield {"type": "status", "content": "Fatal Orchestration Error"}
            yield {"type": "content", "content": f"\n\n> **🚨 Core Orchestrator Fault:** `{str(e)}`"}
        finally:
            if registry and user_id:
                registry.unsubscribe_user(user_id, mesh_bridge)

    # --- Internal Helpers ---

    def _get_registry(self, tool_service):
        if tool_service and hasattr(tool_service, "_services"):
            orchestrator = getattr(tool_service._services, "orchestrator", None)
            return getattr(orchestrator, "registry", None)
        return None

    def _get_assistant(self, tool_service):
        if tool_service and hasattr(tool_service, "_services"):
            orchestrator = getattr(tool_service._services, "orchestrator", None)
            return getattr(orchestrator, "assistant", None)
        return None

    def _update_turn_marker(self, messages, turn):
        if messages[0]["role"] == "system":
            content = messages[0]["content"]
            marker_anchor = "[System: Current Turn:"
            if marker_anchor in content:
                 messages[0]["content"] = content.split(marker_anchor)[0].strip() + f"\n\n{marker_anchor} {turn}]"
            else:
                 messages[0]["content"] = content + f"\n\n{marker_anchor} {turn}]"

    async def _call_llm(self, llm_provider, messages, tools):
        kwargs = {"stream": True, "stream_options": {"include_usage": True}, "max_tokens": 4096}
        if tools:
            kwargs["tools"] = tools
            kwargs["tool_choice"] = "auto"
        try:
            return await llm_provider.acompletion(messages=messages, timeout=60, **kwargs)
        except Exception as e:
            logging.error(f"[Architect] LLM Exception: {e}")
            return None

    def _accumulate_tool_calls(self, delta, t_map):
        tc_deltas = getattr(delta, "tool_calls", None) or delta.get("tool_calls")
        if not tc_deltas: return
        for tcd in tc_deltas:
            idx = tcd.index
            if idx not in t_map:
                t_map[idx] = tcd
            else:
                if getattr(tcd, "id", None): t_map[idx].id = tcd.id
                if tcd.function.name: t_map[idx].function.name = tcd.function.name
                if tcd.function.arguments: t_map[idx].function.arguments += tcd.function.arguments

    def _format_assistant_msg(self, content, reasoning, tool_calls):
        if content:
             # Fast string cleaning instead of regex for assistant message formatting
             content = self.stream._apply_turn_header(content).strip()

        clean_tc = []
        for tc in tool_calls:
            clean_tc.append({
                "id": tc.id, "type": "function",
                "function": {"name": tc.function.name, "arguments": tc.function.arguments}
            })
        msg = {"role": "assistant", "content": content or None, "tool_calls": clean_tc}
        if reasoning: msg["reasoning_content"] = reasoning
        return msg