Newer
Older
cortex-hub / ai-hub / app / core / services.py
# import asyncio
# from typing import List, Dict, Any, Tuple
# from sqlalchemy.orm import Session, joinedload
# from sqlalchemy.exc import SQLAlchemyError
# import dspy

# from app.core.vector_store.faiss_store import FaissVectorStore
# from app.core.vector_store.embedder.mock import MockEmbedder
# from app.db import models
# from app.core.retrievers import Retriever, FaissDBRetriever
# from app.core.llm_providers import get_llm_provider
# from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline

# class RAGService:
#     """
#     Service class for managing documents and conversational RAG sessions.
#     This class is now more robust and can handle both real and mock embedders
#     by inspecting its dependencies.
#     """
#     def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]):
#         self.vector_store = vector_store
#         self.retrievers = retrievers
        
#         # Assume one of the retrievers is the FAISS retriever, and you can access it.
#         # A better approach might be to have a dictionary of named retrievers.
#         self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None)

#         # Store the embedder from the vector store for dynamic naming
#         self.embedder = self.vector_store.embedder


#     # --- Session Management ---

#     def create_session(self, db: Session, user_id: str, model: str) -> models.Session:
#         """Creates a new chat session in the database."""
#         try:
#             new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session")
#             db.add(new_session)
#             db.commit()
#             db.refresh(new_session)
#             return new_session
#         except SQLAlchemyError as e:
#             db.rollback()
#             raise

#     async def chat_with_rag(
#         self,
#         db: Session,
#         session_id: int,
#         prompt: str,
#         model: str,
#         load_faiss_retriever: bool = False
#     ) -> Tuple[str, str]:
#         """
#         Handles a message within a session, including saving history and getting a response.
#         Allows switching the LLM model and conditionally using the FAISS retriever.
#         """
#         session = db.query(models.Session).options(
#             joinedload(models.Session.messages)
#         ).filter(models.Session.id == session_id).first()
        
#         if not session:
#             raise ValueError(f"Session with ID {session_id} not found.")

#         user_message = models.Message(session_id=session_id, sender="user", content=prompt)
#         db.add(user_message)
#         db.commit()

#         llm_provider = get_llm_provider(model)
#         dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model)
#         dspy.configure(lm=dspy_llm)
        
#         current_retrievers = []
#         if load_faiss_retriever:
#             if self.faiss_retriever:
#                 current_retrievers.append(self.faiss_retriever)
#             else:
#                 print("Warning: FaissDBRetriever requested but not available. Proceeding without it.")
        
#         rag_pipeline = DspyRagPipeline(retrievers=current_retrievers)

#         answer_text = await rag_pipeline.forward(
#             question=prompt, 
#             history=session.messages,
#             db=db
#         )
        
#         assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text)
#         db.add(assistant_message)
#         db.commit()

#         return answer_text, model

#     def get_message_history(self, db: Session, session_id: int) -> List[models.Message]:
#         """
#         Retrieves all messages for a given session, or None if the session doesn't exist.
#         """
#         session = db.query(models.Session).options(
#             joinedload(models.Session.messages)
#         ).filter(models.Session.id == session_id).first()
        
#         return sorted(session.messages, key=lambda msg: msg.created_at) if session else None

#     # --- Document Management (Updated) ---
#     def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int:
#         try:
#             document_db = models.Document(**doc_data)
#             db.add(document_db)
#             db.commit()
#             db.refresh(document_db)
            
#             # Use the embedder provided to the vector store to get the correct model name
#             embedding_model_name = "mock_embedder" if isinstance(self.embedder, MockEmbedder) else "GenAIEmbedder"
            
#             faiss_index = self.vector_store.add_document(document_db.text)
#             vector_metadata = models.VectorMetadata(
#                 document_id=document_db.id,
#                 faiss_index=faiss_index,
#                 embedding_model=embedding_model_name
#             )
#             db.add(vector_metadata)
#             db.commit()
#             return document_db.id
#         except SQLAlchemyError as e:
#             db.rollback()
#             raise

#     def get_all_documents(self, db: Session) -> List[models.Document]:
#         return db.query(models.Document).order_by(models.Document.created_at.desc()).all()

#     def delete_document(self, db: Session, document_id: int) -> int:
#         try:
#             doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first()
#             if not doc_to_delete:
#                 return None
#             db.delete(doc_to_delete)
#             db.commit()
#             return document_id
#         except SQLAlchemyError as e:
#             db.rollback()
#             raise