from fastapi import APIRouter, HTTPException, Query, Response
from fastapi.responses import StreamingResponse
from app.api.dependencies import ServiceContainer
from app.api import schemas
from typing import AsyncGenerator
def create_tts_router(services: ServiceContainer) -> APIRouter:
router = APIRouter(prefix="/speech", tags=["TTS"])
@router.post(
"",
summary="Generate speech from text",
response_description="Audio bytes in WAV or PCM format, either as a complete file or a stream.",
)
async def create_speech_response(
request: schemas.SpeechRequest,
stream: bool = Query(
False,
description="If true, returns a streamed audio response. Otherwise, returns a complete file."
),
as_wav: bool = Query(
True,
description="If true, returns WAV format audio. If false, returns raw PCM audio data. Only applies when stream is true."
)
):
try:
if stream:
# Pass the new as_wav parameter to the streaming function
audio_stream_generator: AsyncGenerator[bytes, None] = services.tts_service.create_speech_stream(
text=request.text,
as_wav=as_wav
)
# Dynamically set the media_type based on the as_wav flag
media_type = "audio/wav" if as_wav else "audio/pcm"
return StreamingResponse(audio_stream_generator, media_type=media_type)
else:
# The non-streaming function only returns WAV, so this part remains the same
audio_bytes = await services.tts_service.create_speech_non_stream(
text=request.text
)
return Response(content=audio_bytes, media_type="audio/wav")
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Failed to generate speech: {e}"
)
return router