Newer
Older
cortex-hub / ai-hub / app / core / orchestration / architect.py
import logging
import queue
import time
from typing import List, Dict, Any, Optional
from app.db import models
from .memory import ContextManager
from .stream import StreamProcessor
from .body import ToolExecutor
from .guards import SafetyGuard

class Architect:
    """
    The Master-Architect Orchestrator. 
    Decomposed successor to RagPipeline.
    """

    def __init__(self, context_manager: Optional[ContextManager] = None):
        self.memory = context_manager or ContextManager()
        self.stream = None # Created during run()

    async def run(
        self,
        question: str,
        context_chunks: List[Dict[str, Any]],
        history: List[models.Message],
        llm_provider,
        prompt_service = None,
        tool_service = None,
        tools: List[Dict[str, Any]] = None,
        mesh_context: str = "",
        db = None,
        user_id: Optional[str] = None,
        sync_workspace_id: Optional[str] = None,
        session_id: Optional[int] = None,
        feature_name: str = "chat",
        prompt_slug: str = "rag-pipeline"
    ):
        # 1. Initialize Context & Messages
        messages = self.memory.prepare_initial_messages(
            question, context_chunks, history, feature_name, mesh_context, sync_workspace_id,
            db=db, user_id=user_id, prompt_service=prompt_service, prompt_slug=prompt_slug,
            tools=tools
        )

        # 2. Setup Mesh Observation
        mesh_bridge = queue.Queue()
        registry = self._get_registry(tool_service)
        if registry and user_id:
            registry.subscribe_user(user_id, mesh_bridge)

        # 3. Setup Guards
        safety = SafetyGuard(db, session_id)
        
        # 4. Main Autonomous Loop
        from .profiles import get_profile
        profile = get_profile(feature_name)
        self.stream = StreamProcessor(profile=profile)
        turn = 0
        session_start_time = time.time()
        
        try:
            while turn < profile.autonomous_limit:
                turn += 1
                turn_start_time = time.time()
                self.stream.reset_turn()

                # A. Cancellation / Memory check
                if safety.check_cancellation():
                    yield {"type": "reasoning", "content": "\n> **🛑 User Interruption:** Terminating loop.\n"}
                    return
                messages = self.memory.compress_history(messages)

                # B. Turn Start Heartbeat
                self._update_turn_marker(messages, turn)
                if profile.show_heartbeat:
                    yield {"type": "status", "content": f"Turn {turn}: architecting next step..."}

                # C. LLM Call
                prediction = await self._call_llm(llm_provider, messages, tools)
                if not prediction:
                     yield {"type": "error", "content": "LLM Provider failed to generate a response."}
                     return

                # D. Process Stream
                accumulated_content = ""
                accumulated_reasoning = ""
                tool_calls_map = {}
                
                async for chunk in prediction:
                    if not chunk.choices: continue
                    delta = chunk.choices[0].delta
                    
                    # Native reasoning (O-series)
                    r = getattr(delta, "reasoning_content", None) or delta.get("reasoning_content")
                    if r: 
                        accumulated_reasoning += r
                        yield {"type": "reasoning", "content": r}
                    
                    # Content & Thinking Tags
                    c = getattr(delta, "content", None) or delta.get("content")
                    if c:
                        async for event in self.stream.process_chunk(c, turn):
                            if event["type"] == "content": accumulated_content += event["content"]
                            if event["type"] == "reasoning": accumulated_reasoning += event["content"]
                            
                            if not profile.buffer_content:
                                yield event
                            else:
                                # In buffered mode (voice), we yield reasoning immediately but hold content
                                if event["type"] == "reasoning":
                                    yield event
                    
                    # Tool delta accumulation
                    self._accumulate_tool_calls(delta, tool_calls_map)

                # End Stream & Flush buffers
                async for event in self.stream.end_stream(turn):
                    if event["type"] == "content": accumulated_content += event["content"]
                    if event["type"] == "reasoning": accumulated_reasoning += event["content"]
                    yield event

                # E. Branch: Tools or Exit?
                # Heartbeat Fallback: If no content was sent but tools are being called, force a bridge sentence
                if tool_calls_map and not self.stream.header_sent and not profile.silent_stream:
                    fallback_text = f"Strategy: Executing orchestrated tasks in progress..."
                    async for event in self.stream.process_chunk(fallback_text, turn):
                         if event["type"] == "content": accumulated_content += event["content"]
                         yield event

                # E. Branch: Tools or Exit?
                if not tool_calls_map:
                    # Final Turn: Yield the accumulated content if it was empty
                    if not accumulated_content.strip():
                        import re
                        fallback = "I've completed the requested task."
                        if accumulated_reasoning:
                            fallback = "Analysis finished. Please review the results above."
                        
                        # In voice mode (buffered), we apply specialized stripping
                        if profile.buffer_content:
                            content_to_yield = fallback
                            yield {"type": "content", "content": content_to_yield}
                        else:
                            # In chat mode, just send the fallback if no content ever came through
                            yield {"type": "content", "content": fallback}
                    elif profile.buffer_content:
                        # Standard buffered yield
                        import re
                        content_to_yield = accumulated_content
                        for pattern in profile.strip_headers:
                            content_to_yield = re.sub(pattern, "", content_to_yield, flags=re.IGNORECASE)
                        yield {"type": "content", "content": content_to_yield.strip()}

                    # Watchdog Check
                    if safety.should_activate_watchdog(self._get_assistant(tool_service), sync_workspace_id):
                        yield {"type": "status", "content": "Watchdog: tasks remain. continuing..."}
                        messages.append({"role": "user", "content": "WATCHDOG: .ai_todo.md has open items. Please continue until all are marked [COMPLETED]."})
                        continue
                    
                    # Turn duration report (Natural Exit)
                    turn_duration = time.time() - turn_start_time
                    total_duration = time.time() - session_start_time
                    duration_marker = f"\n\n> **⏱️ Turn {turn} Duration:** {turn_duration:.1f}s | **Total:** {total_duration:.1f}s\n"
                    yield {"type": "reasoning", "content": duration_marker}
                    yield {"type": "status", "content": f"Turn {turn} finished in **{turn_duration:.1f}s**. (Session Total: **{total_duration:.1f}s**)"}
                    return # Natural exit

                # F. Execute Tools
                processed_tc = list(tool_calls_map.values())
                if safety.detect_loop(processed_tc):
                    yield {"type": "reasoning", "content": "\n> **🚨 Loop Guard:** Repetitive plan detected. Retrying with warnings.\n"}
                    messages.append({"role": "user", "content": "LOOP GUARD: You are stuck. Change strategy."})
                    continue

                yield {"type": "status", "content": f"Architect analysis complete. Dispatching {len(processed_tc)} tools..."}
                
                # Append assistant message to history
                messages.append(self._format_assistant_msg(accumulated_content, accumulated_reasoning, processed_tc))

                # Run parallel execution
                executor = ToolExecutor(tool_service, user_id, db, sync_workspace_id, session_id)
                async for event in executor.run_tools(processed_tc, safety, mesh_bridge):
                    if "role" in event: # It's a tool result for history
                         messages.append(event)
                    else:
                         yield event
                
                # Turn duration report (End of Loop)
                turn_duration = time.time() - turn_start_time
                total_duration = time.time() - session_start_time
                duration_marker = f"\n\n> **⏱️ Turn {turn} Duration:** {turn_duration:.1f}s | **Total:** {total_duration:.1f}s\n"
                yield {"type": "reasoning", "content": duration_marker}
                yield {"type": "status", "content": f"Turn {turn} finished in **{turn_duration:.1f}s**. (Session Total: **{total_duration:.1f}s**)"}

        finally:
            if registry and user_id:
                registry.unsubscribe_user(user_id, mesh_bridge)

    # --- Internal Helpers ---

    def _get_registry(self, tool_service):
        if tool_service and hasattr(tool_service, "_services"):
            orchestrator = getattr(tool_service._services, "orchestrator", None)
            return getattr(orchestrator, "registry", None)
        return None

    def _get_assistant(self, tool_service):
        if tool_service and hasattr(tool_service, "_services"):
            orchestrator = getattr(tool_service._services, "orchestrator", None)
            return getattr(orchestrator, "assistant", None)
        return None

    def _update_turn_marker(self, messages, turn):
        if messages[0]["role"] == "system":
            base = messages[0]["content"].split("[System:")[0].strip()
            messages[0]["content"] = base + f"\n\n[System: Current Turn: {turn}]"

    async def _call_llm(self, llm_provider, messages, tools):
        kwargs = {"stream": True}
        if tools:
            kwargs["tools"] = tools
            kwargs["tool_choice"] = "auto"
        try:
            return await llm_provider.acompletion(messages=messages, **kwargs)
        except Exception as e:
            logging.error(f"[Architect] LLM Exception: {e}")
            return None

    def _accumulate_tool_calls(self, delta, t_map):
        tc_deltas = getattr(delta, "tool_calls", None) or delta.get("tool_calls")
        if not tc_deltas: return
        for tcd in tc_deltas:
            idx = tcd.index
            if idx not in t_map:
                t_map[idx] = tcd
            else:
                if getattr(tcd, "id", None): t_map[idx].id = tcd.id
                if tcd.function.name: t_map[idx].function.name = tcd.function.name
                if tcd.function.arguments: t_map[idx].function.arguments += tcd.function.arguments

    def _format_assistant_msg(self, content, reasoning, tool_calls):
        if content:
            import re
            # Strip system-injected turn headers from history to avoid LLM hallucination
            content = re.sub(r"(?i)---\n### 🛰️ \*\*\[Turn \d+\] .*?\*\*\n", "", content)
            content = content.strip()

        clean_tc = []
        for tc in tool_calls:
            clean_tc.append({
                "id": tc.id, "type": "function",
                "function": {"name": tc.function.name, "arguments": tc.function.arguments}
            })
        msg = {"role": "assistant", "content": content or None, "tool_calls": clean_tc}
        if reasoning: msg["reasoning_content"] = reasoning
        return msg