# tests/api/test_dependencies.py import pytest import asyncio from unittest.mock import MagicMock, patch from sqlalchemy.orm import Session from fastapi import HTTPException # Import the dependencies and services to be tested from app.api.dependencies import get_db, get_current_user, ServiceContainer from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.services.tts import TTSService from app.core.services.stt import STTService from app.core.services.workspace import WorkspaceService from app.core.vector_store.faiss_store import FaissVectorStore from app.core.retrievers.base_retriever import Retriever # --- Fixtures --- @pytest.fixture def mock_session(): """ Fixture that provides a mock SQLAlchemy session. """ mock = MagicMock(spec=Session) yield mock # --- Tests for get_db dependency --- @patch('app.api.dependencies.SessionLocal') def test_get_db_yields_session_and_closes(mock_session_local, mock_session): """ Tests that get_db yields a database session and ensures it's closed correctly. """ # Arrange mock_session_local.return_value = mock_session # Act db_generator = get_db() db = next(db_generator) # Assert: Correct session yielded assert db == mock_session # Act 2: Close generator with pytest.raises(StopIteration): next(db_generator) # Assert: Session closed mock_session.close.assert_called_once() @patch('app.api.dependencies.SessionLocal') def test_get_db_closes_on_exception(mock_session_local, mock_session): """ Tests that get_db still closes the session even if an exception occurs. """ mock_session_local.return_value = mock_session db_generator = get_db() db = next(db_generator) with pytest.raises(Exception): db_generator.throw(Exception("Test exception")) mock_session.close.assert_called_once() # --- Tests for get_current_user dependency --- def test_get_current_user_with_valid_token(): """ Tests that get_current_user returns the expected user dictionary for a valid token. """ user = asyncio.run(get_current_user(token="valid_token")) assert user == {"email": "user@example.com", "id": 1} def test_get_current_user_with_no_token(): """ Tests that get_current_user raises an HTTPException for a missing token. """ with pytest.raises(HTTPException) as excinfo: asyncio.run(get_current_user(token=None)) assert excinfo.value.status_code == 401 assert "Unauthorized" in str(excinfo.value.detail) # --- Tests for ServiceContainer class --- def test_service_container_initialization(): """ Tests that ServiceContainer initializes DocumentService, RAGService, and other services with the correct dependencies. """ # Arrange: Create mock dependencies mock_vector_store = MagicMock(spec=FaissVectorStore) mock_vector_store.embedder = MagicMock() mock_retrievers = [MagicMock(spec=Retriever), MagicMock(spec=Retriever)] mock_tts_service = MagicMock(spec=TTSService) mock_stt_service = MagicMock(spec=STTService) mock_workspace_service = MagicMock(spec=WorkspaceService) # Act container = ServiceContainer( vector_store=mock_vector_store, retrievers=mock_retrievers, tts_service=mock_tts_service, stt_service=mock_stt_service, workspace_service=mock_workspace_service ) # Assert: DocumentService assert isinstance(container.document_service, DocumentService) assert container.document_service.vector_store == mock_vector_store # Assert: RAGService assert isinstance(container.rag_service, RAGService) assert container.rag_service.retrievers == mock_retrievers # Assert: TTSService assert isinstance(container.tts_service, TTSService) assert container.tts_service == mock_tts_service # Assert: STTService assert isinstance(container.stt_service, STTService) assert container.stt_service == mock_stt_service # Assert: WorkspaceService assert isinstance(container.workspace_service, WorkspaceService) assert container.workspace_service == mock_workspace_service