Newer
Older
cortex-hub / ai-hub / app / core / services / rag.py
import asyncio
from typing import List, Dict, Any, Tuple
from sqlalchemy.orm import Session, joinedload
from sqlalchemy.exc import SQLAlchemyError
import dspy

from app.core.vector_store.faiss_store import FaissVectorStore
from app.db import models
from app.core.retrievers import Retriever, FaissDBRetriever
from app.core.llm_providers import get_llm_provider
from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline

class RAGService:
    """
    Service class for managing conversational RAG sessions.
    This class orchestrates the RAG pipeline and manages chat sessions.
    """
    def __init__(self,  retrievers: List[Retriever]):
        self.retrievers = retrievers
        self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None)
    # --- Session Management ---

    def create_session(self, db: Session, user_id: str, model: str) -> models.Session:
        """Creates a new chat session in the database."""
        try:
            new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session")
            db.add(new_session)
            db.commit()
            db.refresh(new_session)
            return new_session
        except SQLAlchemyError as e:
            db.rollback()
            raise

    async def chat_with_rag(
        self,
        db: Session,
        session_id: int,
        prompt: str,
        model: str,
        load_faiss_retriever: bool = False
    ) -> Tuple[str, str]:
        """
        Handles a message within a session, including saving history and getting a response.
        """
        session = db.query(models.Session).options(
            joinedload(models.Session.messages)
        ).filter(models.Session.id == session_id).first()
        
        if not session:
            raise ValueError(f"Session with ID {session_id} not found.")

        user_message = models.Message(session_id=session_id, sender="user", content=prompt)
        db.add(user_message)
        db.commit()

        llm_provider = get_llm_provider(model)
        dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model)
        dspy.configure(lm=dspy_llm)
        
        current_retrievers = []
        if load_faiss_retriever:
            if self.faiss_retriever:
                current_retrievers.append(self.faiss_retriever)
            else:
                print("Warning: FaissDBRetriever requested but not available. Proceeding without it.")
        
        rag_pipeline = DspyRagPipeline(retrievers=current_retrievers)

        answer_text = await rag_pipeline.forward(
            question=prompt, 
            history=session.messages,
            db=db
        )
        
        assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text)
        db.add(assistant_message)
        db.commit()

        return answer_text, model

    def get_message_history(self, db: Session, session_id: int) -> List[models.Message]:
        """
        Retrieves all messages for a given session.
        """
        session = db.query(models.Session).options(
            joinedload(models.Session.messages)
        ).filter(models.Session.id == session_id).first()
        
        return sorted(session.messages, key=lambda msg: msg.created_at) if session else None