# tests/app/test_app.py import os import asyncio from fastapi.testclient import TestClient from unittest.mock import patch, MagicMock, AsyncMock from sqlalchemy.orm import Session from datetime import datetime import numpy as np # Import the factory function directly and dependencies from app.app import create_app from app.api.dependencies import get_db, ServiceContainer from app.db import models from app.core.retrievers.base_retriever import Retriever from app.core.vector_store.faiss_store import FaissVectorStore # Define a constant for the dimension to ensure consistency TEST_DIMENSION = 768 # --- Dependency Override for Testing --- # This is a mock database session that will be used in our tests. mock_db = MagicMock(spec=Session) def override_get_db(): """Returns the mock database session for tests.""" try: yield mock_db finally: pass # We patch ServiceContainer directly to control its instantiation in create_app @patch('app.app.ServiceContainer') @patch('app.app.create_db_and_tables') @patch('app.app.print_config') @patch('app.core.vector_store.embedder.factory.get_embedder_from_config') @patch('faiss.read_index') # This patch is for the FaissVectorStore initialization def test_read_root(mock_read_index, mock_get_embedder, mock_print_config, mock_create_db, mock_service_container): """Test the root endpoint to ensure it's running.""" # Arrange: We patch the embedder and faiss calls to prevent real logic mock_read_index.return_value = MagicMock() mock_get_embedder.return_value = MagicMock() # The mock_service_container is a mock of the ServiceContainer class. # We create an instance of it (mock_services) and configure it. mock_services = MagicMock() mock_service_container.return_value = mock_services app = create_app() client = TestClient(app) response = client.get("/") # Assert assert response.status_code == 200 assert response.json() == {"status": "AI Model Hub is running!"} @patch('app.app.ServiceContainer') @patch('app.app.create_db_and_tables') @patch('app.core.vector_store.embedder.factory.get_embedder_from_config') @patch('faiss.read_index') def test_create_session_success(mock_read_index, mock_get_embedder, mock_create_db, mock_service_container): """ Tests successfully creating a new chat session via the POST /sessions endpoint. """ # Arrange mock_read_index.return_value = MagicMock() mock_get_embedder.return_value = MagicMock() # Create a mock instance of ServiceContainer and its services mock_services = MagicMock() mock_service_container.return_value = mock_services # Configure the mock rag_service to return a mocked session object mock_session_obj = models.Session( id=1, user_id="test_user", model_name="gemini", title="New Chat Session", created_at=datetime.now() ) mock_services.rag_service.create_session.return_value = mock_session_obj app = create_app() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) # Act response = client.post("/sessions", json={"user_id": "test_user", "model": "gemini"}) # Assert assert response.status_code == 200 response_data = response.json() assert response_data["id"] == 1 assert response_data["user_id"] == "test_user" mock_services.rag_service.create_session.assert_called_once_with( db=mock_db, user_id="test_user", model="gemini" ) @patch('app.app.ServiceContainer') @patch('app.app.create_db_and_tables') @patch('app.core.vector_store.embedder.factory.get_embedder_from_config') @patch('faiss.read_index') def test_chat_in_session_success(mock_read_index, mock_get_embedder, mock_create_db, mock_service_container): """ Test the session-based chat endpoint with a successful, mocked response. It should default to 'deepseek' if no model is specified. """ # Arrange mock_read_index.return_value = MagicMock() mock_get_embedder.return_value = MagicMock() # Create a mock instance of ServiceContainer and its services mock_services = MagicMock() mock_service_container.return_value = mock_services # Correctly mock the async method using AsyncMock mock_chat_with_rag = AsyncMock(return_value=("This is a mock response.", "deepseek")) mock_services.rag_service.chat_with_rag = mock_chat_with_rag app = create_app() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) # Act response = client.post("/sessions/123/chat", json={"prompt": "Hello there"}) # Assert assert response.status_code == 200 assert response.json()["answer"] == "This is a mock response." assert response.json()["model_used"] == "deepseek" mock_services.rag_service.chat_with_rag.assert_called_once_with( db=mock_db, session_id=123, prompt="Hello there", model="deepseek", load_faiss_retriever=False ) @patch('app.app.ServiceContainer') @patch('app.app.create_db_and_tables') @patch('app.core.vector_store.embedder.factory.get_embedder_from_config') @patch('faiss.read_index') def test_chat_in_session_with_model_switch(mock_read_index, mock_get_embedder, mock_create_db, mock_service_container): """ Tests sending a message in an existing session and explicitly switching the model. """ # Arrange mock_read_index.return_value = MagicMock() mock_get_embedder.return_value = MagicMock() # Create a mock instance of ServiceContainer and its services mock_services = MagicMock() mock_service_container.return_value = mock_services # Correctly mock the async method using AsyncMock mock_chat_with_rag = AsyncMock(return_value=("Mocked response from Gemini", "gemini")) mock_services.rag_service.chat_with_rag = mock_chat_with_rag app = create_app() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) response = client.post("/sessions/42/chat", json={"prompt": "Hello there, Gemini!", "model": "gemini"}) # Assert assert response.status_code == 200 assert response.json() == {"answer": "Mocked response from Gemini", "model_used": "gemini"} mock_services.rag_service.chat_with_rag.assert_called_once_with( db=mock_db, session_id=42, prompt="Hello there, Gemini!", model="gemini", load_faiss_retriever=False ) @patch('app.app.ServiceContainer') @patch('app.app.create_db_and_tables') @patch('app.core.vector_store.embedder.factory.get_embedder_from_config') @patch('faiss.read_index') def test_add_document_success(mock_read_index, mock_get_embedder, mock_create_db, mock_service_container): """ Test the /document endpoint with a successful, mocked RAG service response. """ # Arrange mock_read_index.return_value = MagicMock() mock_get_embedder.return_value = MagicMock() # Create a mock instance of ServiceContainer and its services mock_services = MagicMock() mock_service_container.return_value = mock_services mock_services.document_service.add_document.return_value = 1 app = create_app() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) doc_data = { "title": "Test Document", "text": "This is a test document.", "source_url": "http://example.com/test" } # Act response = client.post("/documents", json=doc_data) # Assert assert response.status_code == 200 assert response.json()["message"] == "Document 'Test Document' added successfully with ID 1" expected_doc_data = doc_data.copy() expected_doc_data.update({"author": None, "user_id": "default_user"}) mock_services.document_service.add_document.assert_called_once_with(db=mock_db, doc_data=expected_doc_data) @patch('app.app.ServiceContainer') @patch('app.app.create_db_and_tables') @patch('app.core.vector_store.embedder.factory.get_embedder_from_config') @patch('faiss.read_index') def test_add_document_api_failure(mock_read_index, mock_get_embedder, mock_create_db, mock_service_container): """ Test the /document endpoint when the RAG service encounters an error. """ # Arrange mock_read_index.return_value = MagicMock() mock_get_embedder.return_value = MagicMock() # Create a mock instance of ServiceContainer and its services mock_services = MagicMock() mock_service_container.return_value = mock_services mock_services.document_service.add_document.side_effect = Exception("Service failed") app = create_app() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) doc_data = { "title": "Test Document", "text": "This is a test document.", "source_url": "http://example.com/test" } # Act response = client.post("/documents", json=doc_data) # Assert assert response.status_code == 500 assert "An error occurred: Service failed" in response.json()["detail"] expected_doc_data = doc_data.copy() expected_doc_data.update({"author": None, "user_id": "default_user"}) mock_services.document_service.add_document.assert_called_once_with(db=mock_db, doc_data=expected_doc_data) @patch('app.app.ServiceContainer') @patch('app.app.create_db_and_tables') @patch('app.core.vector_store.embedder.factory.get_embedder_from_config') @patch('faiss.read_index') def test_get_documents_success(mock_read_index, mock_get_embedder, mock_create_db, mock_service_container): """ Tests the /documents endpoint for successful retrieval of documents. """ # Arrange mock_read_index.return_value = MagicMock() mock_get_embedder.return_value = MagicMock() # Create a mock instance of ServiceContainer and its services mock_services = MagicMock() mock_service_container.return_value = mock_services mock_docs = [ models.Document(id=1, title="Doc One", status="ready", created_at=datetime.now()), models.Document(id=2, title="Doc Two", status="processing", created_at=datetime.now()) ] mock_services.document_service.get_all_documents.return_value = mock_docs app = create_app() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) # Act response = client.get("/documents") assert response.status_code == 200 assert len(response.json()["documents"]) == 2 assert response.json()["documents"][0]["title"] == "Doc One" mock_services.document_service.get_all_documents.assert_called_once_with(db=mock_db) @patch('app.app.ServiceContainer') @patch('app.app.create_db_and_tables') @patch('app.core.vector_store.embedder.factory.get_embedder_from_config') @patch('faiss.read_index') def test_delete_document_success(mock_read_index, mock_get_embedder, mock_create_db, mock_service_container): """ Tests the DELETE /documents/{document_id} endpoint for successful deletion. """ # Arrange mock_read_index.return_value = MagicMock() mock_get_embedder.return_value = MagicMock() # Create a mock instance of ServiceContainer and its services mock_services = MagicMock() mock_service_container.return_value = mock_services mock_services.document_service.delete_document.return_value = 42 app = create_app() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) response = client.delete("/documents/42") assert response.status_code == 200 assert response.json()["message"] == "Document deleted successfully" assert response.json()["document_id"] == 42 mock_services.document_service.delete_document.assert_called_once_with(db=mock_db, document_id=42) @patch('app.app.ServiceContainer') @patch('app.app.create_db_and_tables') @patch('app.core.vector_store.embedder.factory.get_embedder_from_config') @patch('faiss.read_index') def test_delete_document_not_found(mock_read_index, mock_get_embedder, mock_create_db, mock_service_container): """ Tests the DELETE /documents/{document_id} endpoint when the document is not found. """ # Arrange mock_read_index.return_value = MagicMock() mock_get_embedder.return_value = MagicMock() # Create a mock instance of ServiceContainer and its services mock_services = MagicMock() mock_service_container.return_value = mock_services mock_services.document_service.delete_document.return_value = None app = create_app() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) response = client.delete("/documents/999") assert response.status_code == 404 assert response.json()["detail"] == "Document with ID 999 not found." mock_services.document_service.delete_document.assert_called_once_with(db=mock_db, document_id=999) # FIX: Add a new test to explicitly check the application shutdown behavior @patch('app.core.vector_store.faiss_store.FaissVectorStore.save_index') @patch('app.app.ServiceContainer') @patch('app.app.create_db_and_tables') @patch('app.app.print_config') @patch('app.core.vector_store.embedder.factory.get_embedder_from_config') @patch('faiss.read_index') def test_shutdown_saves_index(mock_read_index, mock_get_embedder, mock_print_config, mock_create_db, mock_service_container, mock_save_index): """ Tests that the FAISS index is saved on application shutdown. """ # Arrange mock_read_index.return_value = MagicMock() mock_get_embedder.return_value = MagicMock() # Create the app and let the lifespan events run app = create_app() with TestClient(app) as client: # Act: The lifespan shutdown event will run when the 'with' block is exited pass # Assert mock_save_index.assert_called_once()