Newer
Older
cortex-hub / ai-hub / app / core / rag_service.py
import asyncio
from typing import List, Dict, Any
from types import SimpleNamespace
from sqlalchemy.orm import Session
from sqlalchemy.exc import SQLAlchemyError
import dspy
import logging

from app.core.vector_store import FaissVectorStore
from app.db import models
from app.core.retrievers import Retriever
from app.core.llm_providers import LLMProvider, get_llm_provider

# --- DSPy Components for RAG ---

class DSPyLLMProvider(dspy.BaseLM):
    """
    A custom wrapper for the LLMProvider to make it compatible with DSPy.
    """
    def __init__(self, provider: LLMProvider, model_name: str, **kwargs):
        super().__init__(model=model_name)
        self.provider = provider
        self.kwargs.update(kwargs)
        print(f"DSPyLLMProvider initialized for model: {self.model}")

    async def aforward(self, prompt: str, **kwargs):
        """
        The required asynchronous forward pass for the language model.
        """
        logging.info(f"[DSPyLLMProvider.aforward] Received prompt of length: {len(prompt) if prompt else 0}")

        # --- CRITICAL FIX: Ensure prompt is not None or empty ---
        if not prompt or not prompt.strip():
            logging.error("[DSPyLLMProvider.aforward] Received a null or empty prompt!")
            # Return a default, safe response instead of calling the API with null.
            return SimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content="Error: Received an empty prompt."))])

        # Call the async provider directly using the existing event loop
        response_text = await self.provider.generate_response(prompt)

        # Create a mock response object that mimics the OpenAI API structure
        mock_choice = SimpleNamespace(
            message=SimpleNamespace(content=response_text, tool_calls=None)
        )
        mock_response = SimpleNamespace(
            choices=[mock_choice],
            usage=SimpleNamespace(prompt_tokens=0, completion_tokens=0, total_tokens=0),
            model=self.model
        )
        return mock_response

class AnswerWithContext(dspy.Signature):
    """
    Signature for our RAG task: input is a context and question, output is an answer.
    """
    context = dspy.InputField(desc="Relevant document snippets from the knowledge base.")
    question = dspy.InputField()
    answer = dspy.OutputField()

class RAGPipeline(dspy.Module):
    """
    A simple RAG pipeline that retrieves context and then generates an answer.
    """
    def __init__(self, retrievers: List[Retriever]):
        super().__init__()
        self.retrievers = retrievers
        # We only need the signature here to generate the prompt text.
        self.generate_answer = dspy.Predict(AnswerWithContext)

    async def forward(self, question: str, db: Session) -> str:
        """
        Executes the RAG pipeline asynchronously.
        """
        logging.info(f"[RAGPipeline.forward] Received question: '{question}'")
        retrieved_contexts = []
        for retriever in self.retrievers:
            context = retriever.retrieve_context(question, db)
            retrieved_contexts.extend(context)

        context_text = "\n\n".join(retrieved_contexts)
        if not context_text:
            print("⚠️ No context retrieved. Falling back to direct QA.")
            context_text = "No context provided."

        # --- REVISED LOGIC ---
        # 1. Manually create the full prompt using the signature's template.
        #    The `dspy.Predict` object can be called with the inputs to get the compiled prompt.
        #    We access the last generated prompt from the LM's history.
        #    Since we haven't called the LM yet, we temporarily configure a basic LM.
        
        # Get the configured language model from dspy settings
        lm = dspy.settings.lm
        if lm is None:
            raise RuntimeError("DSPy LM has not been configured. Call dspy.configure(lm=...) first.")

        # 2. Use the signature to create a dspy.Example, which generates the prompt.
        # The dspy.Predict module will format this into a prompt string.
        example = dspy.Example(context=context_text, question=question, signatures=self.generate_answer.signature)
        
        # 3. Call the language model directly with the full prompt string.
        # The `example.signatures` contains the logic to render the prompt.
        # In modern DSPy, `dspy.predict` is a simpler way to do this.
        # We will call the LM's aforward method directly for clarity.
        full_prompt = self.generate_answer.signature.instructions.format(context=context_text, question=question) + "\nAnswer:"
        
        response_obj = await lm.aforward(prompt=full_prompt)
        
        return response_obj.choices[0].message.content


# --- Main RAG Service Class --- (This class remains unchanged)
class RAGService:
    def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]):
        self.vector_store = vector_store
        self.retrievers = retrievers

    def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int:
        try:
            document_db = models.Document(
                title=doc_data["title"],
                text=doc_data["text"],
                source_url=doc_data["source_url"]
            )
            db.add(document_db)
            db.commit()
            db.refresh(document_db)
            faiss_index = self.vector_store.add_document(document_db.text)
            vector_metadata = models.VectorMetadata(
                document_id=document_db.id,
                faiss_index=faiss_index,
                embedding_model="mock_embedder"
            )
            db.add(vector_metadata)
            db.commit()
            print(f"Document with ID {document_db.id} successfully added.")
            return document_db.id
        except SQLAlchemyError as e:
            db.rollback()
            print(f"Database error while adding document: {e}")
            raise
        except Exception as e:
            db.rollback()
            print(f"An unexpected error occurred: {e}")
            raise

    async def chat_with_rag(self, db: Session, prompt: str, model: str) -> str:
        print(f"Received Prompt: {prompt}")
        if not prompt or not prompt.strip():
            raise ValueError("The prompt cannot be null, empty, or contain only whitespace.")
        
        llm_provider_instance = get_llm_provider(model)
        dspy_llm_provider = DSPyLLMProvider(provider=llm_provider_instance, model_name=model)
        
        # Configure dspy's global settings with our custom LM
        dspy.configure(lm=dspy_llm_provider)
        
        rag_pipeline = RAGPipeline(retrievers=self.retrievers)
        answer = await rag_pipeline.forward(question=prompt, db=db)
        return answer