from fastapi import APIRouter, HTTPException, Depends, File, UploadFile, Response
from fastapi.responses import FileResponse
from sqlalchemy.orm import Session
from app.api.dependencies import ServiceContainer, get_db
from app.api import schemas
from typing import AsyncGenerator, List, Optional
from app.db import models
from app.core.pipelines.validator import Validator
import os
import shutil

def create_sessions_router(services: ServiceContainer) -> APIRouter:
    router = APIRouter(prefix="/sessions", tags=["Sessions"])

    @router.post("/", response_model=schemas.Session, summary="Create a New Chat Session")
    def create_session(
        request: schemas.SessionCreate,
        db: Session = Depends(get_db)
    ):
        if request.user_id is None or request.provider_name is None:
            raise HTTPException(status_code=400, detail="user_id and provider_name are required to create a session.")
        try:
            new_session = services.session_service.create_session(
                db=db,
                user_id=request.user_id,
                provider_name=request.provider_name,
                feature_name=request.feature_name,
                stt_provider_name=request.stt_provider_name,
                tts_provider_name=request.tts_provider_name
            )
            return new_session
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Failed to create session: {e}")

    @router.post("/{session_id}/chat", response_model=schemas.ChatResponse, summary="Send a Message in a Session")
    async def chat_in_session(
        session_id: int,
        request: schemas.ChatRequest,
        db: Session = Depends(get_db)
    ):
        try:
            response_text, provider_used, message_id = await services.rag_service.chat_with_rag(
                db=db,
                session_id=session_id,
                prompt=request.prompt,
                provider_name=request.provider_name,
                load_faiss_retriever=request.load_faiss_retriever,
                user_service=services.user_service
            )
            return schemas.ChatResponse(answer=response_text, provider_used=provider_used, message_id=message_id)
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"An error occurred during chat: {e}")

    @router.get("/{session_id}/messages", response_model=schemas.MessageHistoryResponse, summary="Get Session Chat History")
    def get_session_messages(session_id: int, db: Session = Depends(get_db)):
        try:
            messages = services.rag_service.get_message_history(db=db, session_id=session_id)
            if messages is None:
                raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.")

            # Enhance messages with audio availability
            enhanced_messages = []
            for m in messages:
                msg_dict = schemas.Message.model_validate(m).model_dump()
                if m.audio_path and os.path.exists(m.audio_path):
                    msg_dict["has_audio"] = True
                    msg_dict["audio_url"] = f"/sessions/messages/{m.id}/audio"
                enhanced_messages.append(msg_dict)

            return schemas.MessageHistoryResponse(session_id=session_id, messages=enhanced_messages)
        except HTTPException:
            raise
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
            
    @router.get("/{session_id}/tokens", response_model=schemas.SessionTokenUsageResponse, summary="Get Session Token Usage")
    def get_session_token_usage(session_id: int, db: Session = Depends(get_db)):
        try:
            messages = services.rag_service.get_message_history(db=db, session_id=session_id)
            if messages is None:
                raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found.")

            combined_text = " ".join([m.content for m in messages])
            validator = Validator()
            token_count = len(validator.encoding.encode(combined_text))
            token_limit = validator.token_limit
            percentage = round((token_count / token_limit) * 100, 2) if token_limit > 0 else 0.0

            return schemas.SessionTokenUsageResponse(
                token_count=token_count,
                token_limit=token_limit,
                percentage=percentage
            )
        except HTTPException:
            raise
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"An error occurred: {e}")

    @router.get("/", response_model=List[schemas.Session], summary="Get All Chat Sessions")
    def get_sessions(
        user_id: str,
        feature_name: str = "default",
        db: Session = Depends(get_db)
    ):
        try:
            sessions = db.query(models.Session).filter(
                models.Session.user_id == user_id,
                models.Session.feature_name == feature_name,
                models.Session.is_archived == False
            ).order_by(models.Session.created_at.desc()).all()
            return sessions
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Failed to fetch sessions: {e}")

    @router.get("/{session_id}", response_model=schemas.Session, summary="Get a Single Session")
    def get_session(session_id: int, db: Session = Depends(get_db)):
        try:
            session = db.query(models.Session).filter(
                models.Session.id == session_id,
                models.Session.is_archived == False
            ).first()
            if not session:
                raise HTTPException(status_code=404, detail="Session not found.")
            return session
        except HTTPException:
            raise
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Failed to fetch session: {e}")

    @router.patch("/{session_id}", response_model=schemas.Session, summary="Update a Chat Session")
    def update_session(session_id: int, session_update: schemas.SessionUpdate, db: Session = Depends(get_db)):
        try:
            session = db.query(models.Session).filter(
                models.Session.id == session_id,
                models.Session.is_archived == False
            ).first()
            if not session:
                raise HTTPException(status_code=404, detail="Session not found.")
            
            if session_update.title is not None:
                session.title = session_update.title
            if session_update.provider_name is not None:
                session.provider_name = session_update.provider_name
            if session_update.stt_provider_name is not None:
                session.stt_provider_name = session_update.stt_provider_name
            if session_update.tts_provider_name is not None:
                session.tts_provider_name = session_update.tts_provider_name
                
            db.commit()
            db.refresh(session)
            return session
        except HTTPException:
            raise
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Failed to update session: {e}")

    @router.delete("/{session_id}", summary="Delete a Chat Session")
    def delete_session(session_id: int, db: Session = Depends(get_db)):
        try:
            session = db.query(models.Session).filter(models.Session.id == session_id).first()
            if not session:
                raise HTTPException(status_code=404, detail="Session not found.")
            session.is_archived = True
            db.commit()
            return {"message": "Session deleted successfully."}
        except HTTPException:
            raise
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Failed to delete session: {e}")

    @router.delete("/", summary="Delete All Sessions for Feature")
    def delete_all_sessions(user_id: str, feature_name: str = "default", db: Session = Depends(get_db)):
        try:
            sessions = db.query(models.Session).filter(
                models.Session.user_id == user_id,
                models.Session.feature_name == feature_name
            ).all()
            for session in sessions:
                session.is_archived = True
            db.commit()
            return {"message": "All sessions deleted successfully."}
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Failed to delete all sessions: {e}")

    @router.post("/messages/{message_id}/audio", summary="Upload audio for a specific message")
    async def upload_message_audio(message_id: int, file: UploadFile = File(...), db: Session = Depends(get_db)):
        try:
            message = db.query(models.Message).filter(models.Message.id == message_id).first()
            if not message:
                raise HTTPException(status_code=404, detail="Message not found.")
            
            # Create data directory if not exists
            audio_dir = "/app/data/audio"
            os.makedirs(audio_dir, exist_ok=True)
            
            # Save file
            file_path = f"{audio_dir}/message_{message_id}.wav"
            with open(file_path, "wb") as buffer:
                shutil.copyfileobj(file.file, buffer)
            
            # Update database
            message.audio_path = file_path
            db.commit()
            
            return {"message": "Audio uploaded successfully.", "audio_path": file_path}
        except Exception as e:
            print(f"Error uploading audio: {e}")
            raise HTTPException(status_code=500, detail=f"Failed to upload audio: {e}")

    @router.get("/messages/{message_id}/audio", summary="Get audio for a specific message")
    async def get_message_audio(message_id: int, db: Session = Depends(get_db)):
        try:
            message = db.query(models.Message).filter(models.Message.id == message_id).first()
            if not message or not message.audio_path:
                raise HTTPException(status_code=404, detail="Audio not found for this message.")
            
            if not os.path.exists(message.audio_path):
                 raise HTTPException(status_code=404, detail="Audio file missing on disk.")

            return FileResponse(message.audio_path, media_type="audio/wav")
        except HTTPException:
            raise
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Failed to get audio: {e}")

    return router