# tests/api/test_dependencies.py
import pytest
import asyncio
from unittest.mock import MagicMock, patch
from sqlalchemy.orm import Session
from fastapi import HTTPException
# Import the dependencies and services to be tested
from app.api.dependencies import get_db, get_current_user, ServiceContainer
from app.core.services.document import DocumentService
from app.core.services.rag import RAGService
from app.core.services.tts import TTSService
from app.core.services.stt import STTService
from app.core.services.workspace import WorkspaceService
from app.core.vector_store.faiss_store import FaissVectorStore
from app.core.retrievers.base_retriever import Retriever
# --- Fixtures ---
@pytest.fixture
def mock_session():
"""
Fixture that provides a mock SQLAlchemy session.
"""
mock = MagicMock(spec=Session)
yield mock
@pytest.fixture
def mock_faiss_vector_store():
"""
Fixture that provides a mock FaissVectorStore with an embedder attribute.
"""
mock = MagicMock(spec=FaissVectorStore)
# The DocumentService.__init__ method tries to access this.
mock.embedder = MagicMock()
return mock
# --- Tests for get_db dependency ---
@patch('app.api.dependencies.SessionLocal')
def test_get_db_yields_session_and_closes(mock_session_local, mock_session):
"""
Tests that get_db yields a database session and ensures it's closed correctly.
"""
# Arrange
mock_session_local.return_value = mock_session
# Act
db_generator = get_db()
db = next(db_generator)
# Assert: Correct session yielded
assert db == mock_session
# Act 2: Close generator
with pytest.raises(StopIteration):
next(db_generator)
# Assert: Session closed
mock_session.close.assert_called_once()
@patch('app.api.dependencies.SessionLocal')
def test_get_db_closes_on_exception(mock_session_local, mock_session):
"""
Tests that get_db still closes the session even if an exception occurs.
"""
mock_session_local.return_value = mock_session
db_generator = get_db()
db = next(db_generator)
with pytest.raises(Exception):
db_generator.throw(Exception("Test exception"))
mock_session.close.assert_called_once()
# --- Tests for get_current_user dependency ---
def test_get_current_user_with_valid_token():
"""
Tests that get_current_user returns the expected user dictionary for a valid token.
"""
user = asyncio.run(get_current_user(token="valid_token"))
assert user == {"email": "user@example.com", "id": 1}
def test_get_current_user_with_no_token():
"""
Tests that get_current_user raises an HTTPException for a missing token.
"""
with pytest.raises(HTTPException) as excinfo:
asyncio.run(get_current_user(token=None))
assert excinfo.value.status_code == 401
assert "Unauthorized" in str(excinfo.value.detail)
# --- Tests for ServiceContainer class ---
def test_service_container_with_document_service(mock_faiss_vector_store):
"""
Tests that the ServiceContainer correctly creates a DocumentService instance
using the with_document_service method.
"""
# Act
container = ServiceContainer().with_document_service(vector_store=mock_faiss_vector_store)
# Assert
assert isinstance(container.document_service, DocumentService)
assert container.document_service.vector_store == mock_faiss_vector_store
def test_service_container_with_rag_service():
"""
Tests that the ServiceContainer correctly creates a RAGService instance
using the with_rag_service method.
"""
# Arrange: Create mock dependencies
mock_retrievers = [MagicMock(spec=Retriever), MagicMock(spec=Retriever)]
# Act
container = ServiceContainer().with_rag_service(retrievers=mock_retrievers)
# Assert
assert isinstance(container.rag_service, RAGService)
assert container.rag_service.retrievers == mock_retrievers
def test_service_container_with_service():
"""
Tests that the ServiceContainer can add a service using the generic with_service method
and that it is accessible as an attribute.
"""
# Arrange
mock_tts_service = MagicMock(spec=TTSService)
mock_stt_service = MagicMock(spec=STTService)
# Act
container = ServiceContainer().with_service("tts_service", mock_tts_service) \
.with_service("stt_service", mock_stt_service)
# Assert
assert hasattr(container, "tts_service")
assert container.tts_service == mock_tts_service
assert hasattr(container, "stt_service")
assert container.stt_service == mock_stt_service
def test_service_container_attribute_error():
"""
Tests that accessing a non-existent service raises an AttributeError.
"""
# Arrange
container = ServiceContainer()
# Act / Assert
with pytest.raises(AttributeError) as excinfo:
_ = container.non_existent_service
assert "object has no service named 'non_existent_service'" in str(excinfo.value)
def test_service_container_chaining(mock_faiss_vector_store):
"""
Tests that the with_* methods can be chained together.
"""
# Arrange
# mock_vector_store is now a fixture
mock_retrievers = [MagicMock(spec=Retriever)]
mock_tts_service = MagicMock(spec=TTSService)
# Act
container = ServiceContainer() \
.with_document_service(vector_store=mock_faiss_vector_store) \
.with_rag_service(retrievers=mock_retrievers) \
.with_service("tts_service", mock_tts_service)
# Assert
assert isinstance(container.document_service, DocumentService)
assert isinstance(container.rag_service, RAGService)
assert container.tts_service == mock_tts_service