import logging
from fastapi import APIRouter, HTTPException, Query, Response, Depends
from fastapi.responses import StreamingResponse
from typing import AsyncGenerator
from sqlalchemy.orm import Session
from app.api.dependencies import ServiceContainer, get_db
from app.api.routes.user import get_current_user_id
from app.api import schemas
logger = logging.getLogger(__name__)
def create_tts_router(services: ServiceContainer) -> APIRouter:
router = APIRouter(prefix="/speech", tags=["TTS"])
@router.post(
"",
summary="Generate speech from text",
response_description="Audio bytes in WAV or PCM format, either as a complete file or a stream.",
)
async def create_speech_response(
request: schemas.SpeechRequest,
stream: bool = Query(
False,
description="If true, returns a streamed audio response. Otherwise, returns a complete file."
),
as_wav: bool = Query(
True,
description="If true, returns WAV format audio. If false, returns raw PCM audio data. Only applies when stream is true."
),
provider_name: str = Query(
None,
description="Optional session-level override for the TTS provider"
),
db: Session = Depends(get_db),
user_id: str = Depends(get_current_user_id)
):
try:
# Resolve provider: User Prefs > Global Settings
prefs = {}
system_prefs = services.user_service.get_system_settings(db)
if user_id:
user = services.user_service.get_user_by_id(db=db, user_id=user_id)
if user and user.preferences:
prefs = user.preferences.get("tts", {})
from app.config import settings
active_provider = provider_name or prefs.get("active_provider") or system_prefs.get("tts", {}).get("active_provider") or settings.TTS_PROVIDER
active_prefs = prefs.get("providers", {}).get(active_provider, {})
# --- Fallback to System Settings if personal key is missing ---
if not active_prefs or not active_prefs.get("api_key") or "*" in str(active_prefs.get("api_key")):
system_provider_prefs = system_prefs.get("tts", {}).get("providers", {}).get(active_provider, {})
if system_provider_prefs:
# Merge but prioritize system key if personal is masked/empty
merged = system_provider_prefs.copy()
if active_prefs: merged.update({k: v for k, v in active_prefs.items() if v})
active_prefs = merged
from app.core.providers.factory import get_tts_provider
# kwargs extract non-key/model/voice settings
kwargs = {k: v for k, v in active_prefs.items() if k not in ["api_key", "model", "voice"]}
provider_override = get_tts_provider(
provider_name=active_provider,
api_key=active_prefs.get("api_key"),
model_name=active_prefs.get("model", ""),
voice_name=active_prefs.get("voice", ""),
**kwargs
)
logger.info(f"Using TTS provider: {type(provider_override).__name__} for user={user_id}")
if stream:
# Pre-flight: generate first chunk before streaming to catch errors cleanly
# If we send StreamingResponse and then fail, the browser sees a network error
# instead of a meaningful error message.
# Split into first chunk for latency, then send entire rest for smoothness
all_text = request.text
separators = ['.', '?', '!', '\n', '。', '?', '!', ',', ';']
# Adaptive Multi-Splitter:
# 1. Tiny first chunk for instant response.
# 2. Larger later chunks to build a safety buffer for the stream.
def split_text(text):
chunks = []
primary_seps = ['.', '?', '!', '\n', '。', '?', '!']
secondary_seps = [',', ';', ',', ';', ':']
is_first = True
current_text = text.strip()
while current_text:
# Target size: 80 for first (speed), 300 for rest (buffer)
target_size = 80 if is_first else 300
if len(current_text) <= target_size:
chunks.append(current_text)
break
split_at = -1
# Priority 1: Sentence ends
for i in range(min(len(current_text)-1, target_size), 40, -1):
if current_text[i] in primary_seps:
split_at = i + 1
break
# Priority 2: Clauses/commas (only if no sentence end found)
if split_at == -1:
for i in range(min(len(current_text)-1, target_size), 40, -1):
if current_text[i] in secondary_seps:
split_at = i + 1
break
# Priority 3: Spaces (forced)
if split_at == -1:
for i in range(min(len(current_text)-1, target_size), 40, -1):
if current_text[i] == ' ':
split_at = i + 1
break
if split_at != -1:
chunks.append(current_text[:split_at].strip())
current_text = current_text[split_at:].strip()
is_first = False
else:
# Hard cutoff fallback
chunks.append(current_text[:target_size].strip())
current_text = current_text[target_size:].strip()
is_first = False
return [c for c in chunks if c]
chunks = split_text(all_text)
if not chunks:
raise HTTPException(status_code=400, detail="No text to synthesize.")
provider = provider_override or services.tts_service.default_tts_provider
if not chunks or not chunks[0].strip():
raise HTTPException(status_code=400, detail="No text to synthesize.")
# Test first chunk synchronously to validate the provider works
first_pcm = await provider.generate_speech(chunks[0])
logger.info(f"TTS Stream started for session {user_id}. Initial chunk: {len(first_pcm)} bytes.")
async def full_stream():
try:
# Yield the already-generated first chunk
if as_wav:
from app.core.services.tts import _create_wav_file
yield _create_wav_file(first_pcm)
else:
yield first_pcm
# Then stream the remaining chunks using parallel fetching but sequential yielding
import asyncio
semaphore = asyncio.Semaphore(1) # Strict lock for Beta TTS stability
async def fetch_chunk(text_chunk, idx):
retries = 3
delay = 1.0
for attempt in range(retries):
try:
async with semaphore:
pcm_data = await provider.generate_speech(text_chunk)
return pcm_data
except Exception as e:
error_str = str(e)
if "No audio in response" in error_str or "finishReason" in error_str:
logger.error(f"TTS Chunk {idx} blocked: {e}")
return None
if attempt == retries - 1:
logger.error(f"TTS Chunk {idx} failed after {retries} attempts: {e}")
return None
await asyncio.sleep(delay)
delay *= 2
# Start all tasks concurrently
tasks = [asyncio.create_task(fetch_chunk(chunk, i+1)) for i, chunk in enumerate(chunks[1:])]
for i, task in enumerate(tasks):
pcm = await task
if pcm:
if as_wav:
from app.core.services.tts import _create_wav_file
yield _create_wav_file(pcm)
else:
yield pcm
logger.info(f"TTS Chunk {i+1} yielded successfully.")
except Exception as e:
logger.error(f"Runtime error in TTS stream: {e}")
raise
finally:
logger.info(f"TTS Stream finished for session {user_id}")
media_type = "audio/wav" if as_wav else "audio/wav" # Keep as wav for best client support
return StreamingResponse(
full_stream(),
media_type=media_type,
headers={
"X-TTS-Chunk-Count": str(len(chunks)),
"Cache-Control": "no-cache",
"Connection": "keep-alive"
}
)
else:
# The non-streaming function only returns WAV
audio_bytes = await services.tts_service.create_speech_non_stream(
text=request.text,
provider_override=provider_override
)
return Response(content=audio_bytes, media_type="audio/wav")
except HTTPException:
raise
except Exception as e:
logger.error(f"TTS route error: {e}")
raise HTTPException(
status_code=500, detail=f"Failed to generate speech: {e}"
)
@router.get(
"/voices",
summary="List available TTS voices",
response_description="A list of voice names"
)
async def list_voices(
provider: str = Query(None, description="Optional provider name"),
api_key: str = Query(None, description="Optional API key override"),
db: Session = Depends(get_db),
user_id: str = Depends(get_current_user_id)
):
from app.config import settings
import httpx
from app.core.providers.tts.gemini import GeminiTTSProvider
from app.core.providers.tts.gcloud_tts import GCloudTTSProvider
# Resolve masked key if needed
key_to_use = api_key
if key_to_use and "***" in key_to_use and user_id:
user = services.user_service.get_user_by_id(db=db, user_id=user_id)
if user and user.preferences:
# Look for the key in any TTS provider since we don't necessarily know which one yet
for p_name, p_data in user.preferences.get("tts", {}).get("providers", {}).items():
if p_data.get("api_key") and "***" not in p_data["api_key"]:
# If a provider was passed, only use its key
if not provider or provider == p_name:
key_to_use = p_data["api_key"]
break
# Fallback to defaults
if not key_to_use or "***" in key_to_use:
key_to_use = settings.TTS_API_KEY or settings.GEMINI_API_KEY
# If it's Gemini, or the key starts with AIza (common AI Studio key)
if provider == "google_gemini" or (not provider and key_to_use and key_to_use.startswith("AIza")):
return sorted(GeminiTTSProvider.AVAILABLE_VOICES)
# Default or explicit GCloud
if not key_to_use:
return []
url = f"https://texttospeech.googleapis.com/v1/voices?key={key_to_use}"
try:
async with httpx.AsyncClient(timeout=10) as client:
res = await client.get(url)
if res.status_code == 200:
data = res.json()
voices = data.get('voices', [])
names = [v['name'] for v in voices]
return sorted(names)
# If Google Cloud TTS fails, maybe it's actually an AI Studio key being used for Gemini?
# Fallback to Gemini voices if it seems likely
if key_to_use.startswith("AIza"):
return sorted(GeminiTTSProvider.AVAILABLE_VOICES)
return []
except Exception as e:
import logging
logging.getLogger(__name__).error(f"Failed to fetch voices: {e}")
# Final fallback to standard list if everything else fails but we have a key
if key_to_use and key_to_use.startswith("AIza"):
return sorted(GeminiTTSProvider.AVAILABLE_VOICES)
return sorted(GCloudTTSProvider.AVAILABLE_VOICES_EN)
return router