import os
import json
import httpx
import base64
import logging
import asyncio
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type, retry_if_exception
from app.core.providers.base import TTSProvider
from fastapi import HTTPException
# Configure logging
logger = logging.getLogger(__name__)
def is_retryable_exception(exception):
"""Check if the exception is one we should retry on."""
if isinstance(exception, httpx.TimeoutException):
return True
if isinstance(exception, httpx.NetworkError):
return True
if isinstance(exception, HTTPException):
# Retry on 429 (Too Many Requests) or 5xx (Server Errors)
return exception.status_code == 429 or 500 <= exception.status_code < 600
return False
class GeminiTTSProvider(TTSProvider):
"""TTS provider using Gemini's audio responseModalities via Google AI Studio."""
AVAILABLE_VOICES = [
"Zephyr", "Puck", "Charon", "Kore", "Fenrir", "Leda",
"Orus", "Aoede", "Callirrhoe", "Autonoe", "Enceladus",
"Iapetus", "Umbriel", "Algieba", "Despina", "Erinome",
"Algenib", "Rasalgethi", "Laomedeia", "Achernar", "Alnilam",
"Schedar", "Gacrux", "Pulcherrima", "Achird", "Zubenelgenubi",
"Vindemiatrix", "Sadachbia", "Sadaltager", "Sulafat"
]
def __init__(self, api_key: str, model_name: str = "gemini-2.5-flash-preview-tts",
voice_name: str = "Kore", **kwargs):
raw_model = model_name or "gemini-2.5-flash-preview-tts"
# Strip any provider prefix (e.g. "vertex_ai/model" or "gemini/model") → keep only the model id
model_id = raw_model.split("/")[-1]
# Normalise short names: "gemini-2-flash-tts" → "gemini-2.5-flash-preview-tts"
if model_id in ("gemini-2-flash-tts", "gemini-2.5-flash-tts", "flash-tts", "gemini-2.5-flash"):
model_id = "gemini-2.5-flash-preview-tts"
logger.info(f"Normalised model name to: {model_id}")
# Route to Vertex AI ONLY when the key is a Vertex service-account key (starting with "AQ.")
# AI Studio keys start with "AIza" and must use the generativelanguage endpoint.
is_vertex_key = bool(api_key) and api_key.startswith("AQ.")
if is_vertex_key:
self.api_url = (
f"https://us-central1-aiplatform.googleapis.com/v1/publishers/google/"
f"models/{model_id}:streamGenerateContent"
)
self.is_vertex = True
else:
# Google AI Studio — v1beta is required for audio responseModalities
self.api_url = (
f"https://generativelanguage.googleapis.com/v1beta/models/"
f"{model_id}:generateContent?key={api_key}"
)
self.is_vertex = False
self.api_key = api_key
self.voice_name = voice_name
self.model_name = model_id
logger.debug(f"GeminiTTSProvider: model={self.model_name}, vertex={self.is_vertex}")
logger.debug(f" endpoint: {self.api_url[:80]}...")
@retry(
retry=retry_if_exception_type((httpx.TimeoutException, httpx.NetworkError)) | retry_if_exception(is_retryable_exception),
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),
reraise=True,
before_sleep=lambda retry_state: logger.warning(f"Retrying Gemini TTS request (attempt {retry_state.attempt_number})...")
)
async def generate_speech(self, text: str) -> bytes:
logger.info(f"TTS request [model={self.model_name}, vertex={self.is_vertex}]: '{text}'")
headers = {"Content-Type": "application/json"}
# The dedicated TTS models require a system instruction to produce only audio
json_data = {
"contents": [{"role": "user", "parts": [{"text": text}]}],
"generationConfig": {
"responseModalities": ["AUDIO"],
"speechConfig": {
"voiceConfig": {
"prebuiltVoiceConfig": {
"voiceName": self.voice_name
}
}
}
}
}
if not self.is_vertex:
headers["x-goog-api-key"] = self.api_key
logger.debug(f"Calling: {self.api_url}")
try:
async with httpx.AsyncClient(timeout=httpx.Timeout(60.0)) as client:
response = await client.post(self.api_url, headers=headers, json=json_data)
logger.debug(f"Response status: {response.status_code}")
if response.status_code != 200:
body = response.text
logger.error(f"TTS API error {response.status_code}: {body[:300]}")
try:
err = response.json().get("error", {})
msg = err.get("message", body[:200])
except Exception:
msg = body[:200]
# Check if we should retry (429 or 5xx)
status_code = response.status_code
if status_code == 429 or 500 <= status_code < 600:
# tenacity will catch this if we configure it to retry on HTTPException
raise HTTPException(status_code=status_code, detail=f"Gemini TTS error: {msg}")
else:
# Non-retryable error
raise HTTPException(status_code=status_code, detail=f"Gemini TTS error: {msg}")
resp_data = response.json()
audio_fragments = []
# Handle both list (streamGenerateContent) and single object (generateContent)
segments = resp_data if isinstance(resp_data, list) else [resp_data]
for segment in segments:
candidates = segment.get("candidates", [])
if candidates:
parts = candidates[0].get("content", {}).get("parts", [])
for part in parts:
inline = part.get("inlineData", {})
data = inline.get("data")
if data:
audio_fragments.append(base64.b64decode(data))
if not audio_fragments:
logger.error(f"No audio in response. Full response: {json.dumps(resp_data)[:500]}")
raise HTTPException(status_code=500, detail="No audio data in Gemini TTS response.")
result = b"".join(audio_fragments)
logger.debug(f"TTS returned {len(result)} PCM bytes")
return result
except (httpx.TimeoutException, httpx.NetworkError) as e:
logger.error(f"Gemini TTS request ({type(e).__name__}) after 60s")
# tenacity will catch this and retry
raise
except HTTPException:
# tenacity might catch this if it's 429 or 5xx
raise
except Exception as e:
logger.error(f"Unexpected TTS error: {type(e).__name__}: {e}")
raise HTTPException(status_code=500, detail=f"Failed to generate speech: {e}")