import asyncio
from typing import List, Tuple
from sqlalchemy.orm import Session, joinedload
import dspy
from app.db import models
from app.core.retrievers.faiss_db_retriever import FaissDBRetriever
from app.core.retrievers.base_retriever import Retriever
from app.core.providers.factory import get_llm_provider
from app.core.pipelines.dspy_rag import DspyRagPipeline
class RAGService:
"""
Service for orchestrating conversational RAG pipelines.
Manages chat interactions and message history for a session.
"""
def __init__(self, retrievers: List[Retriever]):
self.retrievers = retrievers
self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None)
async def chat_with_rag(
self,
db: Session,
session_id: int,
prompt: str,
provider_name: str,
load_faiss_retriever: bool = False
) -> Tuple[str, str]:
"""
Processes a user prompt within a session, saves the chat history, and returns a response.
"""
session = db.query(models.Session).options(
joinedload(models.Session.messages)
).filter(models.Session.id == session_id).first()
if not session:
raise ValueError(f"Session with ID {session_id} not found.")
# Save user message
user_message = models.Message(session_id=session_id, sender="user", content=prompt)
db.add(user_message)
db.commit()
db.refresh(user_message)
# Auto-title the session from the very first user message
if session.title in (None, "New Chat Session", ""):
session.title = prompt[:60].strip() + ("..." if len(prompt) > 60 else "")
# Keep provider_name in sync with the model actually being used
if session.provider_name != provider_name:
session.provider_name = provider_name
db.commit()
# Fetch user preferences for overrides
api_key_override = None
model_name_override = ""
user = session.user
if user and user.preferences:
llm_prefs = user.preferences.get("llm", {}).get("providers", {}).get(provider_name, {})
api_key_override = llm_prefs.get("api_key")
model_name_override = llm_prefs.get("model", "")
# Get the appropriate LLM provider with all extra prefs passed as kwargs
kwargs = {k: v for k, v in llm_prefs.items() if k not in ["api_key", "model"]}
llm_provider = get_llm_provider(
provider_name,
model_name=model_name_override,
api_key_override=api_key_override,
**kwargs
)
# Configure retrievers for the pipeline
context_chunks = []
if load_faiss_retriever:
if self.faiss_retriever:
context_chunks.extend(self.faiss_retriever.retrieve_context(query=prompt, db=db)) # Ensure FAISS index is loaded
else:
print("Warning: FaissDBRetriever requested but not available. Proceeding without it.")
rag_pipeline = DspyRagPipeline()
with dspy.context(lm=llm_provider):
answer_text = await rag_pipeline.forward(
question=prompt,
history=session.messages,
context_chunks = context_chunks
)
# Save assistant's response
assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text)
db.add(assistant_message)
db.commit()
db.refresh(assistant_message)
return answer_text, provider_name
def get_message_history(self, db: Session, session_id: int) -> List[models.Message]:
"""
Retrieves all messages for a given session, ordered by creation time.
"""
session = db.query(models.Session).options(
joinedload(models.Session.messages)
).filter(models.Session.id == session_id).first()
return sorted(session.messages, key=lambda msg: msg.created_at) if session else None