import asyncio from typing import List, Dict, Any, Tuple from sqlalchemy.orm import Session, joinedload from sqlalchemy.exc import SQLAlchemyError import dspy from app.core.vector_store import FaissVectorStore from app.db import models from app.core.retrievers import Retriever from app.core.llm_providers import get_llm_provider from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline class RAGService: """ Service class for managing documents and conversational RAG sessions. """ def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]): self.vector_store = vector_store self.retrievers = retrievers # --- Session Management --- def create_session(self, db: Session, user_id: str, model: str) -> models.Session: """Creates a new chat session in the database.""" try: new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session") db.add(new_session) db.commit() db.refresh(new_session) return new_session except SQLAlchemyError as e: db.rollback() raise async def chat_with_rag(self, db: Session, session_id: int, prompt: str, model: str) -> Tuple[str, str]: """ Handles a message within a session, including saving history and getting a response. Allows switching the LLM model for the current chat turn. """ # Eagerly load the message history in a single query for efficiency. session = db.query(models.Session).options( joinedload(models.Session.messages) ).filter(models.Session.id == session_id).first() if not session: raise ValueError(f"Session with ID {session_id} not found.") # Save the new user message to the database user_message = models.Message(session_id=session_id, sender="user", content=prompt) db.add(user_message) db.commit() # Use the 'model' parameter passed to this method for the current chat turn llm_provider = get_llm_provider(model) dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model) dspy.configure(lm=dspy_llm) rag_pipeline = DspyRagPipeline(retrievers=self.retrievers) # Pass the full message history to the pipeline's forward method. # Note: The history is passed, but the current RAGPipeline implementation # might not fully utilize it for conversational context unless explicitly # designed to. This is a placeholder for future conversational RAG. answer_text = await rag_pipeline.forward( question=prompt, history=session.messages, # Pass the existing history db=db ) # Save the assistant's response assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text) db.add(assistant_message) db.commit() # Return the answer text and the model that was actually used for this turn return answer_text, model def get_message_history(self, db: Session, session_id: int) -> List[models.Message]: """ Retrieves all messages for a given session, or None if the session doesn't exist. """ session = db.query(models.Session).options( joinedload(models.Session.messages) ).filter(models.Session.id == session_id).first() # Return messages sorted by created_at to ensure chronological order return sorted(session.messages, key=lambda msg: msg.created_at) if session else None # --- Document Management (Unchanged) --- def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int: try: document_db = models.Document(**doc_data) db.add(document_db) db.commit() db.refresh(document_db) faiss_index = self.vector_store.add_document(document_db.text) vector_metadata = models.VectorMetadata( document_id=document_db.id, faiss_index=faiss_index, embedding_model="mock_embedder" ) db.add(vector_metadata) db.commit() return document_db.id except SQLAlchemyError as e: db.rollback() raise def get_all_documents(self, db: Session) -> List[models.Document]: return db.query(models.Document).order_by(models.Document.created_at.desc()).all() def delete_document(self, db: Session, document_id: int) -> int: try: doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first() if not doc_to_delete: return None db.delete(doc_to_delete) db.commit() return document_id except SQLAlchemyError as e: db.rollback() raise