# tests/api/test_routes.py import pytest from unittest.mock import MagicMock, AsyncMock from fastapi import FastAPI from fastapi.testclient import TestClient from sqlalchemy.orm import Session from datetime import datetime from httpx import AsyncClient, ASGITransport # Import the dependencies and router factory from app.api.dependencies import get_db, ServiceContainer from app.core.services.rag import RAGService from app.core.services.document import DocumentService from app.core.services.tts import TTSService from app.api.routes import create_api_router from app.db import models # Import your SQLAlchemy models @pytest.fixture def client(): """ Pytest fixture to create a TestClient with a fully mocked environment, including a mock ServiceContainer. """ test_app = FastAPI() # Mock individual services mock_rag_service = MagicMock(spec=RAGService) mock_document_service = MagicMock(spec=DocumentService) # Use AsyncMock for the TTS service since its methods are async mock_tts_service = MagicMock(spec=TTSService) # Create a mock ServiceContainer that holds the mocked services mock_services = MagicMock(spec=ServiceContainer) mock_services.rag_service = mock_rag_service mock_services.document_service = mock_document_service mock_services.tts_service = mock_tts_service # Mock the database session mock_db_session = MagicMock(spec=Session) def override_get_db(): yield mock_db_session # Pass the mock ServiceContainer to the router factory api_router = create_api_router(services=mock_services) test_app.dependency_overrides[get_db] = override_get_db test_app.include_router(api_router) # Return the test client and the mock services for assertion yield TestClient(test_app), mock_services # --- General Endpoint --- def test_read_root(client): """Tests the root endpoint.""" test_client, _ = client response = test_client.get("/") assert response.status_code == 200 assert response.json() == {"status": "AI Model Hub is running!"} # --- Session and Chat Endpoints --- def test_create_session_success(client): """Tests successfully creating a new chat session.""" test_client, mock_services = client mock_session = models.Session(id=1, user_id="test_user", model_name="gemini", title="New Chat", created_at=datetime.now()) mock_services.rag_service.create_session.return_value = mock_session response = test_client.post("/sessions", json={"user_id": "test_user", "model": "gemini"}) assert response.status_code == 200 assert response.json()["id"] == 1 mock_services.rag_service.create_session.assert_called_once() def test_chat_in_session_success(client): """ Tests sending a message in an existing session without specifying a model or retriever. It should default to 'deepseek' and 'False'. """ test_client, mock_services = client mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek")) response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there"}) assert response.status_code == 200 assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"} mock_services.rag_service.chat_with_rag.assert_called_once_with( db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="Hello there", model="deepseek", load_faiss_retriever=False ) def test_chat_in_session_with_model_switch(client): """ Tests sending a message in an existing session and explicitly switching the model. """ test_client, mock_services = client mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini")) response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"}) assert response.status_code == 200 assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"} mock_services.rag_service.chat_with_rag.assert_called_once_with( db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="Hello there, Gemini!", model="gemini", load_faiss_retriever=False ) def test_chat_in_session_with_faiss_retriever(client): """ Tests sending a message and explicitly enabling the FAISS retriever. """ test_client, mock_services = client mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Response with context", "deepseek")) response = test_client.post( "/sessions/42/chat", json={"prompt": "What is RAG?", "load_faiss_retriever": True} ) assert response.status_code == 200 assert response.json() == {"answer": "Response with context", "model_used": "deepseek"} mock_services.rag_service.chat_with_rag.assert_called_once_with( db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], session_id=42, prompt="What is RAG?", model="deepseek", load_faiss_retriever=True ) def test_get_session_messages_success(client): """Tests retrieving the message history for a session.""" test_client, mock_services = client mock_history = [ models.Message(sender="user", content="Hello", created_at=datetime.now()), models.Message(sender="assistant", content="Hi there!", created_at=datetime.now()) ] mock_services.rag_service.get_message_history.return_value = mock_history response = test_client.get("/sessions/123/messages") assert response.status_code == 200 response_data = response.json() assert response_data["session_id"] == 123 assert len(response_data["messages"]) == 2 assert response_data["messages"][0]["sender"] == "user" assert response_data["messages"][1]["content"] == "Hi there!" mock_services.rag_service.get_message_history.assert_called_once_with( db=mock_services.rag_service.get_message_history.call_args.kwargs['db'], session_id=123 ) def test_get_session_messages_not_found(client): """Tests retrieving messages for a session that does not exist.""" test_client, mock_services = client mock_services.rag_service.get_message_history.return_value = None response = test_client.get("/sessions/999/messages") assert response.status_code == 404 assert response.json()["detail"] == "Session with ID 999 not found." # --- Document Endpoints --- def test_add_document_success(client): """Tests the /documents endpoint for adding a new document.""" test_client, mock_services = client mock_services.document_service.add_document.return_value = 123 doc_payload = {"title": "Test Doc", "text": "Content here"} response = test_client.post("/documents", json=doc_payload) assert response.status_code == 200 assert response.json()["message"] == "Document 'Test Doc' added successfully with ID 123" def test_get_documents_success(client): """Tests the /documents endpoint for retrieving all documents.""" test_client, mock_services = client mock_docs = [ models.Document(id=1, title="Doc One", status="ready", created_at=datetime.now()), models.Document(id=2, title="Doc Two", status="processing", created_at=datetime.now()) ] mock_services.document_service.get_all_documents.return_value = mock_docs response = test_client.get("/documents") assert response.status_code == 200 assert len(response.json()["documents"]) == 2 def test_delete_document_success(client): """Tests the DELETE /documents/{document_id} endpoint for successful deletion.""" test_client, mock_services = client mock_services.document_service.delete_document.return_value = 42 response = test_client.delete("/documents/42") assert response.status_code == 200 assert response.json()["document_id"] == 42 def test_delete_document_not_found(client): """Tests the DELETE /documents/{document_id} endpoint when the document is not found.""" test_client, mock_services = client mock_services.document_service.delete_document.return_value = None response = test_client.delete("/documents/999") assert response.status_code == 404 # --- TTS Endpoint --- @pytest.mark.anyio async def test_create_speech_stream_success(client): """ Tests the /speech endpoint to ensure it can successfully generate an audio stream. """ test_client, mock_services = client app = test_client.app # Get the FastAPI app from the TestClient # Arrange: Define the text to convert and mock the service's response. text_to_speak = "Hello, world!" # Define the async generator async def mock_audio_generator(): yield b'chunk1' yield b'chunk2' yield b'chunk3' # Properly mock the method to return the generator mock_services.tts_service.create_speech_stream = lambda text: mock_audio_generator() # Use AsyncClient with ASGITransport to send request to the FastAPI app async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac: response = await ac.post("/speech", json={"text": text_to_speak}) # Assert: Check status code and content assert response.status_code == 200 assert response.headers["content-type"] == "audio/wav" assert response.content == b"chunk1chunk2chunk3"