Newer
Older
cortex-hub / ai-hub / app / core / pipelines / dspy_rag.py
import dspy
import logging
from typing import List, Callable, Optional
from types import SimpleNamespace
from sqlalchemy.orm import Session

from app.db import models
from app.core.retrievers.base_retriever import Retriever
from app.core.providers.base import LLMProvider


class DSPyLLMProvider(dspy.BaseLM):
    def __init__(self, provider: LLMProvider, model_name: str, **kwargs):
        super().__init__(model=model_name)
        self.provider = provider
        self.kwargs.update(kwargs)

    async def aforward(self, prompt: str, **kwargs):
        if not prompt or not prompt.strip():
            return SimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content="Error: Empty prompt."))])
        response_text = await self.provider.generate_response(prompt)
        choice = SimpleNamespace(message=SimpleNamespace(content=response_text))
        return SimpleNamespace(choices=[choice])


class AnswerWithHistory(dspy.Signature):
    """Given the context and chat history, answer the user's question."""
    context = dspy.InputField(desc="Relevant document snippets from the knowledge base.")
    chat_history = dspy.InputField(desc="The ongoing conversation between the user and the AI.")
    question = dspy.InputField()
    answer = dspy.OutputField()


class DspyRagPipeline(dspy.Module):
    """
    A flexible and extensible DSPy-based RAG pipeline with modular stages.
    """

    def __init__(
        self,
        retrievers: List[Retriever],
        signature_class: dspy.Signature = AnswerWithHistory,
        context_postprocessor: Optional[Callable[[List[str]], str]] = None,
        history_formatter: Optional[Callable[[List[models.Message]], str]] = None,
        response_postprocessor: Optional[Callable[[str], str]] = None,
    ):
        super().__init__()
        self.retrievers = retrievers
        self.generate_answer = dspy.Predict(signature_class)

        self.context_postprocessor = context_postprocessor or self._default_context_postprocessor
        self.history_formatter = history_formatter or self._default_history_formatter
        self.response_postprocessor = response_postprocessor

    async def forward(self, question: str, history: List[models.Message], db: Session) -> str:
        logging.debug(f"[DspyRagPipeline.forward] Received question: '{question}'")

        # Step 1: Retrieve all document contexts
        context_chunks = []
        for retriever in self.retrievers:
            context_chunks.extend(retriever.retrieve_context(question, db))

        context_text = self.context_postprocessor(context_chunks)

        # Step 2: Format history
        history_text = self.history_formatter(history)

        # Step 3: Build final prompt
        instruction = self.generate_answer.signature.__doc__
        full_prompt = self._build_prompt(instruction, context_text, history_text, question)

        logging.debug(f"[DspyRagPipeline.forward] Full Prompt:\n{full_prompt}")

        # Step 4: Generate response using LLM
        lm = dspy.settings.lm
        if lm is None:
            raise RuntimeError("DSPy LM not configured.")

        response_obj = await lm.aforward(prompt=full_prompt)
        raw_response = response_obj.choices[0].message.content

        # Step 5: Optional response postprocessing
        if self.response_postprocessor:
            return self.response_postprocessor(raw_response)

        return raw_response

    # Default context processor: concatenate chunks
    def _default_context_postprocessor(self, contexts: List[str]) -> str:
        return "\n\n".join(contexts) or "No context provided."

    # Default history formatter: simple speaker prefix
    def _default_history_formatter(self, history: List[models.Message]) -> str:
        return "\n".join(
            f"{'Human' if msg.sender == 'user' else 'Assistant'}: {msg.content}"
            for msg in history
        )

    # Prompt builder
    def _build_prompt(self, instruction: str, context: str, history: str, question: str) -> str:
        return (
            f"{instruction.strip()}\n\n"
            f"---\n\n"
            f"Context:\n{context.strip()}\n\n"
            f"---\n\n"
            f"Chat History:\n{history.strip()}\n\n"
            f"---\n\n"
            f"Human: {question.strip()}\n"
            f"Assistant:"
        )