import os from fastapi.testclient import TestClient from unittest.mock import patch, MagicMock, AsyncMock from sqlalchemy.orm import Session # Import the factory function directly to get a fresh app instance for testing from app.app import create_app # The get_db function is now in app/db_setup.py, so we must update the import path. from app.db_setup import get_db # --- Dependency Override for Testing --- # This is a mock database session that will be used in our tests. mock_db = MagicMock(spec=Session) def override_get_db(): """Returns the mock database session for tests.""" try: yield mock_db finally: pass # --- API Endpoint Tests --- # We patch the RAGService class itself, as the instance is created inside create_app(). # This test does not require mocking, so the app can be created at the module level. # For consistency, we can still move it inside a function if preferred. def test_read_root(): """Test the root endpoint to ensure it's running.""" # Create app and client here to be sure no mocking interferes app = create_app() client = TestClient(app) response = client.get("/") assert response.status_code == 200 assert response.json() == {"status": "AI Model Hub is running!"} @patch('app.app.RAGService') def test_chat_handler_success(mock_rag_service_class): """ Test the /chat endpoint with a successful, mocked RAG service response. We patch the RAGService class and configure a mock instance with a controlled return value. """ # Create a mock instance of RAGService that will be returned by the factory mock_rag_service_instance = mock_rag_service_class.return_value mock_rag_service_instance.chat_with_rag = AsyncMock(return_value="This is a mock response from the RAG service.") # Now create the app and client, so the patch takes effect. app = create_app() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) # Make the request to our app response = client.post("/chat", json={"prompt": "Hello there"}) # Assert our app behaved as expected assert response.status_code == 200 assert response.json()["response"] == "This is a mock response from the RAG service." # Verify that the mocked method was called with the correct arguments mock_rag_service_instance.chat_with_rag.assert_called_once_with( db=mock_db, prompt="Hello there", model="deepseek" ) @patch('app.app.RAGService') def test_chat_handler_api_failure(mock_rag_service_class): """ Test the /chat endpoint when the RAG service encounters an error. We configure the mock RAGService instance's chat_with_rag method to raise an exception. """ # Create a mock instance of RAGService mock_rag_service_instance = mock_rag_service_class.return_value mock_rag_service_instance.chat_with_rag = AsyncMock(side_effect=Exception("API connection error")) # Now create the app and client, so the patch takes effect. app = create_app() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) # Make the request to our app response = client.post("/chat", json={"prompt": "This request will fail"}) # Assert our app handles the error gracefully assert response.status_code == 500 assert "An error occurred with the deepseek API" in response.json()["detail"] # Verify that the mocked method was called with the correct arguments mock_rag_service_instance.chat_with_rag.assert_called_once_with( db=mock_db, prompt="This request will fail", model="deepseek" ) @patch('app.app.RAGService') def test_add_document_success(mock_rag_service_class): """ Test the /document endpoint with a successful, mocked RAG service response. """ # Create a mock instance of RAGService mock_rag_service_instance = mock_rag_service_class.return_value mock_rag_service_instance.add_document.return_value = 1 # Now create the app and client, so the patch takes effect. app = create_app() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) doc_data = { "title": "Test Document", "text": "This is a test document.", "source_url": "http://example.com/test" } response = client.post("/document", json=doc_data) assert response.status_code == 200 assert response.json()["message"] == "Document 'Test Document' added successfully with ID 1" # Verify that the mocked method was called with the correct arguments, # including the default values added by Pydantic. expected_doc_data = doc_data.copy() expected_doc_data.update({"author": None, "user_id": "default_user"}) mock_rag_service_instance.add_document.assert_called_once_with(db=mock_db, doc_data=expected_doc_data) @patch('app.app.RAGService') def test_add_document_api_failure(mock_rag_service_class): """ Test the /document endpoint when the RAG service encounters an error. """ # Create a mock instance of RAGService mock_rag_service_instance = mock_rag_service_class.return_value mock_rag_service_instance.add_document.side_effect = Exception("Service failed") # Now create the app and client, so the patch takes effect. app = create_app() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) doc_data = { "title": "Test Document", "text": "This is a test document.", "source_url": "http://example.com/test" } response = client.post("/document", json=doc_data) assert response.status_code == 500 assert "An error occurred: Service failed" in response.json()["detail"] # Verify that the mocked method was called with the correct arguments, # including the default values added by Pydantic. expected_doc_data = doc_data.copy() expected_doc_data.update({"author": None, "user_id": "default_user"}) mock_rag_service_instance.add_document.assert_called_once_with(db=mock_db, doc_data=expected_doc_data)