Newer
Older
cortex-hub / ai-hub / tests / api / routes / test_tts.py
import pytest
from unittest.mock import MagicMock, AsyncMock

@pytest.mark.asyncio
async def test_create_speech_response(async_client):
    """Test the /speech endpoint returns audio bytes without streaming."""
    test_client, mock_services = await anext(async_client)
    mock_audio_bytes = b"fake wav audio bytes"

    # Use AsyncMock for the async function create_speech_non_stream
    mock_services.tts_service.create_speech_non_stream = AsyncMock(return_value=mock_audio_bytes)

    response = await test_client.post("/speech", json={"text": "Hello, this is a test"})

    assert response.status_code == 200
    assert response.headers["content-type"] == "audio/wav"
    assert response.content == mock_audio_bytes

    mock_services.tts_service.create_speech_non_stream.assert_called_once_with(text="Hello, this is a test")

@pytest.mark.asyncio
async def test_create_speech_stream_wav_response(async_client):
    """Test the /speech endpoint with stream=true and as_wav=true returns a streamed WAV response."""
    test_client, mock_services = await anext(async_client)
    mock_audio_bytes_chunks = [b"chunk1", b"chunk2", b"chunk3"]

    async def mock_async_generator():
        for chunk in mock_audio_bytes_chunks:
            yield chunk

    # Mock `create_speech_stream` with a MagicMock returning the async generator
    mock_services.tts_service.create_speech_stream = MagicMock(return_value=mock_async_generator())

    # Explicitly set stream=true and as_wav=true
    response = await test_client.post("/speech?stream=true&as_wav=true", json={"text": "Hello, this is a test"})

    assert response.status_code == 200
    assert response.headers["content-type"] == "audio/wav"

    # Read the streamed content and verify it matches the mocked chunks
    streamed_content = b""
    async for chunk in response.aiter_bytes():
        streamed_content += chunk

    assert streamed_content == b"".join(mock_audio_bytes_chunks)
    mock_services.tts_service.create_speech_stream.assert_called_once_with(text="Hello, this is a test", as_wav=True)

@pytest.mark.asyncio
async def test_create_speech_stream_pcm_response(async_client):
    """Test the /speech endpoint with stream=true and as_wav=false returns a streamed PCM response."""
    test_client, mock_services = await anext(async_client)
    mock_audio_bytes_chunks = [b"pcm_chunk1", b"pcm_chunk2", b"pcm_chunk3"]

    async def mock_async_generator():
        for chunk in mock_audio_bytes_chunks:
            yield chunk

    mock_services.tts_service.create_speech_stream = MagicMock(return_value=mock_async_generator())

    # Set stream=true and as_wav=false
    response = await test_client.post("/speech?stream=true&as_wav=false", json={"text": "Hello, this is a test"})

    assert response.status_code == 200
    assert response.headers["content-type"] == "audio/pcm"

    streamed_content = b""
    async for chunk in response.aiter_bytes():
        streamed_content += chunk

    assert streamed_content == b"".join(mock_audio_bytes_chunks)
    mock_services.tts_service.create_speech_stream.assert_called_once_with(text="Hello, this is a test", as_wav=False)