Newer
Older
cortex-hub / ai-hub / tests / api / routes / conftest.py
import pytest
from unittest.mock import MagicMock, AsyncMock
from fastapi import FastAPI
from fastapi.testclient import TestClient
from httpx import AsyncClient, ASGITransport
from sqlalchemy.orm import Session
from app.api.dependencies import get_db, ServiceContainer
from app.core.services.rag import RAGService
from app.core.services.document import DocumentService
from app.core.services.tts import TTSService
from app.api.routes.api import create_api_router

# Change the scope to "function" so the fixture is re-created for each test
@pytest.fixture(scope="function")
def client():
    """
    Pytest fixture to create a TestClient with a fully mocked environment
    for synchronous endpoints, scoped to a single function.
    """
    test_app = FastAPI()

    mock_rag_service = MagicMock(spec=RAGService)
    mock_document_service = MagicMock(spec=DocumentService)
    mock_tts_service = MagicMock(spec=TTSService)

    mock_services = MagicMock(spec=ServiceContainer)
    mock_services.rag_service = mock_rag_service
    mock_services.document_service = mock_document_service
    mock_services.tts_service = mock_tts_service

    mock_db_session = MagicMock(spec=Session)

    def override_get_db():
        yield mock_db_session

    api_router = create_api_router(services=mock_services)
    test_app.dependency_overrides[get_db] = override_get_db
    test_app.include_router(api_router)

    test_client = TestClient(test_app)

    # Use a yield to ensure the teardown happens after each test
    yield test_client, mock_services
    
    # You could also add a reset call here for an extra layer of safety,
    # but with scope="function" it's not strictly necessary.

# Change the scope to "function" for the async client as well
@pytest.fixture(scope="function")
async def async_client():
    """
    Pytest fixture to create an AsyncClient for testing async endpoints,
    scoped to a single function.
    """
    test_app = FastAPI()

    mock_rag_service = MagicMock(spec=RAGService)
    mock_document_service = MagicMock(spec=DocumentService)
    mock_tts_service = MagicMock(spec=TTSService)

    mock_services = MagicMock(spec=ServiceContainer)
    mock_services.rag_service = mock_rag_service
    mock_services.document_service = mock_document_service
    mock_services.tts_service = mock_tts_service

    mock_db_session = MagicMock(spec=Session)

    def override_get_db():
        yield mock_db_session

    api_router = create_api_router(services=mock_services)
    test_app.dependency_overrides[get_db] = override_get_db
    test_app.include_router(api_router)

    async with AsyncClient(transport=ASGITransport(app=test_app), base_url="http://test") as client:
        yield client, mock_services