# tests/api/test_dependencies.py
import pytest
import asyncio
from unittest.mock import MagicMock, patch
from sqlalchemy.orm import Session
from fastapi import HTTPException
# Import the dependencies and services to be tested
from app.api.dependencies import get_db, get_current_user, ServiceContainer
from app.core.services.document import DocumentService
from app.core.services.rag import RAGService
from app.core.services.tts import TTSService # Added this import
from app.core.vector_store.faiss_store import FaissVectorStore
from app.core.retrievers.base_retriever import Retriever
@pytest.fixture
def mock_session():
"""
Fixture that provides a mock SQLAlchemy session.
"""
mock = MagicMock(spec=Session)
yield mock
# --- Tests for get_db dependency ---
@patch('app.api.dependencies.SessionLocal')
def test_get_db_yields_session_and_closes(mock_session_local, mock_session):
"""
Tests that get_db yields a database session and ensures it's closed correctly.
"""
# Arrange: Configure the mock SessionLocal to return our mock_session
mock_session_local.return_value = mock_session
# Act: Use the generator in a context manager
db_generator = get_db()
db = next(db_generator)
# Assert 1: The correct session object was yielded
assert db == mock_session
# Act 2: Manually close the generator
with pytest.raises(StopIteration):
next(db_generator)
# Assert 2: The session's close method was called
mock_session.close.assert_called_once()
@patch('app.api.dependencies.SessionLocal')
def test_get_db_closes_on_exception(mock_session_local, mock_session):
"""
Tests that get_db still closes the session even if an exception occurs.
"""
# Arrange: Configure the mock SessionLocal to return our mock_session
mock_session_local.return_value = mock_session
# Act & Assert: Call the generator and raise an exception
db_generator = get_db()
db = next(db_generator)
with pytest.raises(Exception):
db_generator.throw(Exception("Test exception"))
# Assert: The session's close method was still called after the exception was handled
mock_session.close.assert_called_once()
# --- Tests for get_current_user dependency ---
def test_get_current_user_with_valid_token():
"""
Tests that get_current_user returns the expected user dictionary for a valid token.
"""
# Act
user = asyncio.run(get_current_user(token="valid_token"))
# Assert
assert user == {"email": "user@example.com", "id": 1}
def test_get_current_user_with_no_token():
"""
Tests that get_current_user raises an HTTPException for a missing token.
"""
# Assert
with pytest.raises(HTTPException) as excinfo:
asyncio.run(get_current_user(token=None))
assert excinfo.value.status_code == 401
assert "Unauthorized" in excinfo.value.detail
# --- Tests for ServiceContainer class ---
def test_service_container_initialization():
"""
Tests that ServiceContainer initializes DocumentService and RAGService
with the correct dependencies.
"""
# Arrange: Create mock dependencies
mock_vector_store = MagicMock(spec=FaissVectorStore)
# The DocumentService constructor needs a .embedder attribute on the vector_store
mock_vector_store.embedder = MagicMock()
mock_retrievers = [MagicMock(spec=Retriever), MagicMock(spec=Retriever)]
mock_tts_service = MagicMock(spec=TTSService)
# Act: Instantiate the ServiceContainer
container = ServiceContainer(
vector_store=mock_vector_store,
retrievers=mock_retrievers,
tts_service=mock_tts_service # Passing the mock TTS service
)
# Assert: Check if the services were created and configured correctly
assert isinstance(container.document_service, DocumentService)
assert container.document_service.vector_store == mock_vector_store
assert isinstance(container.rag_service, RAGService)
assert container.rag_service.retrievers == mock_retrievers
# Assert for the tts_service as well
assert isinstance(container.tts_service, TTSService)
assert container.tts_service == mock_tts_service