from fastapi import APIRouter, HTTPException, Query, Response, Depends
from fastapi.responses import StreamingResponse
from typing import AsyncGenerator
from sqlalchemy.orm import Session
from app.api.dependencies import ServiceContainer, get_db
from app.api.routes.user import get_current_user_id
from app.api import schemas
def create_tts_router(services: ServiceContainer) -> APIRouter:
router = APIRouter(prefix="/speech", tags=["TTS"])
@router.post(
"",
summary="Generate speech from text",
response_description="Audio bytes in WAV or PCM format, either as a complete file or a stream.",
)
async def create_speech_response(
request: schemas.SpeechRequest,
stream: bool = Query(
False,
description="If true, returns a streamed audio response. Otherwise, returns a complete file."
),
as_wav: bool = Query(
True,
description="If true, returns WAV format audio. If false, returns raw PCM audio data. Only applies when stream is true."
),
provider_name: str = Query(
None,
description="Optional session-level override for the TTS provider"
),
db: Session = Depends(get_db),
user_id: str = Depends(get_current_user_id)
):
try:
provider_override = None
if user_id:
user = services.user_service.get_user_by_id(db=db, user_id=user_id)
prefs = user.preferences.get("tts", {}) if user and user.preferences else {}
from app.config import settings
active_provider = provider_name or prefs.get("active_provider") or settings.TTS_PROVIDER
active_prefs = prefs.get("providers", {}).get(active_provider, {})
if active_prefs:
from app.core.providers.factory import get_tts_provider
kwargs = {k: v for k, v in active_prefs.items() if k not in ["api_key", "model", "voice"]}
provider_override = get_tts_provider(
provider_name=active_provider,
api_key=active_prefs.get("api_key"),
model_name=active_prefs.get("model", ""),
voice_name=active_prefs.get("voice", ""),
**kwargs
)
if stream:
# Pre-flight: generate first chunk before streaming to catch errors cleanly
# If we send StreamingResponse and then fail, the browser sees a network error
# instead of a meaningful error message.
chunks = await services.tts_service._split_text_into_chunks(request.text)
provider = provider_override or services.tts_service.default_tts_provider
if not chunks:
raise HTTPException(status_code=400, detail="No text to synthesize.")
# Test first chunk synchronously to validate the provider works
first_pcm = await provider.generate_speech(chunks[0])
async def full_stream():
# Yield the already-generated first chunk
if as_wav:
from app.core.services.tts import _create_wav_file
yield _create_wav_file(first_pcm)
else:
yield first_pcm
# Then stream the remaining chunks
for chunk in chunks[1:]:
try:
pcm = await provider.generate_speech(chunk)
if pcm:
if as_wav:
from app.core.services.tts import _create_wav_file
yield _create_wav_file(pcm)
else:
yield pcm
except Exception as e:
import logging
logging.getLogger(__name__).error(f"TTS chunk error: {e}")
break # Stop cleanly rather than crashing the stream
media_type = "audio/wav" if as_wav else "audio/pcm"
return StreamingResponse(full_stream(), media_type=media_type)
else:
# The non-streaming function only returns WAV, so this part remains the same
audio_bytes = await services.tts_service.create_speech_non_stream(
text=request.text,
provider_override=provider_override
)
return Response(content=audio_bytes, media_type="audio/wav")
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Failed to generate speech: {e}"
)
@router.get(
"/voices",
summary="List available TTS voices",
response_description="A list of voice names"
)
async def list_voices(
provider: str = Query(None, description="Optional provider name"),
api_key: str = Query(None, description="Optional API key override"),
db: Session = Depends(get_db),
user_id: str = Depends(get_current_user_id)
):
from app.config import settings
import httpx
from app.core.providers.tts.gemini import GeminiTTSProvider
from app.core.providers.tts.gcloud_tts import GCloudTTSProvider
# Resolve masked key if needed
key_to_use = api_key
if key_to_use and "***" in key_to_use and user_id:
user = services.user_service.get_user_by_id(db=db, user_id=user_id)
if user and user.preferences:
# Look for the key in any TTS provider since we don't necessarily know which one yet
for p_name, p_data in user.preferences.get("tts", {}).get("providers", {}).items():
if p_data.get("api_key") and "***" not in p_data["api_key"]:
# If a provider was passed, only use its key
if not provider or provider == p_name:
key_to_use = p_data["api_key"]
break
# Fallback to defaults
if not key_to_use or "***" in key_to_use:
key_to_use = settings.TTS_API_KEY or settings.GEMINI_API_KEY
# If it's Gemini, or the key starts with AIza (common AI Studio key)
if provider == "google_gemini" or (not provider and key_to_use and key_to_use.startswith("AIza")):
return sorted(GeminiTTSProvider.AVAILABLE_VOICES)
# Default or explicit GCloud
if not key_to_use:
return []
url = f"https://texttospeech.googleapis.com/v1/voices?key={key_to_use}"
try:
async with httpx.AsyncClient(timeout=10) as client:
res = await client.get(url)
if res.status_code == 200:
data = res.json()
voices = data.get('voices', [])
names = [v['name'] for v in voices]
return sorted(names)
# If Google Cloud TTS fails, maybe it's actually an AI Studio key being used for Gemini?
# Fallback to Gemini voices if it seems likely
if key_to_use.startswith("AIza"):
return sorted(GeminiTTSProvider.AVAILABLE_VOICES)
return []
except Exception as e:
import logging
logging.getLogger(__name__).error(f"Failed to fetch voices: {e}")
# Final fallback to standard list if everything else fails but we have a key
if key_to_use and key_to_use.startswith("AIza"):
return sorted(GeminiTTSProvider.AVAILABLE_VOICES)
return sorted(GCloudTTSProvider.AVAILABLE_VOICES_EN)
return router