Newer
Older
cortex-hub / ai-hub / app / core / pipelines / dspy_rag.py
# In app/core/pipelines/dspy_rag.py

import dspy
import logging
from typing import List
from types import SimpleNamespace
from sqlalchemy.orm import Session

from app.core.retrievers import Retriever
from app.core.llm_providers import LLMProvider

class DSPyLLMProvider(dspy.BaseLM):
    """
    A custom wrapper for the LLMProvider to make it compatible with DSPy.
    """
    def __init__(self, provider: LLMProvider, model_name: str, **kwargs):
        super().__init__(model=model_name)
        self.provider = provider
        self.kwargs.update(kwargs)
        print(f"DSPyLLMProvider initialized for model: {self.model}")

    async def aforward(self, prompt: str, **kwargs):
        """
        The required asynchronous forward pass for the language model.
        """
        logging.info(f"[DSPyLLMProvider.aforward] Received prompt of length: {len(prompt) if prompt else 0}")
        if not prompt or not prompt.strip():
            logging.error("[DSPyLLMProvider.aforward] Received a null or empty prompt!")
            return SimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content="Error: Received an empty prompt."))])
        
        response_text = await self.provider.generate_response(prompt)

        mock_choice = SimpleNamespace(message=SimpleNamespace(content=response_text, tool_calls=None))
        return SimpleNamespace(choices=[mock_choice], usage=SimpleNamespace(prompt_tokens=0, completion_tokens=0, total_tokens=0), model=self.model)

class AnswerWithContext(dspy.Signature):
    """Given the context, answer the user's question."""
    context = dspy.InputField(desc="Relevant document snippets from the knowledge base.")
    question = dspy.InputField()
    answer = dspy.OutputField()

class DspyRagPipeline(dspy.Module):
    """
    A simple RAG pipeline that retrieves context and then generates an answer using DSPy.
    """
    def __init__(self, retrievers: List[Retriever]):
        super().__init__()
        self.retrievers = retrievers
        # We still define the predictor to access its signature easily.
        self.generate_answer = dspy.Predict(AnswerWithContext)

    async def forward(self, question: str, db: Session) -> str:
        """
        Executes the RAG pipeline asynchronously.
        """
        logging.info(f"[DspyRagPipeline.forward] Received question: '{question}'")
        retrieved_contexts = []
        for retriever in self.retrievers:
            context = retriever.retrieve_context(question, db)
            retrieved_contexts.extend(context)

        context_text = "\n\n".join(retrieved_contexts)
        if not context_text:
            print("⚠️ No context retrieved. Falling back to direct QA.")
            context_text = "No context provided."

        lm = dspy.settings.lm
        if lm is None:
            raise RuntimeError("DSPy LM has not been configured. Call dspy.configure(lm=...) first.")

        # --- FIX: Revert to manual prompt construction ---
        # Get the instruction from the signature's docstring.
        instruction = self.generate_answer.signature.__doc__

        # Build the full prompt exactly as DSPy would.
        full_prompt = (
            f"{instruction}\n\n"
            f"---\n\n"
            f"Context: {context_text}\n\n"
            f"Question: {question}\n\n"
            f"Answer:"
        )
        
        # Call the language model's aforward method directly with the complete prompt.
        response_obj = await lm.aforward(prompt=full_prompt)
        
        return response_obj.choices[0].message.content