import asyncio
from typing import List, Dict, Any, Tuple
from sqlalchemy.orm import Session, joinedload
from sqlalchemy.exc import SQLAlchemyError
import dspy

from app.core.vector_store import FaissVectorStore
from app.db import models
from app.core.retrievers import Retriever, FaissDBRetriever # Assuming FaissDBRetriever is available
from app.core.llm_providers import get_llm_provider
from app.core.pipelines.dspy_rag import DSPyLLMProvider, DspyRagPipeline

class RAGService:
    """
    Service class for managing documents and conversational RAG sessions.
    """
    def __init__(self, vector_store: FaissVectorStore, retrievers: List[Retriever]):
        self.vector_store = vector_store
        self.retrievers = retrievers
        # Assume one of the retrievers is the FAISS retriever, and you can access it.
        # A better approach might be to have a dictionary of named retrievers.
        self.faiss_retriever = next((r for r in retrievers if isinstance(r, FaissDBRetriever)), None)


    # --- Session Management ---

    def create_session(self, db: Session, user_id: str, model: str) -> models.Session:
        """Creates a new chat session in the database."""
        try:
            new_session = models.Session(user_id=user_id, model_name=model, title=f"New Chat Session")
            db.add(new_session)
            db.commit()
            db.refresh(new_session)
            return new_session
        except SQLAlchemyError as e:
            db.rollback()
            raise

    async def chat_with_rag(
        self,
        db: Session,
        session_id: int,
        prompt: str,
        model: str,
        load_faiss_retriever: bool = False # Add the new parameter with a default value
    ) -> Tuple[str, str]:
        """
        Handles a message within a session, including saving history and getting a response.
        Allows switching the LLM model and conditionally using the FAISS retriever.
        """
        session = db.query(models.Session).options(
            joinedload(models.Session.messages)
        ).filter(models.Session.id == session_id).first()
        
        if not session:
            raise ValueError(f"Session with ID {session_id} not found.")

        user_message = models.Message(session_id=session_id, sender="user", content=prompt)
        db.add(user_message)
        db.commit()

        llm_provider = get_llm_provider(model)
        dspy_llm = DSPyLLMProvider(provider=llm_provider, model_name=model)
        dspy.configure(lm=dspy_llm)
        
        # Conditionally choose the retriever list based on the new parameter
        current_retrievers = []
        if load_faiss_retriever:
            if self.faiss_retriever:
                current_retrievers.append(self.faiss_retriever)
            else:
                # Handle the case where the FaissDBRetriever isn't initialized
                print("Warning: FaissDBRetriever requested but not available. Proceeding without it.")

        # If no specific retriever is requested or available, fall back to a default or empty list
        # This part of the logic may need to be adjusted based on your system's design.
        # For this example, we proceed with an empty list if no retriever is selected.
        
        rag_pipeline = DspyRagPipeline(retrievers=current_retrievers)

        answer_text = await rag_pipeline.forward(
            question=prompt, 
            history=session.messages,
            db=db
        )
        
        assistant_message = models.Message(session_id=session_id, sender="assistant", content=answer_text)
        db.add(assistant_message)
        db.commit()

        return answer_text, model

    def get_message_history(self, db: Session, session_id: int) -> List[models.Message]:
        """
        Retrieves all messages for a given session, or None if the session doesn't exist.
        """
        session = db.query(models.Session).options(
            joinedload(models.Session.messages)
        ).filter(models.Session.id == session_id).first()
        
        return sorted(session.messages, key=lambda msg: msg.created_at) if session else None

    # --- Document Management (Unchanged) ---
    def add_document(self, db: Session, doc_data: Dict[str, Any]) -> int:
        try:
            document_db = models.Document(**doc_data)
            db.add(document_db)
            db.commit()
            db.refresh(document_db)
            faiss_index = self.vector_store.add_document(document_db.text)
            vector_metadata = models.VectorMetadata(
                document_id=document_db.id,
                faiss_index=faiss_index,
                embedding_model="mock_embedder"
            )
            db.add(vector_metadata)
            db.commit()
            return document_db.id
        except SQLAlchemyError as e:
            db.rollback()
            raise

    def get_all_documents(self, db: Session) -> List[models.Document]:
        return db.query(models.Document).order_by(models.Document.created_at.desc()).all()

    def delete_document(self, db: Session, document_id: int) -> int:
        try:
            doc_to_delete = db.query(models.Document).filter(models.Document.id == document_id).first()
            if not doc_to_delete:
                return None
            db.delete(doc_to_delete)
            db.commit()
            return document_id
        except SQLAlchemyError as e:
            db.rollback()
            raise