from typing import List, Tuple, Optional
from sqlalchemy.orm import Session, joinedload
from app.db import models
from app.core.retrievers.faiss_db_retriever import FaissDBRetriever
from app.core.retrievers.base_retriever import Retriever
from app.core.providers.factory import get_llm_provider
from app.core.orchestration import Architect
class RAGService:
"""
Service for orchestrating conversational RAG pipelines.
Manages chat interactions and message history for a session.
"""
def __init__(self, retrievers: List[Retriever], prompt_service = None, tool_service = None, node_registry_service = None):
self.retrievers = retrievers
self.prompt_service = prompt_service
self.tool_service = tool_service
self.node_registry_service = node_registry_service
self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None)
async def chat_with_rag(
self,
db: Session,
session_id: int,
prompt: str,
provider_name: str,
load_faiss_retriever: bool = False,
user_service = None,
user_id: Optional[str] = None
):
"""
Processes a user prompt within a session, yields events in real-time,
and saves the chat history at the end.
"""
session = db.query(models.Session).options(
joinedload(models.Session.messages)
).filter(models.Session.id == session_id).first()
if not session:
raise ValueError(f"Session with ID {session_id} not found.")
# Save user message
user_message = models.Message(session_id=session_id, sender="user", content=prompt)
db.add(user_message)
db.commit()
db.refresh(user_message)
# Auto-title the session from the very first user message
if session.title in (None, "New Chat Session", ""):
session.title = prompt[:60].strip() + ("..." if len(prompt) > 60 else "")
# Keep provider_name in sync
if session.provider_name != provider_name:
session.provider_name = provider_name
db.commit()
# Resolve provider
llm_prefs = {}
user = session.user
if user and user.preferences:
llm_prefs = user.preferences.get("llm", {}).get("providers", {}).get(provider_name, {})
if (not llm_prefs or not llm_prefs.get("api_key") or "*" in str(llm_prefs.get("api_key"))) and user_service:
system_prefs = user_service.get_system_settings(db)
system_provider_prefs = system_prefs.get("llm", {}).get("providers", {}).get(provider_name, {})
if system_provider_prefs:
merged = system_provider_prefs.copy()
if llm_prefs: merged.update({k: v for k, v in llm_prefs.items() if v})
llm_prefs = merged
api_key_override = llm_prefs.get("api_key")
model_name_override = llm_prefs.get("model", "")
kwargs = {k: v for k, v in llm_prefs.items() if k not in ["api_key", "model"]}
llm_provider = get_llm_provider(
provider_name,
model_name=model_name_override,
api_key_override=api_key_override,
**kwargs
)
context_chunks = []
if load_faiss_retriever:
if self.faiss_retriever:
context_chunks.extend(self.faiss_retriever.retrieve_context(query=prompt, db=db))
architect = Architect()
tools = []
if self.tool_service:
tools = self.tool_service.get_available_tools(db, session.user_id, feature=session.feature_name)
mesh_context = ""
if session.attached_node_ids:
nodes = db.query(models.AgentNode).filter(models.AgentNode.node_id.in_(session.attached_node_ids)).all()
if nodes:
mesh_context = "Attached Agent Nodes (Infrastructure):\n"
for node in nodes:
mesh_context += f"- Node ID: {node.node_id}\n"
mesh_context += f" Name: {node.display_name}\n"
mesh_context += f" Description: {node.description or 'No description provided.'}\n"
mesh_context += f" Status: {node.last_status}\n"
caps = node.capabilities or {}
if caps.get("local_ip"):
mesh_context += f" Local IP: {caps.get('local_ip')}\n"
if caps.get("arch"):
mesh_context += f" Architecture: {caps['arch']} ({caps.get('os', 'unknown')})\n"
if caps.get("gpu") and caps["gpu"] != "none":
mesh_context += f" GPU: {caps['gpu']}\n"
# Privilege level — critical for knowing whether to use sudo
# Values are stored as strings ("true"/"false") due to protobuf map<string,string>
is_root = caps.get("is_root")
has_sudo = caps.get("has_sudo")
if is_root == "true" or is_root is True:
mesh_context += f" Privilege Level: root (skip sudo — run all commands directly)\n"
elif has_sudo == "true" or has_sudo is True:
mesh_context += f" Privilege Level: standard user with passwordless sudo\n"
elif is_root == "false" or is_root is False:
mesh_context += f" Privilege Level: standard user (sudo NOT available — avoid privileged ops)\n"
# If neither field exists yet (old node version), omit to avoid confusion
shell_config = (node.skill_config or {}).get("shell", {})
if shell_config.get("enabled"):
sandbox = shell_config.get("sandbox", {})
mode = sandbox.get("mode", "PERMISSIVE")
allowed = sandbox.get("allowed_commands", [])
denied = sandbox.get("denied_commands", [])
mesh_context += f" Terminal Sandbox Mode: {mode}\n"
if mode == "STRICT":
mesh_context += f" AI Permitted Commands (Allow-list): {', '.join(allowed) if allowed else 'None'}\n"
elif mode == "PERMISSIVE":
mesh_context += f" AI Restricted Commands (Blacklist): {', '.join(denied) if denied else 'None'}\n"
if mode == "STRICT" and not allowed:
mesh_context += " ⚠️ Warning: All shell commands are currently blocked by sandbox policy.\n"
# AI Visibility: Recent terminal history
registry = getattr(self, "node_registry_service", None)
if not registry and user_service:
registry = getattr(user_service, "node_registry_service", None)
if registry:
live = registry.get_node(node.node_id)
if live and live.terminal_history:
history = live.terminal_history[-40:]
mesh_context += " Recent Terminal Output:\n"
mesh_context += " ```\n"
for line in history: mesh_context += f" {line}"
if not history[-1].endswith('\n'): mesh_context += "\n"
mesh_context += " ```\n"
mesh_context += "\n"
# Accumulators for the DB save at the end
full_answer = ""
full_reasoning = ""
# Stream from specialized Architect
async for event in architect.run(
question=prompt,
history=session.messages,
context_chunks = context_chunks,
llm_provider = llm_provider,
prompt_service = self.prompt_service,
tool_service = self.tool_service,
tools = tools,
mesh_context = mesh_context,
db = db,
user_id = user_id or session.user_id,
sync_workspace_id = session.sync_workspace_id,
session_id = session_id,
feature_name = session.feature_name,
prompt_slug = "rag-pipeline"
):
if event["type"] == "content":
full_answer += event["content"]
elif event["type"] == "reasoning":
full_reasoning += event["content"]
# Forward the event to the API stream
yield event
# Save assistant's response to DB
assistant_message = models.Message(
session_id=session_id,
sender="assistant",
content=full_answer,
# We assume your models.Message might have these or we just save content
)
# Optional: if model supports reasoning_content field
if full_reasoning and hasattr(assistant_message, "reasoning_content"):
assistant_message.reasoning_content = full_reasoning
db.add(assistant_message)
db.commit()
db.refresh(assistant_message)
# Yield a final finish event with metadata
yield {
"type": "finish",
"message_id": assistant_message.id,
"provider": provider_name,
"full_answer": full_answer
}
def get_message_history(self, db: Session, session_id: int) -> List[models.Message]:
"""
Retrieves all messages for a given session, ordered by creation time.
"""
session = db.query(models.Session).options(
joinedload(models.Session.messages)
).filter(models.Session.id == session_id).first()
return sorted(session.messages, key=lambda msg: msg.created_at) if session else None