import pytest
from unittest.mock import MagicMock, AsyncMock
from fastapi import FastAPI, Response
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session
from datetime import datetime
from httpx import AsyncClient, ASGITransport
import asyncio

# Import the dependencies and router factory
from app.api.dependencies import get_db, ServiceContainer
from app.core.services.rag import RAGService
from app.core.services.document import DocumentService
from app.core.services.tts import TTSService
from app.api.routes import create_api_router
from app.db import models

@pytest.fixture
def client():
    """
    Pytest fixture to create a TestClient with a fully mocked environment
    for synchronous endpoints.
    """
    test_app = FastAPI()

    mock_rag_service = MagicMock(spec=RAGService)
    mock_document_service = MagicMock(spec=DocumentService)
    mock_tts_service = MagicMock(spec=TTSService)

    mock_services = MagicMock(spec=ServiceContainer)
    mock_services.rag_service = mock_rag_service
    mock_services.document_service = mock_document_service
    mock_services.tts_service = mock_tts_service

    mock_db_session = MagicMock(spec=Session)

    def override_get_db():
        yield mock_db_session

    api_router = create_api_router(services=mock_services)
    test_app.dependency_overrides[get_db] = override_get_db
    test_app.include_router(api_router)

    test_client = TestClient(test_app)

    yield test_client, mock_services

@pytest.fixture
async def async_client():
    """
    Pytest fixture to create an AsyncClient for testing async endpoints.
    """
    test_app = FastAPI()

    mock_rag_service = MagicMock(spec=RAGService)
    mock_document_service = MagicMock(spec=DocumentService)
    mock_tts_service = MagicMock(spec=TTSService)

    mock_services = MagicMock(spec=ServiceContainer)
    mock_services.rag_service = mock_rag_service
    mock_services.document_service = mock_document_service
    mock_services.tts_service = mock_tts_service

    mock_db_session = MagicMock(spec=Session)

    def override_get_db():
        yield mock_db_session

    api_router = create_api_router(services=mock_services)
    test_app.dependency_overrides[get_db] = override_get_db
    test_app.include_router(api_router)

    async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client:
        yield client, mock_services

# --- General Endpoint ---

def test_read_root(client):
    """Tests the root endpoint."""
    test_client, _ = client
    response = test_client.get("/")
    assert response.status_code == 200
    assert response.json() == {"status": "AI Model Hub is running!"}

# --- Session and Chat Endpoints ---

def test_create_session_success(client):
    """Tests successfully creating a new chat session."""
    test_client, mock_services = client
    mock_session = models.Session(id=1, user_id="test_user", model_name="gemini", title="New Chat", created_at=datetime.now())
    mock_services.rag_service.create_session.return_value = mock_session

    response = test_client.post("/sessions", json={"user_id": "test_user", "model": "gemini"})

    assert response.status_code == 200
    assert response.json()["id"] == 1
    mock_services.rag_service.create_session.assert_called_once()

def test_chat_in_session_success(client):
    """
    Tests sending a message in an existing session without specifying a model
    or retriever. It should default to 'deepseek' and 'False'.
    """
    test_client, mock_services = client
    mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek"))

    response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there"})

    assert response.status_code == 200
    assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"}

    mock_services.rag_service.chat_with_rag.assert_called_once_with(
        db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'],
        session_id=42,
        prompt="Hello there",
        model="deepseek",
        load_faiss_retriever=False
    )

def test_chat_in_session_with_model_switch(client):
    """
    Tests sending a message in an existing session and explicitly switching the model.
    """
    test_client, mock_services = client
    mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini"))

    response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"})

    assert response.status_code == 200
    assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"}

    mock_services.rag_service.chat_with_rag.assert_called_once_with(
        db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'],
        session_id=42,
        prompt="Hello there, Gemini!",
        model="gemini",
        load_faiss_retriever=False
    )

def test_chat_in_session_with_faiss_retriever(client):
    """
    Tests sending a message and explicitly enabling the FAISS retriever.
    """
    test_client, mock_services = client
    mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Response with context", "deepseek"))

    response = test_client.post(
        "/sessions/42/chat",
        json={"prompt": "What is RAG?", "load_faiss_retriever": True}
    )

    assert response.status_code == 200
    assert response.json() == {"answer": "Response with context", "model_used": "deepseek"}

    mock_services.rag_service.chat_with_rag.assert_called_once_with(
        db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'],
        session_id=42,
        prompt="What is RAG?",
        model="deepseek",
        load_faiss_retriever=True
    )

