Newer
Older
cortex-hub / ai-hub / app / core / services.py
from typing import List, Dict, Any, Tuple
from sqlalchemy.orm import Session, joinedload
from sqlalchemy.exc import SQLAlchemyError
import dspy

from app.core.vector_store import FaissVectorStore
from app.db import models
from app.core.retrievers import Retriever
from app.core.llm_providers import get_llm_provider
from app.core.pipelines.dspy_rag import DspyRagPipeline, DSPyLLMProvider


class RAGService:
    """
    Service class for managing documents and conversational RAG sessions.
    """
    def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]):
        self.vector_store = vector_store
        self.retrievers = retrievers

    # --- Session Management ---

    def create_session(self, db: Session, user_id: str, model: str) -> models.Session:
        """
        Creates a new chat session in the database.
        """
        try:
            # Create a default title; this could be updated later by the AI
            new_session = models.Session(
                user_id=user_id,
                model_name=model,
                title=f"New Chat Session"
            )
            db.add(new_session)
            db.commit()
            db.refresh(new_session)
            return new_session
        except SQLAlchemyError as e:
            db.rollback()
            raise

    async def chat_with_rag(self, db: Session, session_id: int, prompt: str) -> Tuple[str, str]:
        """
        Handles a message within a session, including saving history and getting a response.
        """
        if not prompt or not prompt.strip():
            raise ValueError("Prompt cannot be empty.")

        # 1. Find the session and its history
        session = db.query(models.Session).options(joinedload(models.Session.messages)).filter(models.Session.id == session_id).first()
        if not session:
            raise ValueError(f"Session with ID {session_id} not found.")

        # 2. Save the user's new message to the database
        user_message = models.Message(session_id=session_id, sender="user", content=prompt)
        db.add(user_message)
        db.commit()

        # 3. Configure DSPy with the session's model and execute the pipeline
        llm_provider = get_llm_provider(session.model_name)
        dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=session.model_name)
        dspy.configure(lm=dspy_llm)
        
        rag_pipeline = DspyRagPipeline(retrievers=self.retrievers)
        # (Optional) You could pass `session.messages` to the pipeline for context
        answer_text = await rag_pipeline.forward(question=prompt, db=db)
        
        # 4. Save the assistant's response to the database
        assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text)
        db.add(assistant_message)
        db.commit()

        return answer_text, session.model_name

    # --- Document Management (Unchanged) ---

    def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int:
        """Adds a document to the database and vector store."""
        # ... (implementation is unchanged)
        try:
            document_db = models.Document(**doc_data)
            db.add(document_db)
            db.commit()
            db.refresh(document_db)
            faiss_index = self.vector_store.add_document(document_db.text)
            vector_metadata = models.VectorMetadata(
                document_id=document_db.id,
                faiss_index=faiss_index,
                embedding_model="mock_embedder"
            )
            db.add(vector_metadata)
            db.commit()
            return document_db.id
        except SQLAlchemyError as e:
            db.rollback()
            raise

    def get_all_documents(self, db: Session) -> List[models.Document]:
        """Retrieves all documents from the database."""
        # ... (implementation is unchanged)
        return db.query(models.Document).order_by(models.Document.created_at.desc()).all()


    def delete_document(self, db: Session, document_id: int) -> int:
        """Deletes a document from the database."""
        # ... (implementation is unchanged)
        try:
            doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first()
            if not doc_to_delete:
                return None
            db.delete(doc_to_delete)
            db.commit()
            return document_id
        except SQLAlchemyError as e:
            db.rollback()
            raise