import dspy
import logging
from typing import List, Callable, Optional
from sqlalchemy.orm import Session

from app.db import models
from app.core.retrievers.base_retriever import Retriever


# --- DSPy Signature Class (No Change) ---
class AnswerWithHistory(dspy.Signature):
    """Generate a natural and context-aware answer to the user's question using the provided knowledge and conversation history."""

    context = dspy.InputField(desc="Relevant excerpts from the knowledge base to support the answer.")
    chat_history = dspy.InputField(desc="The ongoing dialogue between the user and the AI, providing conversational context.")
    question = dspy.InputField(desc="The user's current question.")
    answer = dspy.OutputField(desc="A well-formed answer suitable for delivery in an audio play format.")


# --- DSPy RAG Pipeline Class (Updated) ---
class DspyRagPipeline(dspy.Module):
    """
    A flexible and extensible DSPy-based RAG pipeline with modular stages.
    """

    def __init__(
        self,
        # retrievers: List[Retriever],
        signature_class: dspy.Signature = AnswerWithHistory,
        context_postprocessor: Optional[Callable[[List[str]], str]] = None,
        history_formatter: Optional[Callable[[List[models.Message]], str]] = None,
        response_postprocessor: Optional[Callable[[str], str]] = None,
    ):
        super().__init__()
        # self.retrievers = retrievers
        self.generate_answer = dspy.Predict(signature_class)

        self.context_postprocessor = context_postprocessor or self._default_context_postprocessor
        self.history_formatter = history_formatter or self._default_history_formatter
        self.response_postprocessor = response_postprocessor

    async def forward(self, question: str, history: List[models.Message], context_chunks :List[str]) -> str:
        logging.debug(f"[DspyRagPipeline.forward] Received question: '{question}'")

        # Step 1: Retrieve all document contexts
        # context_chunks = []
        # for retriever in self.retrievers:
        #     context_chunks.extend(retriever.retrieve_context(question, db))

        context_text = self.context_postprocessor(context_chunks)

        # Step 2: Format history
        history_text = self.history_formatter(history)

        # Step 3: Generate response using LLM
        # With DSPy and LiteLLM, the signature-based generation handles the prompt building.
        # You no longer need to manually build the prompt string.
        prediction = await self.generate_answer.aforward(
            context=context_text,
            chat_history=history_text,
            question=question
        )

        raw_response = prediction.answer

        # Step 4: Optional response postprocessing
        if self.response_postprocessor:
            return self.response_postprocessor(raw_response)

        return raw_response

    # Default context processor: concatenate chunks
    def _default_context_postprocessor(self, contexts: List[str]) -> str:
        return "\n\n".join(contexts) or "No context provided."

    # Default history formatter: simple speaker prefix
    def _default_history_formatter(self, history: List[models.Message]) -> str:
        return "\n".join(
            f"{'Human' if msg.sender == 'user' else 'Assistant'}: {msg.content}"
            for msg in history
        )

# Note: The _build_prompt method is removed as DSPy handles this automatically.