import dspy
import logging
from typing import List, Callable, Optional
from sqlalchemy.orm import Session
from app.db import models
from app.core.retrievers.base_retriever import Retriever
# --- DSPy Signature Class (No Change) ---
class AnswerWithHistory(dspy.Signature):
"""Generate a natural and context-aware answer to the user's question using the provided knowledge and conversation history."""
context = dspy.InputField(desc="Relevant excerpts from the knowledge base to support the answer.")
chat_history = dspy.InputField(desc="The ongoing dialogue between the user and the AI, providing conversational context.")
question = dspy.InputField(desc="The user's current question.")
answer = dspy.OutputField(desc="A well-formed answer suitable for delivery in an audio play format.")
# --- DSPy RAG Pipeline Class (Updated) ---
class DspyRagPipeline(dspy.Module):
"""
A flexible and extensible DSPy-based RAG pipeline with modular stages.
"""
def __init__(
self,
# retrievers: List[Retriever],
signature_class: dspy.Signature = AnswerWithHistory,
context_postprocessor: Optional[Callable[[List[str]], str]] = None,
history_formatter: Optional[Callable[[List[models.Message]], str]] = None,
response_postprocessor: Optional[Callable[[str], str]] = None,
):
super().__init__()
# self.retrievers = retrievers
self.generate_answer = dspy.Predict(signature_class)
self.context_postprocessor = context_postprocessor or self._default_context_postprocessor
self.history_formatter = history_formatter or self._default_history_formatter
self.response_postprocessor = response_postprocessor
async def forward(self, question: str, history: List[models.Message], context_chunks :List[str]) -> str:
logging.debug(f"[DspyRagPipeline.forward] Received question: '{question}'")
# Step 1: Retrieve all document contexts
# context_chunks = []
# for retriever in self.retrievers:
# context_chunks.extend(retriever.retrieve_context(question, db))
context_text = self.context_postprocessor(context_chunks)
# Step 2: Format history
history_text = self.history_formatter(history)
# Step 3: Generate response using LLM
# With DSPy and LiteLLM, the signature-based generation handles the prompt building.
# You no longer need to manually build the prompt string.
prediction = await self.generate_answer.aforward(
context=context_text,
chat_history=history_text,
question=question
)
raw_response = prediction.answer
# Step 4: Optional response postprocessing
if self.response_postprocessor:
return self.response_postprocessor(raw_response)
return raw_response
# Default context processor: concatenate chunks
def _default_context_postprocessor(self, contexts: List[str]) -> str:
return "\n\n".join(contexts) or "No context provided."
# Default history formatter: simple speaker prefix
def _default_history_formatter(self, history: List[models.Message]) -> str:
return "\n".join(
f"{'Human' if msg.sender == 'user' else 'Assistant'}: {msg.content}"
for msg in history
)
# Note: The _build_prompt method is removed as DSPy handles this automatically.