# import pytest # from unittest.mock import MagicMock, AsyncMock # from fastapi import FastAPI, Response # from fastapi.testclient import TestClient # from sqlalchemy.orm import Session # from datetime import datetime # from httpx import AsyncClient, ASGITransport # import asyncio # # Import the dependencies and router factory # from app.api.dependencies import get_db, ServiceContainer # from app.core.services.rag import RAGService # from app.core.services.document import DocumentService # from app.core.services.tts import TTSService # from app.api.routes import create_api_router # from app.db import models # @pytest.fixture # def client(): # """ # Pytest fixture to create a TestClient with a fully mocked environment # for synchronous endpoints. # """ # test_app = FastAPI() # mock_rag_service = MagicMock(spec=RAGService) # mock_document_service = MagicMock(spec=DocumentService) # mock_tts_service = MagicMock(spec=TTSService) # mock_services = MagicMock(spec=ServiceContainer) # mock_services.rag_service = mock_rag_service # mock_services.document_service = mock_document_service # mock_services.tts_service = mock_tts_service # mock_db_session = MagicMock(spec=Session) # def override_get_db(): # yield mock_db_session # api_router = create_api_router(services=mock_services) # test_app.dependency_overrides[get_db] = override_get_db # test_app.include_router(api_router) # test_client = TestClient(test_app) # yield test_client, mock_services # @pytest.fixture # async def async_client(): # """ # Pytest fixture to create an AsyncClient for testing async endpoints. # """ # test_app = FastAPI() # mock_rag_service = MagicMock(spec=RAGService) # mock_document_service = MagicMock(spec=DocumentService) # mock_tts_service = MagicMock(spec=TTSService) # mock_services = MagicMock(spec=ServiceContainer) # mock_services.rag_service = mock_rag_service # mock_services.document_service = mock_document_service # mock_services.tts_service = mock_tts_service # mock_db_session = MagicMock(spec=Session) # def override_get_db(): # yield mock_db_session # api_router = create_api_router(services=mock_services) # test_app.dependency_overrides[get_db] = override_get_db # test_app.include_router(api_router) # async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client: # yield client, mock_services # # --- General Endpoint --- # def test_read_root(client): # """Tests the root endpoint.""" # test_client, _ = client # response = test_client.get("/") # assert response.status_code == 200 # assert response.json() == {"status": "AI Model Hub is running!"} # # --- Session and Chat Endpoints --- # def test_create_session_success(client): # """Tests successfully creating a new chat session.""" # test_client, mock_services = client # mock_session = models.Session(id=1, user_id="test_user", model_name="gemini", title="New Chat", created_at=datetime.now()) # mock_services.rag_service.create_session.return_value = mock_session # response = test_client.post("/sessions", json={"user_id": "test_user", "model": "gemini"}) # assert response.status_code == 200 # assert response.json()["id"] == 1 # mock_services.rag_service.create_session.assert_called_once() # def test_chat_in_session_success(client): # """ # Tests sending a message in an existing session without specifying a model # or retriever. It should default to 'deepseek' and 'False'. # """ # test_client, mock_services = client # mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response", "deepseek")) # response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there"}) # assert response.status_code == 200 # assert response.json() == {"answer": "Mocked response", "model_used": "deepseek"} # mock_services.rag_service.chat_with_rag.assert_called_once_with( # db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], # session_id=42, # prompt="Hello there", # model="deepseek", # load_faiss_retriever=False # ) # def test_chat_in_session_with_model_switch(client): # """ # Tests sending a message in an existing session and explicitly switching the model. # """ # test_client, mock_services = client # mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini")) # response = test_client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"}) # assert response.status_code == 200 # assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"} # mock_services.rag_service.chat_with_rag.assert_called_once_with( # db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], # session_id=42, # prompt="Hello there, Gemini!", # model="gemini", # load_faiss_retriever=False # ) # def test_chat_in_session_with_faiss_retriever(client): # """ # Tests sending a message and explicitly enabling the FAISS retriever. # """ # test_client, mock_services = client # mock_services.rag_service.chat_with_rag = AsyncMock(return_value=("Response with context", "deepseek")) # response = test_client.post( # "/sessions/42/chat", # json={"prompt": "What is RAG?", "load_faiss_retriever": True} # ) # assert response.status_code == 200 # assert response.json() == {"answer": "Response with context", "model_used": "deepseek"} # mock_services.rag_service.chat_with_rag.assert_called_once_with( # db=mock_services.rag_service.chat_with_rag.call_args.kwargs['db'], # session_id=42, # prompt="What is RAG?", # model="deepseek", # load_faiss_retriever=True # ) # def test_get_session_messages_success(client): # """Tests retrieving the message history for a session.""" # test_client, mock_services = client # mock_history = [ # models.Message(sender="user", content="Hello", created_at=datetime.now()), # models.Message(sender="assistant", content="Hi there!", created_at=datetime.now()) # ] # mock_services.rag_service.get_message_history.return_value = mock_history # response = test_client.get("/sessions/123/messages") # assert response.status_code == 200 # response_data = response.json() # assert response_data["session_id"] == 123 # assert len(response_data["messages"]) == 2 # assert response_data["messages"][0]["sender"] == "user" # assert response_data["messages"][1]["content"] == "Hi there!" # mock_services.rag_service.get_message_history.assert_called_once_with( # db=mock_services.rag_service.get_message_history.call_args.kwargs['db'], # session_id=123 # ) # def test_get_session_messages_not_found(client): # """Tests retrieving messages for a session that does not exist.""" # test_client, mock_services = client # mock_services.rag_service.get_message_history.return_value = None # response = test_client.get("/sessions/999/messages") # assert response.status_code == 404 # assert response.json()["detail"] == "Session with ID 999 not found." # # --- Document Endpoints --- # def test_add_document_success(client): # """Tests the /documents endpoint for adding a new document.""" # test_client, mock_services = client # mock_services.document_service.add_document.return_value = 123 # doc_payload = {"title": "Test Doc", "text": "Content here"} # response = test_client.post("/documents", json=doc_payload) # assert response.status_code == 200 # assert response.json()["message"] == "Document 'Test Doc' added successfully with ID 123" # def test_get_documents_success(client): # """Tests the /documents endpoint for retrieving all documents.""" # test_client, mock_services = client # mock_docs = [ # models.Document(id=1, title="Doc One", status="ready", created_at=datetime.now()), # models.Document(id=2, title="Doc Two", status="processing", created_at=datetime.now()) # ] # mock_services.document_service.get_all_documents.return_value = mock_docs # response = test_client.get("/documents") # assert response.status_code == 200 # assert len(response.json()["documents"]) == 2 # def test_delete_document_success(client): # """Tests the DELETE /documents/{document_id} endpoint for successful deletion.""" # test_client, mock_services = client # mock_services.document_service.delete_document.return_value = 42 # response = test_client.delete("/documents/42") # assert response.status_code == 200 # assert response.json()["document_id"] == 42 # def test_delete_document_not_found(client): # """Tests the DELETE /documents/{document_id} endpoint when the document is not found.""" # test_client, mock_services = client # mock_services.document_service.delete_document.return_value = None # response = test_client.delete("/documents/999") # assert response.status_code == 404 # @pytest.mark.asyncio # async def test_create_speech_response(async_client): # """Test the /speech endpoint returns audio bytes.""" # test_client, mock_services = await anext(async_client) # mock_audio_bytes = b"fake wav audio bytes" # # The route handler calls `create_speech_non_stream`, not `create_speech_stream` # # It's an async function, so we must use AsyncMock # mock_services.tts_service.create_speech_non_stream = AsyncMock(return_value=mock_audio_bytes) # response = await test_client.post("/speech", json={"text": "Hello, this is a test"}) # assert response.status_code == 200 # assert response.headers["content-type"] == "audio/wav" # assert response.content == mock_audio_bytes # mock_services.tts_service.create_speech_non_stream.assert_called_once_with(text="Hello, this is a test") # @pytest.mark.asyncio # async def test_create_speech_stream_response(async_client): # """Test the consolidated /speech endpoint with stream=true returns a streaming response.""" # test_client, mock_services = await anext(async_client) # mock_audio_bytes_chunks = [b"chunk1", b"chunk2", b"chunk3"] # # This async generator mock correctly simulates the streaming service # async def mock_async_generator(): # for chunk in mock_audio_bytes_chunks: # yield chunk # # We mock `create_speech_stream` with a MagicMock returning the async generator # mock_services.tts_service.create_speech_stream = MagicMock(return_value=mock_async_generator()) # # Correct the endpoint URL to use the consolidated /speech endpoint with the stream query parameter # response = await test_client.post("/speech?stream=true", json={"text": "Hello, this is a test"}) # assert response.status_code == 200 # assert response.headers["content-type"] == "audio/wav" # # Read the streamed content and verify it matches the mocked chunks # streamed_content = b"" # async for chunk in response.aiter_bytes(): # streamed_content += chunk # assert streamed_content == b"".join(mock_audio_bytes_chunks) # mock_services.tts_service.create_speech_stream.assert_called_once_with(text="Hello, this is a test")