Newer
Older
cortex-hub / ai-hub / tests / api / routes / test_tts.py
import pytest
from unittest.mock import MagicMock, AsyncMock

@pytest.mark.asyncio
async def test_create_speech_response(async_client):
    """Test the /speech endpoint returns audio bytes without streaming."""
    test_client, mock_services = await anext(async_client)
    mock_audio_bytes = b"fake wav audio bytes"

    # Use AsyncMock for the async function create_speech_non_stream
    mock_services.tts_service.create_speech_non_stream = AsyncMock(return_value=mock_audio_bytes)

    response = await test_client.post("/speech", json={"text": "Hello, this is a test"})

    assert response.status_code == 200
    assert response.headers["content-type"] == "audio/wav"
    assert response.content == mock_audio_bytes

    mock_services.tts_service.create_speech_non_stream.assert_called_once_with(text="Hello, this is a test")

@pytest.mark.asyncio
async def test_create_speech_stream_response(async_client):
    """Test the /speech endpoint with stream=true returns a streaming response."""
    test_client, mock_services = await anext(async_client)
    mock_audio_bytes_chunks = [b"chunk1", b"chunk2", b"chunk3"]

    # This async generator mock correctly simulates the streaming service
    async def mock_async_generator():
        for chunk in mock_audio_bytes_chunks:
            yield chunk

    # We mock `create_speech_stream` with a MagicMock returning the async generator
    mock_services.tts_service.create_speech_stream = MagicMock(return_value=mock_async_generator())

    response = await test_client.post("/speech?stream=true", json={"text": "Hello, this is a test"})

    assert response.status_code == 200
    assert response.headers["content-type"] == "audio/wav"

    # Read the streamed content and verify it matches the mocked chunks
    streamed_content = b""
    async for chunk in response.aiter_bytes():
        streamed_content += chunk

    assert streamed_content == b"".join(mock_audio_bytes_chunks)
    mock_services.tts_service.create_speech_stream.assert_called_once_with(text="Hello, this is a test")