Newer
Older
cortex-hub / ai-hub / tests / api / test_dependencies.py
# tests/api/test_dependencies.py

import pytest
import asyncio
from unittest.mock import MagicMock, patch
from sqlalchemy.orm import Session
from fastapi import HTTPException

# Import the dependencies and services to be tested
from app.api.dependencies import get_db, get_current_user, ServiceContainer
from app.core.services.document import DocumentService
from app.core.services.rag import RAGService
from app.core.services.tts import TTSService
from app.core.services.stt import STTService
from app.core.services.workspace import WorkspaceService
from app.core.vector_store.faiss_store import FaissVectorStore
from app.core.retrievers.base_retriever import Retriever


# --- Fixtures ---

@pytest.fixture
def mock_session():
    """
    Fixture that provides a mock SQLAlchemy session.
    """
    mock = MagicMock(spec=Session)
    yield mock


# --- Tests for get_db dependency ---

@patch('app.api.dependencies.SessionLocal')
def test_get_db_yields_session_and_closes(mock_session_local, mock_session):
    """
    Tests that get_db yields a database session and ensures it's closed correctly.
    """
    # Arrange
    mock_session_local.return_value = mock_session

    # Act
    db_generator = get_db()
    db = next(db_generator)

    # Assert: Correct session yielded
    assert db == mock_session

    # Act 2: Close generator
    with pytest.raises(StopIteration):
        next(db_generator)

    # Assert: Session closed
    mock_session.close.assert_called_once()


@patch('app.api.dependencies.SessionLocal')
def test_get_db_closes_on_exception(mock_session_local, mock_session):
    """
    Tests that get_db still closes the session even if an exception occurs.
    """
    mock_session_local.return_value = mock_session

    db_generator = get_db()
    db = next(db_generator)

    with pytest.raises(Exception):
        db_generator.throw(Exception("Test exception"))

    mock_session.close.assert_called_once()


# --- Tests for get_current_user dependency ---

def test_get_current_user_with_valid_token():
    """
    Tests that get_current_user returns the expected user dictionary for a valid token.
    """
    user = asyncio.run(get_current_user(token="valid_token"))

    assert user == {"email": "user@example.com", "id": 1}


def test_get_current_user_with_no_token():
    """
    Tests that get_current_user raises an HTTPException for a missing token.
    """
    with pytest.raises(HTTPException) as excinfo:
        asyncio.run(get_current_user(token=None))

    assert excinfo.value.status_code == 401
    assert "Unauthorized" in str(excinfo.value.detail)


# --- Tests for ServiceContainer class ---

def test_service_container_initialization():
    """
    Tests that ServiceContainer initializes DocumentService, RAGService, and other services
    with the correct dependencies.
    """
    # Arrange: Create mock dependencies
    mock_vector_store = MagicMock(spec=FaissVectorStore)
    mock_vector_store.embedder = MagicMock()
    
    mock_retrievers = [MagicMock(spec=Retriever), MagicMock(spec=Retriever)]
    mock_tts_service = MagicMock(spec=TTSService)
    mock_stt_service = MagicMock(spec=STTService)
    mock_workspace_service = MagicMock(spec=WorkspaceService)

    # Act
    container = ServiceContainer(
        vector_store=mock_vector_store,
        retrievers=mock_retrievers,
        tts_service=mock_tts_service,
        stt_service=mock_stt_service,
        workspace_service=mock_workspace_service
    )

    # Assert: DocumentService
    assert isinstance(container.document_service, DocumentService)
    assert container.document_service.vector_store == mock_vector_store

    # Assert: RAGService
    assert isinstance(container.rag_service, RAGService)
    assert container.rag_service.retrievers == mock_retrievers

    # Assert: TTSService
    assert isinstance(container.tts_service, TTSService)
    assert container.tts_service == mock_tts_service

    # Assert: STTService
    assert isinstance(container.stt_service, STTService)
    assert container.stt_service == mock_stt_service

    # Assert: WorkspaceService
    assert isinstance(container.workspace_service, WorkspaceService)
    assert container.workspace_service == mock_workspace_service