import os
import json
import httpx
import base64
import logging
from app.core.providers.base import TTSProvider
from fastapi import HTTPException
# Configure logging
logger = logging.getLogger(__name__)
class GeminiTTSProvider(TTSProvider):
"""TTS provider using Gemini's audio responseModalities via Google AI Studio."""
AVAILABLE_VOICES = [
"Zephyr", "Puck", "Charon", "Kore", "Fenrir", "Leda",
"Orus", "Aoede", "Callirrhoe", "Autonoe", "Enceladus",
"Iapetus", "Umbriel", "Algieba", "Despina", "Erinome",
"Algenib", "Rasalgethi", "Laomedeia", "Achernar", "Alnilam",
"Schedar", "Gacrux", "Pulcherrima", "Achird", "Zubenelgenubi",
"Vindemiatrix", "Sadachbia", "Sadaltager", "Sulafat"
]
def __init__(self, api_key: str, model_name: str = "gemini-2.5-flash-preview-tts",
voice_name: str = "Kore", **kwargs):
raw_model = model_name or "gemini-2.5-flash-preview-tts"
# Strip any provider prefix (e.g. "vertex_ai/model" or "gemini/model") → keep only the model id
model_id = raw_model.split("/")[-1]
# Normalise short names: "gemini-2-flash-tts" → "gemini-2.5-flash-preview-tts"
if model_id in ("gemini-2-flash-tts", "gemini-2.5-flash-tts", "flash-tts"):
model_id = "gemini-2.5-flash-preview-tts"
logger.info(f"Normalised model name to: {model_id}")
# Route to Vertex AI ONLY when the key is a Vertex service-account key (starting with "AQ.")
# AI Studio keys start with "AIza" and must use the generativelanguage endpoint.
is_vertex_key = bool(api_key) and api_key.startswith("AQ.")
if is_vertex_key:
self.api_url = (
f"https://us-central1-aiplatform.googleapis.com/v1/publishers/google/"
f"models/{model_id}:streamGenerateContent"
)
self.is_vertex = True
else:
# Google AI Studio — v1beta is required for audio responseModalities
self.api_url = (
f"https://generativelanguage.googleapis.com/v1beta/models/"
f"{model_id}:generateContent?key={api_key}"
)
self.is_vertex = False
self.api_key = api_key
self.voice_name = voice_name
self.model_name = model_id
logger.debug(f"GeminiTTSProvider: model={self.model_name}, vertex={self.is_vertex}")
logger.debug(f" endpoint: {self.api_url[:80]}...")
async def generate_speech(self, text: str) -> bytes:
logger.debug(f"TTS generate_speech: '{text[:60]}...'")
headers = {"Content-Type": "application/json"}
# The dedicated TTS models require a system instruction to produce only audio
json_data = {
"system_instruction": {
"parts": [{"text": "You are a text-to-speech system. Convert the user text to speech audio only. Do not generate any text response."}]
},
"contents": [{"role": "user", "parts": [{"text": text}]}],
"generationConfig": {
"responseModalities": ["AUDIO"],
"speechConfig": {
"voiceConfig": {
"prebuiltVoiceConfig": {
"voiceName": self.voice_name
}
}
}
}
}
if not self.is_vertex:
headers["x-goog-api-key"] = self.api_key
logger.debug(f"Calling: {self.api_url}")
try:
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client:
response = await client.post(self.api_url, headers=headers, json=json_data)
logger.debug(f"Response status: {response.status_code}")
if response.status_code != 200:
body = response.text
logger.error(f"TTS API error {response.status_code}: {body[:300]}")
try:
err = response.json().get("error", {})
msg = err.get("message", body[:200])
except Exception:
msg = body[:200]
raise HTTPException(status_code=response.status_code, detail=f"Gemini TTS error: {msg}")
resp_data = response.json()
audio_fragments = []
# Handle both list (streamGenerateContent) and single object (generateContent)
segments = resp_data if isinstance(resp_data, list) else [resp_data]
for segment in segments:
candidates = segment.get("candidates", [])
if candidates:
parts = candidates[0].get("content", {}).get("parts", [])
for part in parts:
inline = part.get("inlineData", {})
data = inline.get("data")
if data:
audio_fragments.append(base64.b64decode(data))
if not audio_fragments:
logger.error(f"No audio in response. Full response: {json.dumps(resp_data)[:500]}")
raise HTTPException(status_code=500, detail="No audio data in Gemini TTS response.")
result = b"".join(audio_fragments)
logger.debug(f"TTS returned {len(result)} PCM bytes")
return result
except HTTPException:
raise
except httpx.TimeoutException:
logger.error("Gemini TTS request timed out after 30s")
raise HTTPException(status_code=504, detail="Gemini TTS request timed out.")
except Exception as e:
logger.error(f"Unexpected TTS error: {type(e).__name__}: {e}")
raise HTTPException(status_code=500, detail=f"Failed to generate speech: {e}")