import logging
import queue
import time
import traceback
from typing import List, Dict, Any, Optional
import litellm
from app.db import models
from .memory import ContextManager
from .stream import StreamProcessor
from .body import ToolExecutor
from .guards import SafetyGuard
_PROVIDER_ERROR_TYPES = (
litellm.ServiceUnavailableError,
litellm.InternalServerError,
litellm.RateLimitError,
litellm.APIConnectionError,
litellm.APIError,
litellm.Timeout,
litellm.AuthenticationError,
)
def _is_provider_error(e: Exception) -> bool:
"""True for transient or user-actionable provider errors — not code bugs."""
if isinstance(e, _PROVIDER_ERROR_TYPES):
return True
# MidStreamFallbackError wraps another exception — match by name to avoid version drift
if type(e).__name__ in ("MidStreamFallbackError", "ContextWindowExceededError"):
return True
msg = str(e).lower()
return any(tok in msg for tok in ("503", "502", "429", "unavailable", "rate limit", "timeout", "overloaded"))
class Architect:
"""
The Master-Architect Orchestrator.
Decomposed successor to RagPipeline. 100% REGEX-FREE.
"""
def __init__(self, context_manager: Optional[ContextManager] = None):
self.memory = context_manager or ContextManager()
self.stream = None # Created during run()
async def run(
self,
question: str,
context_chunks: List[Dict[str, Any]],
history: List[models.Message],
llm_provider,
prompt_service = None,
tool_service = None,
tools: List[Dict[str, Any]] = None,
mesh_context: str = "",
db = None,
user_id: Optional[str] = None,
sync_workspace_id: Optional[str] = None,
session_id: Optional[int] = None,
feature_name: str = "chat",
prompt_slug: str = "rag-pipeline",
session_override: Optional[str] = None
):
"""Dispatches an autonomous orchestration loop with turn-based strategy."""
messages = self.memory.prepare_initial_messages(
question, context_chunks, history, feature_name, mesh_context, sync_workspace_id,
db=db, user_id=user_id, prompt_service=prompt_service, prompt_slug=prompt_slug,
tools=tools, session_override=session_override
)
logging.info(f"[Architect] Starting loop. Prompt Size: {sum(len(m.get('content','') or '') for m in messages)} chars.")
mesh_bridge = queue.Queue()
registry = self._get_registry(tool_service)
if registry and user_id: registry.subscribe_user(user_id, mesh_bridge)
safety = SafetyGuard(db, session_id)
from .profiles import get_profile
profile = get_profile(feature_name)
self.stream = StreamProcessor(profile=profile)
turn = 0
start_time = time.time()
try:
while turn < profile.autonomous_limit:
turn += 1
self.stream.reset_turn()
if safety.check_cancellation():
yield {"type": "reasoning", "content": "\n> **🛑 User Interruption:** Terminating loop.\n"}
break
messages = await self.memory.compress_history(messages, llm_provider)
self._update_turn_marker(messages, turn)
if profile.show_heartbeat: yield {"type": "status", "content": f"Turn {turn}: architecting next step"}
prediction = await self._call_llm(llm_provider, messages, tools)
if not prediction: break
# A. Handle Stream Turn
content, reasoning, tc_map, finish_reason = "", "", {}, None
async for event in self._process_llm_stream(prediction, turn, profile):
e_type = event.get("type")
if e_type == "content": content += event["content"]
elif e_type == "reasoning": reasoning += event["content"]
elif e_type == "tool_calls_detected": tc_map.update(event["map"])
elif e_type == "finish_reason": finish_reason = event["reason"]
yield event
# B. Decision Branch: Tools or Exit?
if not tc_map:
events = []
should_continue = await self._handle_no_tools_branch(finish_reason, content, reasoning, profile, safety, tool_service, sync_workspace_id, messages, events)
for e in events: yield e
if should_continue:
continue # Watchdog or continuation triggered
break # Natural exit
# C. Execute Tools
processed_tc = list(tc_map.values())
if safety.detect_loop(processed_tc):
yield {"type": "reasoning", "content": "\n> **🚨 Loop Guard:** Loop detected.\n"}
messages.append({"role": "user", "content": "STUCK: Change strategy."})
continue
yield {"type": "status", "content": f"Dispatching {len(processed_tc)} tools"}
messages.append(self._format_assistant_msg(content, reasoning, processed_tc))
executor = ToolExecutor(tool_service, user_id, db, sync_workspace_id, session_id, provider_name=getattr(llm_provider, "provider_name", None))
async for event in executor.run_tools(processed_tc, safety, mesh_bridge):
if "role" in event: messages.append(event)
else: yield event
elapsed = time.time() - start_time
if turn >= profile.autonomous_limit:
yield {"type": "status", "content": f"Autonomous limit reached after {elapsed:.1f}s. Please provide more instructions if needed."}
else:
yield {"type": "status", "content": f"Task complete in {elapsed:.1f}s"}
except Exception as e:
if _is_provider_error(e):
logging.warning(f"[Architect] Provider error (non-fatal): {e}")
model_hint = getattr(llm_provider, 'model_name', 'the AI model')
yield {"type": "status", "content": "AI provider temporarily unavailable"}
yield {"type": "content", "content": (
f"\n\n> ⚠️ **`{model_hint}` is temporarily unavailable.**\n"
f">\n"
f"> The provider returned a service error (likely overloaded or a preview endpoint outage).\n"
f"> **Please try your message again** — or go to **Settings → AI Provider** to switch to a stable model."
)}
else:
logging.error(f"[Architect] CRITICAL FAULT:\n{traceback.format_exc()}")
yield {"type": "status", "content": "Fatal Orchestration Error"}
yield {"type": "content", "content": f"\n\n> **🚨 Core Orchestrator Fault:** `{str(e)}`"}
finally:
if registry and user_id: registry.unsubscribe_user(user_id, mesh_bridge)
# --- M7: Automatic Terminal Task Cancellation ---
if tool_service and hasattr(tool_service, "_services") and sync_workspace_id:
try:
mesh = getattr(tool_service._services, "mesh_service", None)
if mesh: mesh.cancel_session(sync_workspace_id, user_id, db)
except: pass
async def _process_llm_stream(self, prediction, turn, profile):
"""Internal helper for processing raw LLM stream into architectural events."""
tc_map, finish_reason = {}, None
async for chunk in prediction:
if getattr(chunk, "usage", None):
yield {"type": "token_counted", "usage": getattr(chunk, "usage").model_dump() if hasattr(getattr(chunk, "usage"), "model_dump") else getattr(chunk, "usage")}
if not chunk.choices: continue
delta = chunk.choices[0].delta
finish_reason = getattr(chunk.choices[0], "finish_reason", None) or chunk.choices[0].get("finish_reason")
# Native reasoning (O-series)
r = getattr(delta, "reasoning_content", None) or delta.get("reasoning_content")
if r: yield {"type": "reasoning", "content": r}
# Content & Thinking Tags
c = getattr(delta, "content", None) or delta.get("content")
if c:
async for event in self.stream.process_chunk(c, turn):
if not (profile.buffer_content and event["type"] == "content"): yield event
self._accumulate_tool_calls(delta, tc_map)
async for event in self.stream.end_stream(turn):
if not (profile.buffer_content and event["type"] == "content"): yield event
# Standardize tool calls for JSON serialization in the event stream
serializable_tc_map = {}
for idx, tc in tc_map.items():
if hasattr(tc, "model_dump"):
serializable_tc_map[idx] = tc.model_dump()
else:
# Manual fallback for non-Pydantic objects
serializable_tc_map[idx] = {
"id": getattr(tc, "id", None),
"type": "function",
"function": {
"name": getattr(tc.function, "name", ""),
"arguments": getattr(tc.function, "arguments", "")
}
}
yield {"type": "tool_calls_detected", "map": serializable_tc_map}
yield {"type": "finish_reason", "reason": finish_reason}
async def _handle_no_tools_branch(self, finish_reason, content, reasoning, profile, safety, tool_svc, ws_id, messages, events_out: list) -> bool:
"""Determines if a no-tool turn should exit or trigger a continuation/watchdog."""
if finish_reason == "length":
messages.append({"role": "user", "content": "You were cut off. Please continue."})
return True
if not content.strip():
fallback = self.stream._apply_turn_header(reasoning.strip()) if reasoning.strip() else ""
if not fallback.strip(): fallback = "Task complete. Check thought trace for details."
events_out.append({"type": "content", "content": fallback})
elif profile.buffer_content:
events_out.append({"type": "content", "content": self.stream._apply_turn_header(content).strip()})
if safety.should_activate_watchdog(self._get_assistant(tool_svc), ws_id):
messages.append({"role": "user", "content": "WATCHDOG: .ai_todo.md is not empty. Proceed."})
return True
return False
# --- Internal Helpers ---
def _get_registry(self, tool_service):
if tool_service and hasattr(tool_service, "_services"):
orchestrator = getattr(tool_service._services, "orchestrator", None)
return getattr(orchestrator, "registry", None)
return None
def _get_assistant(self, tool_service):
if tool_service and hasattr(tool_service, "_services"):
orchestrator = getattr(tool_service._services, "orchestrator", None)
return getattr(orchestrator, "assistant", None)
return None
def _update_turn_marker(self, messages, turn):
if messages[0]["role"] == "system":
content = messages[0]["content"]
marker_anchor = "[System: Current Turn:"
if marker_anchor in content:
messages[0]["content"] = content.split(marker_anchor)[0].strip() + f"\n\n{marker_anchor} {turn}]"
else:
messages[0]["content"] = content + f"\n\n{marker_anchor} {turn}]"
async def _call_llm(self, llm_provider, messages, tools):
kwargs = {"stream": True, "stream_options": {"include_usage": True}, "max_tokens": 4096}
if tools:
kwargs["tools"] = tools
kwargs["tool_choice"] = "auto"
# Let provider errors propagate so the top-level handler can show a user-friendly message.
# Only swallow true unknowns (e.g. missing provider config) which return None to exit cleanly.
try:
return await llm_provider.acompletion(messages=messages, timeout=60, **kwargs)
except Exception as e:
if _is_provider_error(e):
raise
logging.error(f"[Architect] LLM Exception: {e}")
return None
def _accumulate_tool_calls(self, delta, t_map):
tc_deltas = getattr(delta, "tool_calls", None) or delta.get("tool_calls")
if not tc_deltas: return
for tcd in tc_deltas:
idx = tcd.index
if idx not in t_map:
t_map[idx] = tcd
else:
if getattr(tcd, "id", None): t_map[idx].id = tcd.id
if tcd.function.name: t_map[idx].function.name = tcd.function.name
if tcd.function.arguments: t_map[idx].function.arguments += tcd.function.arguments
def _format_assistant_msg(self, content, reasoning, tool_calls):
if content:
# Fast string cleaning instead of regex for assistant message formatting
content = self.stream._apply_turn_header(content).strip()
clean_tc = []
for tc in tool_calls:
# Handle both object and dict access (Migration to Serializable Swarm)
tc_id = tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None)
tc_func = tc.get("function", {}) if isinstance(tc, dict) else getattr(tc, "function", None)
func_name = tc_func.get("name") if isinstance(tc_func, dict) else getattr(tc_func, "name", "")
func_args = tc_func.get("arguments") if isinstance(tc_func, dict) else getattr(tc_func, "arguments", "")
clean_tc.append({
"id": tc_id, "type": "function",
"function": {"name": func_name, "arguments": func_args}
})
msg = {"role": "assistant", "content": content or None, "tool_calls": clean_tc}
if reasoning: msg["reasoning_content"] = reasoning
return msg