import logging
import time
from typing import List, Optional, Dict, Any, AsyncGenerator, Tuple
from sqlalchemy.orm import Session, joinedload
from app.db import models
from app.core.retrievers.faiss_db_retriever import FaissDBRetriever
from app.core.retrievers.base_retriever import Retriever
from app.core.providers.factory import get_llm_provider
from app.core.orchestration import Architect
from app.core.orchestration.profiles import get_profile
from app.core._regex import ANSI_ESCAPE
from app.db.session import async_db_op
logger = logging.getLogger(__name__)
class RAGService:
"""
Orchestrates conversational RAG pipelines.
Decomposed into manageable components for maintainability.
"""
def __init__(self, retrievers: List[Retriever], prompt_service=None, tool_service=None, node_registry_service=None, services=None):
self.retrievers = retrievers
self.prompt_service = prompt_service
self.tool_service = tool_service
self.node_registry_service = node_registry_service
self.services = services
self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None)
async def chat_with_rag(
self,
db: Session,
session_id: int,
prompt: str,
provider_name: str,
load_faiss_retriever: bool = False,
user_service=None,
user_id: Optional[str] = None,
save_prompt: bool = True
) -> AsyncGenerator[Dict[str, Any], None]:
"""Entry point for the RAG pipeline."""
session = self._resolve_session(db, session_id, prompt, save_prompt=save_prompt)
llm_provider, resolved_provider_name = self._resolve_provider(db, session, provider_name, user_service)
context_chunks = []
if load_faiss_retriever and self.faiss_retriever:
context_chunks.extend(self.faiss_retriever.retrieve_context(query=prompt, db=db))
mesh_context = self._gather_mesh_context(db, session, user_service)
tools = self.tool_service.get_available_tools(db, session.user_id, feature=session.feature_name, session_id=session.id) if self.tool_service else []
profile = get_profile(session.feature_name)
# Accumulators
state = {
"answer": "", "reasoning": "", "tool_counts": {},
"usage": {"input": 0, "output": 0}, "msg": None
}
architect = Architect()
async for event in architect.run(
question=prompt, history=session.messages, context_chunks=context_chunks,
llm_provider=llm_provider, prompt_service=self.prompt_service, tool_service=self.tool_service,
tools=tools, mesh_context=mesh_context, db=db, user_id=user_id or session.user_id,
sync_workspace_id=session.sync_workspace_id or str(session_id), session_id=session_id,
feature_name=session.feature_name, prompt_slug=profile.default_prompt_slug,
session_override=session.system_prompt_override
):
await self._process_event(db, session_id, event, state)
yield event
# Final persistence
assistant_msg = await self._finalize_assistant_message(db, session_id, state)
yield {
"type": "finish", "message_id": assistant_msg.id, "provider": resolved_provider_name,
"full_answer": state["answer"], "tool_counts": state["tool_counts"],
"input_tokens": state["usage"]["input"], "output_tokens": state["usage"]["output"]
}
def _resolve_session(self, db: Session, session_id: int, prompt: str) -> models.Session:
"""Fetches and initializes the session state."""
session = db.query(models.Session).options(joinedload(models.Session.messages)).filter(models.Session.id == session_id).first()
if not session: raise ValueError(f"Session {session_id} not found.")
# Save user message
db.add(models.Message(session_id=session_id, sender="user", content=prompt))
if session.title in (None, "New Chat Session", ""):
session.title = prompt[:60].strip() + ("..." if len(prompt) > 60 else "")
db.commit()
return session
def _resolve_provider(self, db: Session, session: models.Session, provider_name: str, user_service) -> Tuple[Any, str]:
"""Resolves LLM provider with user-preference and system-level fallbacks."""
pref_svc = getattr(self.services, "preference_service", None) if self.services else None
if not pref_svc:
from app.core.providers.factory import get_llm_provider
return get_llm_provider(provider_name, model_name=session.model_name), provider_name
return pref_svc.resolve_llm_provider(
db, session.user, provider_name, model_name=session.model_name
)
def _gather_mesh_context(self, db: Session, session: models.Session, user_service) -> str:
"""Aggregates technical infrastructure context from attached agent nodes."""
profile = get_profile(session.feature_name)
if not session.attached_node_ids or not profile.include_mesh_context:
return ""
nodes = db.query(models.AgentNode).filter(models.AgentNode.node_id.in_(session.attached_node_ids)).all()
ctx = "Attached Agent Nodes (Infrastructure):\n"
for node in nodes:
ctx += f"- Node ID: {node.node_id}\n Name: {node.display_name}\n"
ctx += f" Status: {node.last_status}\n"
caps = node.capabilities or {}
if caps.get("arch"): ctx += f" Arch: {caps['arch']} ({caps.get('os', 'unknown')})\n"
# Privilege inference
is_root, has_sudo = caps.get("is_root") == "true", caps.get("has_sudo") == "true"
ctx += f" Privilege: {'root' if is_root else 'sudo-user' if has_sudo else 'standard'}\n"
# Sandbox status
sb = (node.skill_config or {}).get("shell", {}).get("sandbox", {})
if sb: ctx += f" Sandbox: {sb.get('mode', 'PERMISSIVE')}\n"
# Live terminal tailing
registry = self.node_registry_service or (user_service.node_registry_service if user_service else None)
if registry:
ctx += self._render_node_history(registry, node.node_id)
return ctx
def _render_node_history(self, registry, node_id: str) -> str:
"""Extracts and cleans the recent terminal history for a specific node."""
live = registry.get_node(node_id)
if not live or not live.terminal_history: return ""
chunks, total_len = [], 0
for chunk in reversed(list(live.terminal_history)[-40:]):
c_str = chunk if isinstance(chunk, str) else chunk.get("output", str(chunk)) if isinstance(chunk, dict) else str(chunk)
chunks.insert(0, c_str)
total_len += len(c_str)
if total_len > 4000: break
clean = ANSI_ESCAPE.sub('', "".join(chunks))
if len(clean) > 2000: clean = "...[truncated]...\n" + clean[-2000:]
return f" Recent Terminal Output:\n ```\n {clean}\n ```\n"
async def _process_event(self, db, session_id, event, state):
"""Updates internal state and DB progress based on pipeline events."""
e_type = event["type"]
if e_type == "content": state["answer"] += event["content"]
elif e_type == "reasoning": state["reasoning"] += event["content"]
elif e_type == "tool_start":
name = event.get("name")
if name: state["tool_counts"][name] = state["tool_counts"].get(name, {"calls":0, "successes":0, "failures":0}); state["tool_counts"][name]["calls"] += 1
elif e_type == "tool_result":
name, res = event.get("name"), event.get("result")
if name and name in state["tool_counts"]:
if res and (not isinstance(res, dict) or res.get("success") is False): state["tool_counts"][name]["failures"] += 1
else: state["tool_counts"][name]["successes"] += 1
elif e_type == "token_counted":
u = event.get("usage", {})
state["usage"]["input"] += u.get("prompt_tokens", 0); state["usage"]["output"] += u.get("completion_tokens", 0)
# Persistent UI Observability: Commit assistant chunks occasionally
if e_type in ("content", "reasoning"):
await self._update_assistant_db(db, session_id, event, state)
async def _update_assistant_db(self, db, session_id, event, state):
"""Incrementally saves the assistant's response to the DB for real-time frontend visibility."""
if not state["msg"]:
state["msg"] = models.Message(session_id=session_id, sender="assistant", content="")
db.add(state["msg"])
await async_db_op(db.commit)
if event["type"] == "content": state["msg"].content += event["content"]
elif event["type"] == "reasoning" and hasattr(state["msg"], "reasoning_content"):
state["msg"].reasoning_content = (state["msg"].reasoning_content or "") + event["content"]
if (state["usage"]["input"] + state["usage"]["output"]) % 50 == 0:
try: await async_db_op(db.commit)
except: await async_db_op(db.rollback)
async def _finalize_assistant_message(self, db, session_id, state) -> models.Message:
"""Ensures the final assistant message is correctly persisted and closed."""
msg = state["msg"] or models.Message(session_id=session_id, sender="assistant", content="")
msg.content = state["answer"]
if hasattr(msg, "reasoning_content"): msg.reasoning_content = state["reasoning"]
if not state["msg"]: db.add(msg)
await async_db_op(db.commit)
return msg
def get_message_history(self, db: Session, session_id: int) -> List[models.Message]:
"""Retrieves and sorts the conversational history for a session."""
session = db.query(models.Session).options(joinedload(models.Session.messages)).filter(models.Session.id == session_id).first()
return sorted(session.messages, key=lambda m: m.created_at) if session else None