# tests/api/test_dependencies.py import pytest import asyncio from unittest.mock import MagicMock, patch from sqlalchemy.orm import Session from fastapi import HTTPException # Import the dependencies and services to be tested from app.api.dependencies import get_db, get_current_user, ServiceContainer from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.vector_store.faiss_store import FaissVectorStore from app.core.retrievers import Retriever @pytest.fixture def mock_session(): """ Fixture that provides a mock SQLAlchemy session. """ mock = MagicMock(spec=Session) yield mock # --- Tests for get_db dependency --- @patch('app.api.dependencies.SessionLocal') def test_get_db_yields_session_and_closes(mock_session_local, mock_session): """ Tests that get_db yields a database session and ensures it's closed correctly. """ # Arrange: Configure the mock SessionLocal to return our mock_session mock_session_local.return_value = mock_session # Act: Use the generator in a context manager db_generator = get_db() db = next(db_generator) # Assert 1: The correct session object was yielded assert db == mock_session # Act 2: Manually close the generator with pytest.raises(StopIteration): next(db_generator) # Assert 2: The session's close method was called mock_session.close.assert_called_once() @patch('app.api.dependencies.SessionLocal') def test_get_db_closes_on_exception(mock_session_local, mock_session): """ Tests that get_db still closes the session even if an exception occurs. """ # Arrange: Configure the mock SessionLocal to return our mock_session mock_session_local.return_value = mock_session # Act & Assert: Call the generator and raise an exception db_generator = get_db() db = next(db_generator) with pytest.raises(Exception): db_generator.throw(Exception("Test exception")) # Assert: The session's close method was still called after the exception was handled mock_session.close.assert_called_once() # --- Tests for get_current_user dependency --- def test_get_current_user_with_valid_token(): """ Tests that get_current_user returns the expected user dictionary for a valid token. """ # Act user = asyncio.run(get_current_user(token="valid_token")) # Assert assert user == {"email": "user@example.com", "id": 1} def test_get_current_user_with_no_token(): """ Tests that get_current_user raises an HTTPException for a missing token. """ # Assert with pytest.raises(HTTPException) as excinfo: asyncio.run(get_current_user(token=None)) assert excinfo.value.status_code == 401 assert "Unauthorized" in excinfo.value.detail # --- Tests for ServiceContainer class --- def test_service_container_initialization(): """ Tests that ServiceContainer initializes DocumentService and RAGService with the correct dependencies. """ # Arrange: Create mock dependencies mock_vector_store = MagicMock(spec=FaissVectorStore) # The DocumentService constructor needs a .embedder attribute on the vector_store mock_vector_store.embedder = MagicMock() mock_retrievers = [MagicMock(spec=Retriever), MagicMock(spec=Retriever)] # Act: Instantiate the ServiceContainer container = ServiceContainer( vector_store=mock_vector_store, retrievers=mock_retrievers ) # Assert: Check if the services were created and configured correctly assert isinstance(container.document_service, DocumentService) assert container.document_service.vector_store == mock_vector_store assert isinstance(container.rag_service, RAGService) assert container.rag_service.retrievers == mock_retrievers