import asyncio from typing import List, Tuple from sqlalchemy.orm import Session, joinedload import dspy from app.db import models from app.core.retrievers.faiss_db_retriever import FaissDBRetriever from app.core.retrievers.base_retriever import Retriever from app.core.providers.factory import get_llm_provider from app.core.pipelines.dspy_rag import DspyRagPipeline class RAGService: """ Service for orchestrating conversational RAG pipelines. Manages chat interactions and message history for a session. """ def __init__(self, retrievers: List[Retriever]): self.retrievers = retrievers self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None) async def chat_with_rag( self, db: Session, session_id: int, prompt: str, provider_name: str, load_faiss_retriever: bool = False ) -> Tuple[str, str]: """ Processes a user prompt within a session, saves the chat history, and returns a response. """ session = db.query(models.Session).options( joinedload(models.Session.messages) ).filter(models.Session.id == session_id).first() if not session: raise ValueError(f"Session with ID {session_id} not found.") # Save user message user_message = models.Message(session_id=session_id, sender="user", content=prompt) db.add(user_message) db.commit() db.refresh(user_message) # Get the appropriate LLM provider llm_provider = get_llm_provider(provider_name) # Configure retrievers for the pipeline current_retrievers = [] if load_faiss_retriever: if self.faiss_retriever: current_retrievers.append(self.faiss_retriever) else: print("Warning: FaissDBRetriever requested but not available. Proceeding without it.") rag_pipeline = DspyRagPipeline(retrievers=current_retrievers) # Run the RAG pipeline to get a response with dspy.context(lm=llm_provider): answer_text = await rag_pipeline.forward( question=prompt, history=session.messages, db=db ) # Save assistant's response assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) db.add(assistant_message) db.commit() db.refresh(assistant_message) return answer_text, provider_name def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: """ Retrieves all messages for a given session, ordered by creation time. """ session = db.query(models.Session).options( joinedload(models.Session.messages) ).filter(models.Session.id == session_id).first() return sorted(session.messages, key=lambda msg: msg.created_at) if session else None