from typing import List, Dict, Any, Tuple
from sqlalchemy.orm import Session, joinedload
from sqlalchemy.exc import SQLAlchemyError
import dspy

from app.core.vector_store import FaissVectorStore
from app.db import models
from app.core.retrievers import Retriever
from app.core.llm_providers import get_llm_provider
from app.core.pipelines.dspy_rag import DspyRagPipeline, DSPyLLMProvider


class RAGService:
    """
    Service class for managing documents and conversational RAG sessions.
    """
    def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]):
        self.vector_store = vector_store
        self.retrievers = retrievers

    # --- Session Management ---

    def create_session(self, db: Session, user_id: str, model: str) -> models.Session:
        """Creates a new chat session in the database."""
        try:
            new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session")
            db.add(new_session)
            db.commit()
            db.refresh(new_session)
            return new_session
        except SQLAlchemyError as e:
            db.rollback()
            raise

    async def chat_with_rag(self, db: Session, session_id: int, prompt: str) -> Tuple[str, str]:
        """Handles a message within a session, including saving history and getting a response."""
        # **FIX 1**: Eagerly load the message history in a single query for efficiency.
        session = db.query(models.Session).options(
            joinedload(models.Session.messages)
        ).filter(models.Session.id == session_id).first()
        
        if not session:
            raise ValueError(f"Session with ID {session_id} not found.")

        # Save the new user message to the database
        user_message = models.Message(session_id=session_id, sender="user", content=prompt)
        db.add(user_message)
        db.commit()

        llm_provider = get_llm_provider(session.model_name)
        dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=session.model_name)
        dspy.configure(lm=dspy_llm)
        
        rag_pipeline = DspyRagPipeline(retrievers=self.retrievers)

        # **FIX 2**: Pass the full message history to the pipeline's forward method.
        answer_text = await rag_pipeline.forward(
            question=prompt, 
            history=session.messages, 
            db=db
        )
        
        # Save the assistant's response
        assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text)
        db.add(assistant_message)
        db.commit()

        return answer_text, session.model_name

    def get_message_history(self, db: Session, session_id: int) -> List[models.Message]:
        """
        Retrieves all messages for a given session, or None if the session doesn't exist.
        """
        session = db.query(models.Session).options(
            joinedload(models.Session.messages)
        ).filter(models.Session.id == session_id).first()
        
        return session.messages if session else None

    # --- Document Management (Unchanged) ---
    def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int:
        try:
            document_db = models.Document(**doc_data)
            db.add(document_db)
            db.commit()
            db.refresh(document_db)
            faiss_index = self.vector_store.add_document(document_db.text)
            vector_metadata = models.VectorMetadata(
                document_id=document_db.id,
                faiss_index=faiss_index,
                embedding_model="mock_embedder"
            )
            db.add(vector_metadata)
            db.commit()
            return document_db.id
        except SQLAlchemyError as e:
            db.rollback()
            raise

    def get_all_documents(self, db: Session) -> List[models.Document]:
        return db.query(models.Document).order_by(models.Document.created_at.desc()).all()

    def delete_document(self, db: Session, document_id: int) -> int:
        try:
            doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first()
            if not doc_to_delete:
                return None
            db.delete(doc_to_delete)
            db.commit()
            return document_id
        except SQLAlchemyError as e:
            db.rollback()
            raise