# tests/api/test_dependencies.py import pytest import asyncio from unittest.mock import MagicMock, patch from sqlalchemy.orm import Session from fastapi import HTTPException # Import the dependencies and services to be tested from app.api.dependencies import get_db, get_current_user, ServiceContainer from app.core.services.document import DocumentService from app.core.services.rag import RAGService from app.core.services.tts import TTSService from app.core.services.stt import STTService from app.core.services.workspace import WorkspaceService from app.core.vector_store.faiss_store import FaissVectorStore from app.core.retrievers.base_retriever import Retriever # --- Fixtures --- @pytest.fixture def mock_session(): """ Fixture that provides a mock SQLAlchemy session. """ mock = MagicMock(spec=Session) yield mock @pytest.fixture def mock_faiss_vector_store(): """ Fixture that provides a mock FaissVectorStore with an embedder attribute. """ mock = MagicMock(spec=FaissVectorStore) # The DocumentService.__init__ method tries to access this. mock.embedder = MagicMock() return mock # --- Tests for get_db dependency --- @patch('app.api.dependencies.SessionLocal') def test_get_db_yields_session_and_closes(mock_session_local, mock_session): """ Tests that get_db yields a database session and ensures it's closed correctly. """ # Arrange mock_session_local.return_value = mock_session # Act db_generator = get_db() db = next(db_generator) # Assert: Correct session yielded assert db == mock_session # Act 2: Close generator with pytest.raises(StopIteration): next(db_generator) # Assert: Session closed mock_session.close.assert_called_once() @patch('app.api.dependencies.SessionLocal') def test_get_db_closes_on_exception(mock_session_local, mock_session): """ Tests that get_db still closes the session even if an exception occurs. """ mock_session_local.return_value = mock_session db_generator = get_db() db = next(db_generator) with pytest.raises(Exception): db_generator.throw(Exception("Test exception")) mock_session.close.assert_called_once() # --- Tests for get_current_user dependency --- def test_get_current_user_with_valid_token(): """ Tests that get_current_user returns the expected user dictionary for a valid token. """ user = asyncio.run(get_current_user(token="valid_token")) assert user == {"email": "user@example.com", "id": 1} def test_get_current_user_with_no_token(): """ Tests that get_current_user raises an HTTPException for a missing token. """ with pytest.raises(HTTPException) as excinfo: asyncio.run(get_current_user(token=None)) assert excinfo.value.status_code == 401 assert "Unauthorized" in str(excinfo.value.detail) # --- Tests for ServiceContainer class --- def test_service_container_with_document_service(mock_faiss_vector_store): """ Tests that the ServiceContainer correctly creates a DocumentService instance using the with_document_service method. """ # Act container = ServiceContainer().with_document_service(vector_store=mock_faiss_vector_store) # Assert assert isinstance(container.document_service, DocumentService) assert container.document_service.vector_store == mock_faiss_vector_store def test_service_container_with_rag_service(): """ Tests that the ServiceContainer correctly creates a RAGService instance using the with_rag_service method. """ # Arrange: Create mock dependencies mock_retrievers = [MagicMock(spec=Retriever), MagicMock(spec=Retriever)] # Act container = ServiceContainer().with_rag_service(retrievers=mock_retrievers) # Assert assert isinstance(container.rag_service, RAGService) assert container.rag_service.retrievers == mock_retrievers def test_service_container_with_service(): """ Tests that the ServiceContainer can add a service using the generic with_service method and that it is accessible as an attribute. """ # Arrange mock_tts_service = MagicMock(spec=TTSService) mock_stt_service = MagicMock(spec=STTService) # Act container = ServiceContainer().with_service("tts_service", mock_tts_service) \ .with_service("stt_service", mock_stt_service) # Assert assert hasattr(container, "tts_service") assert container.tts_service == mock_tts_service assert hasattr(container, "stt_service") assert container.stt_service == mock_stt_service def test_service_container_attribute_error(): """ Tests that accessing a non-existent service raises an AttributeError. """ # Arrange container = ServiceContainer() # Act / Assert with pytest.raises(AttributeError) as excinfo: _ = container.non_existent_service assert "object has no service named 'non_existent_service'" in str(excinfo.value) def test_service_container_chaining(mock_faiss_vector_store): """ Tests that the with_* methods can be chained together. """ # Arrange # mock_vector_store is now a fixture mock_retrievers = [MagicMock(spec=Retriever)] mock_tts_service = MagicMock(spec=TTSService) # Act container = ServiceContainer() \ .with_document_service(vector_store=mock_faiss_vector_store) \ .with_rag_service(retrievers=mock_retrievers) \ .with_service("tts_service", mock_tts_service) # Assert assert isinstance(container.document_service, DocumentService) assert isinstance(container.rag_service, RAGService) assert container.tts_service == mock_tts_service