Newer
Older
cortex-hub / ai-hub / app / core / pipelines / dspy_rag.py
import dspy
import logging
from typing import List
from types import SimpleNamespace
from sqlalchemy.orm import Session

from app.db import models # Import your SQLAlchemy models
from app.core.retrievers import Retriever
from app.core.llm_providers import LLMProvider

# (The DSPyLLMProvider class is unchanged)
class DSPyLLMProvider(dspy.BaseLM):
    def __init__(self, provider: LLMProvider, model_name: str, **kwargs):
        super().__init__(model=model_name)
        self.provider = provider
        self.kwargs.update(kwargs)

    async def aforward(self, prompt: str, **kwargs):
        if not prompt or not prompt.strip():
            return SimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content="Error: Empty prompt."))])
        response_text = await self.provider.generate_response(prompt)
        choice = SimpleNamespace(message=SimpleNamespace(content=response_text))
        return SimpleNamespace(choices=[choice])

# --- 1. Update the Signature to include Chat History ---
class AnswerWithHistory(dspy.Signature):
    """Given the context and chat history, answer the user's question."""
    
    context = dspy.InputField(desc="Relevant document snippets from the knowledge base.")
    chat_history = dspy.InputField(desc="The ongoing conversation between the user and the AI.")
    question = dspy.InputField()
    answer = dspy.OutputField()

class DspyRagPipeline(dspy.Module):
    """
    A conversational RAG pipeline that uses document context and chat history.
    """
    def __init__(self, retrievers: List[Retriever]):
        super().__init__()
        self.retrievers = retrievers
        # Use the new signature that includes history
        self.generate_answer = dspy.Predict(AnswerWithHistory)

    # --- 2. Update the `forward` method to accept history ---
    async def forward(self, question: str, history: List[models.Message], db: Session) -> str:
        """
        Executes the RAG pipeline using the question and the conversation history.
        """
        logging.debug(f"[DspyRagPipeline.forward] Received question: '{question}'")
        
        # Retrieve document context based on the current question
        retrieved_contexts = []
        for retriever in self.retrievers:
            context = retriever.retrieve_context(question, db)
            retrieved_contexts.extend(context)

        context_text = "\n\n".join(retrieved_contexts) or "No context provided."

        # --- 3. Format the chat history into a string ---
        history_str = "\n".join(
            f"{'Human' if msg.sender == 'user' else 'Assistant'}: {msg.content}"
            for msg in history
        )

        # --- 4. Build the final prompt including history ---
        instruction = self.generate_answer.signature.__doc__
        full_prompt = (
            f"{instruction}\n\n"
            f"---\n\n"
            f"Context: {context_text}\n\n"
            f"---\n\n"
            f"Chat History:\n{history_str}\n\n"
            f"---\n\n"
            f"Human: {question}\n"
            f"Assistant:"
        )
        
        lm = dspy.settings.lm
        if lm is None:
            raise RuntimeError("DSPy LM not configured.")
            
        response_obj = await lm.aforward(prompt=full_prompt)
        return response_obj.choices[0].message.content