def test_get_session_messages_success(client):
    """Tests retrieving the message history for a session."""
    test_client, mock_services = client
    mock_history = [
        models.Message(sender="user", content="Hello", created_at=datetime.now()),
        models.Message(sender="assistant", content="Hi there!", created_at=datetime.now())
    ]
    mock_services.rag_service.get_message_history.return_value = mock_history

    response = test_client.get("/sessions/123/messages")

    assert response.status_code == 200
    response_data = response.json()
    assert response_data["session_id"] == 123
    assert len(response_data["messages"]) == 2
    assert response_data["messages"][0]["sender"] == "user"
    assert response_data["messages"][1]["content"] == "Hi there!"
    mock_services.rag_service.get_message_history.assert_called_once_with(
        db=mock_services.rag_service.get_message_history.call_args.kwargs['db'],
        session_id=123
    )

def test_get_session_messages_not_found(client):
    """Tests retrieving messages for a session that does not exist."""
    test_client, mock_services = client
    mock_services.rag_service.get_message_history.return_value = None

    response = test_client.get("/sessions/999/messages")

    assert response.status_code == 404
    assert response.json()["detail"] == "Session with ID 999 not found."

# --- Document Endpoints ---

def test_add_document_success(client):
    """Tests the /documents endpoint for adding a new document."""
    test_client, mock_services = client
    mock_services.document_service.add_document.return_value = 123
    doc_payload = {"title": "Test Doc", "text": "Content here"}
    response = test_client.post("/documents", json=doc_payload)
    assert response.status_code == 200
    assert response.json()["message"] == "Document 'Test Doc' added successfully with ID 123"

def test_get_documents_success(client):
    """Tests the /documents endpoint for retrieving all documents."""
    test_client, mock_services = client
    mock_docs = [
        models.Document(id=1, title="Doc One", status="ready", created_at=datetime.now()),
        models.Document(id=2, title="Doc Two", status="processing", created_at=datetime.now())
    ]
    mock_services.document_service.get_all_documents.return_value = mock_docs
    response = test_client.get("/documents")
    assert response.status_code == 200
    assert len(response.json()["documents"]) == 2

def test_delete_document_success(client):
    """Tests the DELETE /documents/{document_id} endpoint for successful deletion."""
    test_client, mock_services = client
    mock_services.document_service.delete_document.return_value = 42
    response = test_client.delete("/documents/42")
    assert response.status_code == 200
    assert response.json()["document_id"] == 42

def test_delete_document_not_found(client):
    """Tests the DELETE /documents/{document_id} endpoint when the document is not found."""
    test_client, mock_services = client
    mock_services.document_service.delete_document.return_value = None
    response = test_client.delete("/documents/999")
    assert response.status_code == 404

@pytest.mark.asyncio
async def test_create_speech_response(async_client):
    """Test the /speech endpoint returns audio bytes."""
    test_client, mock_services = await anext(async_client)
    mock_audio_bytes = b"fake wav audio bytes"

    # The route handler calls `create_speech_non_stream`, not `create_speech_stream`
    # It's an async function, so we must use AsyncMock
    mock_services.tts_service.create_speech_non_stream = AsyncMock(return_value=mock_audio_bytes)

    response = await test_client.post("/speech", json={"text": "Hello, this is a test"})

    assert response.status_code == 200
    assert response.headers["content-type"] == "audio/wav"
    assert response.content == mock_audio_bytes

    mock_services.tts_service.create_speech_non_stream.assert_called_once_with(text="Hello, this is a test")

@pytest.mark.asyncio
async def test_create_speech_stream_response(async_client):
    """Test the consolidated /speech endpoint with stream=true returns a streaming response."""
    test_client, mock_services = await anext(async_client)
    mock_audio_bytes_chunks = [b"chunk1", b"chunk2", b"chunk3"]

    # This async generator mock correctly simulates the streaming service
    async def mock_async_generator():
        for chunk in mock_audio_bytes_chunks:
            yield chunk

    # We mock `create_speech_stream` with a MagicMock returning the async generator
    mock_services.tts_service.create_speech_stream = MagicMock(return_value=mock_async_generator())

    # Correct the endpoint URL to use the consolidated /speech endpoint with the stream query parameter
    response = await test_client.post("/speech?stream=true", json={"text": "Hello, this is a test"})

    assert response.status_code == 200
    assert response.headers["content-type"] == "audio/wav"

    # Read the streamed content and verify it matches the mocked chunks
    streamed_content = b""
    async for chunk in response.aiter_bytes():
        streamed_content += chunk

    assert streamed_content == b"".join(mock_audio_bytes_chunks)
    mock_services.tts_service.create_speech_stream.assert_called_once_with(text="Hello, this is a test")