import dspy import logging from typing import List, Callable, Optional from types import SimpleNamespace from sqlalchemy.orm import Session from app.db import models from app.core.retrievers.base_retriever import Retriever from app.core.providers.base import LLMProvider class DSPyLLMProvider(dspy.BaseLM): def __init__(self, provider: LLMProvider, model_name: str, **kwargs): super().__init__(model=model_name) self.provider = provider self.kwargs.update(kwargs) async def aforward(self, prompt: str, **kwargs): if not prompt or not prompt.strip(): return SimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content="Error: Empty prompt."))]) response_text = await self.provider.generate_response(prompt) choice = SimpleNamespace(message=SimpleNamespace(content=response_text)) return SimpleNamespace(choices=[choice]) class AnswerWithHistory(dspy.Signature): """Given the context and chat history, answer the user's question.""" context = dspy.InputField(desc="Relevant document snippets from the knowledge base.") chat_history = dspy.InputField(desc="The ongoing conversation between the user and the AI.") question = dspy.InputField() answer = dspy.OutputField() class DspyRagPipeline(dspy.Module): """ A flexible and extensible DSPy-based RAG pipeline with modular stages. """ def __init__( self, retrievers: List[Retriever], signature_class: dspy.Signature = AnswerWithHistory, context_postprocessor: Optional[Callable[[List[str]], str]] = None, history_formatter: Optional[Callable[[List[models.Message]], str]] = None, response_postprocessor: Optional[Callable[[str], str]] = None, ): super().__init__() self.retrievers = retrievers self.generate_answer = dspy.Predict(signature_class) self.context_postprocessor = context_postprocessor or self._default_context_postprocessor self.history_formatter = history_formatter or self._default_history_formatter self.response_postprocessor = response_postprocessor async def forward(self, question: str, history: List[models.Message], db: Session) -> str: logging.debug(f"[DspyRagPipeline.forward] Received question: '{question}'") # Step 1: Retrieve all document contexts context_chunks = [] for retriever in self.retrievers: context_chunks.extend(retriever.retrieve_context(question, db)) context_text = self.context_postprocessor(context_chunks) # Step 2: Format history history_text = self.history_formatter(history) # Step 3: Build final prompt instruction = self.generate_answer.signature.__doc__ full_prompt = self._build_prompt(instruction, context_text, history_text, question) logging.debug(f"[DspyRagPipeline.forward] Full Prompt:\n{full_prompt}") # Step 4: Generate response using LLM lm = dspy.settings.lm if lm is None: raise RuntimeError("DSPy LM not configured.") response_obj = await lm.aforward(prompt=full_prompt) raw_response = response_obj.choices[0].message.content # Step 5: Optional response postprocessing if self.response_postprocessor: return self.response_postprocessor(raw_response) return raw_response # Default context processor: concatenate chunks def _default_context_postprocessor(self, contexts: List[str]) -> str: return "\n\n".join(contexts) or "No context provided." # Default history formatter: simple speaker prefix def _default_history_formatter(self, history: List[models.Message]) -> str: return "\n".join( f"{'Human' if msg.sender == 'user' else 'Assistant'}: {msg.content}" for msg in history ) # Prompt builder def _build_prompt(self, instruction: str, context: str, history: str, question: str) -> str: return ( f"{instruction.strip()}\n\n" f"---\n\n" f"Context:\n{context.strip()}\n\n" f"---\n\n" f"Chat History:\n{history.strip()}\n\n" f"---\n\n" f"Human: {question.strip()}\n" f"Assistant:" )