import pytest from unittest.mock import MagicMock, AsyncMock from fastapi import FastAPI, Response from fastapi.testclient import TestClient from sqlalchemy.orm import Session from datetime import datetime from httpx import AsyncClient, ASGITransport import asyncio # Import the dependencies and router factory from app.api.dependencies import get_db, ServiceContainer from app.core.services.rag import RAGService from app.core.services.document import DocumentService from app.core.services.tts import TTSService from app.api.routes import create_api_router from app.db import models @pytest.fixture def client(): """ Pytest fixture to create a TestClient with a fully mocked environment for synchronous endpoints. """ test_app = FastAPI() mock_rag_service = MagicMock(spec=RAGService) mock_document_service = MagicMock(spec=DocumentService) mock_tts_service = MagicMock(spec=TTSService) mock_services = MagicMock(spec=ServiceContainer) mock_services.rag_service = mock_rag_service mock_services.document_service = mock_document_service mock_services.tts_service = mock_tts_service mock_db_session = MagicMock(spec=Session) def override_get_db(): yield mock_db_session api_router = create_api_router(services=mock_services) test_app.dependency_overrides[get_db] = override_get_db test_app.include_router(api_router) test_client = TestClient(test_app) yield test_client, mock_services @pytest.fixture async def async_client(): """ Pytest fixture to create an AsyncClient for testing async endpoints. """ test_app = FastAPI() mock_rag_service = MagicMock(spec=RAGService) mock_document_service = MagicMock(spec=DocumentService) mock_tts_service = MagicMock(spec=TTSService) mock_services = MagicMock(spec=ServiceContainer) mock_services.rag_service = mock_rag_service mock_services.document_service = mock_document_service mock_services.tts_service = mock_tts_service mock_db_session = MagicMock(spec=Session) def override_get_db(): yield mock_db_session api_router = create_api_router(services=mock_services) test_app.dependency_overrides[get_db] = override_get_db test_app.include_router(api_router) async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client: yield client, mock_services # --- General Endpoint --- def test_read_root(client): """Tests the root endpoint.""" test_client, _ = client response = test_client.get("/") assert response.status_code == 200 assert response.json() == {"status": "AI Model Hub is running!"} # --- Session and Chat Endpoints --- def test_create_session_success(client): """Tests successfully creating a new chat session.""" test_client, mock_services = client mock_session = models.Session(id=1, user_id="test_user", model_name="gemini", title="New Chat", created_at=datetime.now()) mock_services.rag_service.create_session.return_value = mock_session response = test_client.post("/sessions", json={"user_id": "test_user", "model": "gemini"}) assert response.status_code == 200 assert response.json()["id"] == 1 mock_services.rag_service.create_session.assert_called_once() def test_chat_in_session_success(client): """ Tests sending a message in an existing session without specifying a model or retriever. It should default to 'deepseek' and 'False'. """ test_client, mock_services = client mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek")) response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there"}) assert response.status_code == 200 assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"} mock_services.rag_service.chat_with_rag.assert_called_once_with( db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="Hello there", model="deepseek", load_faiss_retriever=False ) def test_chat_in_session_with_model_switch(client): """ Tests sending a message in an existing session and explicitly switching the model. """ test_client, mock_services = client mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini")) response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"}) assert response.status_code == 200 assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"} mock_services.rag_service.chat_with_rag.assert_called_once_with( db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="Hello there, Gemini!", model="gemini", load_faiss_retriever=False ) def test_chat_in_session_with_faiss_retriever(client): """ Tests sending a message and explicitly enabling the FAISS retriever. """ test_client, mock_services = client mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Response with context", "deepseek")) response = test_client.post( "/sessions/42/chat", json={"prompt": "What is RAG?", "load_faiss_retriever": True} ) assert response.status_code == 200 assert response.json() == {"answer": "Response with context", "model_used": "deepseek"} mock_services.rag_service.chat_with_rag.assert_called_once_with( db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="What is RAG?", model="deepseek", load_faiss_retriever=True ) def test_get_session_messages_success(client): """Tests retrieving the message history for a session.""" test_client, mock_services = client mock_history = [ models.Message(sender="user", content="Hello", created_at=datetime.now()), models.Message(sender="assistant", content="Hi there!", created_at=datetime.now()) ] mock_services.rag_service.get_message_history.return_value = mock_history response = test_client.get("/sessions/123/messages") assert response.status_code == 200 response_data = response.json() assert response_data["session_id"] == 123 assert len(response_data["messages"]) == 2 assert response_data["messages"][0]["sender"] == "user" assert response_data["messages"][1]["content"] == "Hi there!" mock_services.rag_service.get_message_history.assert_called_once_with( db=mock_services.rag_service.get_message_history.call_args.kwargs['db'], session_id=123 ) def test_get_session_messages_not_found(client): """Tests retrieving messages for a session that does not exist.""" test_client, mock_services = client mock_services.rag_service.get_message_history.return_value = None response = test_client.get("/sessions/999/messages") assert response.status_code == 404 assert response.json()["detail"] == "Session with ID 999 not found." # --- Document Endpoints --- def test_add_document_success(client): """Tests the /documents endpoint for adding a new document.""" test_client, mock_services = client mock_services.document_service.add_document.return_value = 123 doc_payload = {"title": "Test Doc", "text": "Content here"} response = test_client.post("/documents", json=doc_payload) assert response.status_code == 200 assert response.json()["message"] == "Document 'Test Doc' added successfully with ID 123" def test_get_documents_success(client): """Tests the /documents endpoint for retrieving all documents.""" test_client, mock_services = client mock_docs = [ models.Document(id=1, title="Doc One", status="ready", created_at=datetime.now()), models.Document(id=2, title="Doc Two", status="processing", created_at=datetime.now()) ] mock_services.document_service.get_all_documents.return_value = mock_docs response = test_client.get("/documents") assert response.status_code == 200 assert len(response.json()["documents"]) == 2 def test_delete_document_success(client): """Tests the DELETE /documents/{document_id} endpoint for successful deletion.""" test_client, mock_services = client mock_services.document_service.delete_document.return_value = 42 response = test_client.delete("/documents/42") assert response.status_code == 200 assert response.json()["document_id"] == 42 def test_delete_document_not_found(client): """Tests the DELETE /documents/{document_id} endpoint when the document is not found.""" test_client, mock_services = client mock_services.document_service.delete_document.return_value = None response = test_client.delete("/documents/999") assert response.status_code == 404 @pytest.mark.asyncio async def test_create_speech_response(async_client): """Test the /speech endpoint returns audio bytes.""" test_client, mock_services = await anext(async_client) mock_audio_bytes = b"fake wav audio bytes" # The route handler calls `create_speech_non_stream`, not `create_speech_stream` # It's an async function, so we must use AsyncMock mock_services.tts_service.create_speech_non_stream = AsyncMock(return_value=mock_audio_bytes) response = await test_client.post("/speech", json={"text": "Hello, this is a test"}) assert response.status_code == 200 assert response.headers["content-type"] == "audio/wav" assert response.content == mock_audio_bytes mock_services.tts_service.create_speech_non_stream.assert_called_once_with(text="Hello, this is a test") # New test to cover the streaming endpoint @pytest.mark.asyncio async def test_create_speech_stream_response(async_client): """Test the new /speech/stream endpoint returns a streaming response.""" test_client, mock_services = await anext(async_client) mock_audio_bytes_chunks = [b"chunk1", b"chunk2", b"chunk3"] # This async generator mock correctly simulates the streaming service async def mock_async_generator(): for chunk in mock_audio_bytes_chunks: yield chunk # We mock `create_speech_stream` with a MagicMock returning the async generator mock_services.tts_service.create_speech_stream = MagicMock(return_value=mock_async_generator()) response = await test_client.post("/speech/stream", json={"text": "Hello, this is a test"}) assert response.status_code == 200 assert response.headers["content-type"] == "audio/wav" # Read the streamed content and verify it matches the mocked chunks streamed_content = b"" async for chunk in response.aiter_bytes(): streamed_content += chunk assert streamed_content == b"".join(mock_audio_bytes_chunks) mock_services.tts_service.create_speech_stream.assert_called_once_with(text="Hello, this is a test")