Newer
Older
cortex-hub / ai-hub / tests / api / test_dependencies.py
# tests/api/test_dependencies.py
import pytest
import asyncio
from unittest.mock import MagicMock, patch
from sqlalchemy.orm import Session
from fastapi import HTTPException

# Import the dependencies and services to be tested
from app.api.dependencies import get_db, get_current_user, ServiceContainer
from app.core.services.document import DocumentService
from app.core.services.rag import RAGService
from app.core.services.tts import TTSService
from app.core.services.stt import STTService # Added this import
from app.core.vector_store.faiss_store import FaissVectorStore
from app.core.retrievers.base_retriever import Retriever

@pytest.fixture
def mock_session():
    """
    Fixture that provides a mock SQLAlchemy session.
    """
    mock = MagicMock(spec=Session)
    yield mock

# --- Tests for get_db dependency ---

@patch('app.api.dependencies.SessionLocal')
def test_get_db_yields_session_and_closes(mock_session_local, mock_session):
    """
    Tests that get_db yields a database session and ensures it's closed correctly.
    """
    # Arrange: Configure the mock SessionLocal to return our mock_session
    mock_session_local.return_value = mock_session

    # Act: Use the generator in a context manager
    db_generator = get_db()
    db = next(db_generator)

    # Assert 1: The correct session object was yielded
    assert db == mock_session

    # Act 2: Manually close the generator
    with pytest.raises(StopIteration):
        next(db_generator)

    # Assert 2: The session's close method was called
    mock_session.close.assert_called_once()

@patch('app.api.dependencies.SessionLocal')
def test_get_db_closes_on_exception(mock_session_local, mock_session):
    """
    Tests that get_db still closes the session even if an exception occurs.
    """
    # Arrange: Configure the mock SessionLocal to return our mock_session
    mock_session_local.return_value = mock_session

    # Act & Assert: Call the generator and raise an exception
    db_generator = get_db()
    db = next(db_generator)
    with pytest.raises(Exception):
        db_generator.throw(Exception("Test exception"))
        
    # Assert: The session's close method was still called after the exception was handled
    mock_session.close.assert_called_once()


# --- Tests for get_current_user dependency ---

def test_get_current_user_with_valid_token():
    """
    Tests that get_current_user returns the expected user dictionary for a valid token.
    """
    # Act
    user = asyncio.run(get_current_user(token="valid_token"))

    # Assert
    assert user == {"email": "user@example.com", "id": 1}

def test_get_current_user_with_no_token():
    """
    Tests that get_current_user raises an HTTPException for a missing token.
    """
    # Assert
    with pytest.raises(HTTPException) as excinfo:
        asyncio.run(get_current_user(token=None))
    
    assert excinfo.value.status_code == 401
    assert "Unauthorized" in excinfo.value.detail

# --- Tests for ServiceContainer class ---

def test_service_container_initialization():
    """
    Tests that ServiceContainer initializes DocumentService and RAGService
    with the correct dependencies.
    """
    # Arrange: Create mock dependencies
    mock_vector_store = MagicMock(spec=FaissVectorStore)
    # The DocumentService constructor needs a .embedder attribute on the vector_store
    mock_vector_store.embedder = MagicMock()
    mock_retrievers = [MagicMock(spec=Retriever), MagicMock(spec=Retriever)]
    mock_tts_service = MagicMock(spec=TTSService)
    
    # NEW: Create a mock for STTService
    mock_stt_service = MagicMock(spec=STTService)

    # Act: Instantiate the ServiceContainer, now with all required arguments
    container = ServiceContainer(
        vector_store=mock_vector_store,
        retrievers=mock_retrievers,
        tts_service=mock_tts_service,
        stt_service=mock_stt_service # Pass the new mock here
    )

    # Assert: Check if the services were created and configured correctly
    assert isinstance(container.document_service, DocumentService)
    assert container.document_service.vector_store == mock_vector_store
    
    assert isinstance(container.rag_service, RAGService)
    assert container.rag_service.retrievers == mock_retrievers
    
    # Assert for the tts_service and stt_service as well
    assert isinstance(container.tts_service, TTSService)
    assert container.tts_service == mock_tts_service
    assert isinstance(container.stt_service, STTService)
    assert container.stt_service == mock_stt_service