Newer
Older
cortex-hub / ai-hub / app / core / providers / tts / gemini.py
import os
import json
import httpx
import base64
import logging
from app.core.providers.base import TTSProvider
from fastapi import HTTPException


# Configure logging
logger = logging.getLogger(__name__)


class GeminiTTSProvider(TTSProvider):
    """TTS provider using Gemini's audio responseModalities via Google AI Studio."""

    AVAILABLE_VOICES = [
        "Zephyr", "Puck", "Charon", "Kore", "Fenrir", "Leda",
        "Orus", "Aoede", "Callirrhoe", "Autonoe", "Enceladus",
        "Iapetus", "Umbriel", "Algieba", "Despina", "Erinome",
        "Algenib", "Rasalgethi", "Laomedeia", "Achernar", "Alnilam",
        "Schedar", "Gacrux", "Pulcherrima", "Achird", "Zubenelgenubi",
        "Vindemiatrix", "Sadachbia", "Sadaltager", "Sulafat"
    ]

    def __init__(self, api_key: str, model_name: str = "gemini-2.5-flash-preview-tts",
                 voice_name: str = "Kore", **kwargs):
        raw_model = model_name or "gemini-2.5-flash-preview-tts"
        # Strip any provider prefix (e.g. "vertex_ai/model" or "gemini/model") → keep only the model id
        model_id = raw_model.split("/")[-1]
        # Normalise short names: "gemini-2-flash-tts" → "gemini-2.5-flash-preview-tts"
        if model_id in ("gemini-2-flash-tts", "gemini-2.5-flash-tts", "flash-tts"):
            model_id = "gemini-2.5-flash-preview-tts"
            logger.info(f"Normalised model name to: {model_id}")

        # Route to Vertex AI ONLY when the key is a Vertex service-account key (starting with "AQ.")
        # AI Studio keys start with "AIza" and must use the generativelanguage endpoint.
        is_vertex_key = bool(api_key) and api_key.startswith("AQ.")

        if is_vertex_key:
            self.api_url = (
                f"https://us-central1-aiplatform.googleapis.com/v1/publishers/google/"
                f"models/{model_id}:streamGenerateContent"
            )
            self.is_vertex = True
        else:
            # Google AI Studio — v1beta is required for audio responseModalities
            self.api_url = (
                f"https://generativelanguage.googleapis.com/v1beta/models/"
                f"{model_id}:generateContent?key={api_key}"
            )
            self.is_vertex = False

        self.api_key = api_key
        self.voice_name = voice_name
        self.model_name = model_id
        logger.debug(f"GeminiTTSProvider: model={self.model_name}, vertex={self.is_vertex}")
        logger.debug(f"  endpoint: {self.api_url[:80]}...")

    async def generate_speech(self, text: str) -> bytes:
        logger.debug(f"TTS generate_speech: '{text[:60]}...'")

        headers = {"Content-Type": "application/json"}

        # The dedicated TTS models require a system instruction to produce only audio
        json_data = {
            "system_instruction": {
                "parts": [{"text": "You are a text-to-speech system. Convert the user text to speech audio only. Do not generate any text response."}]
            },
            "contents": [{"role": "user", "parts": [{"text": text}]}],
            "generationConfig": {
                "responseModalities": ["AUDIO"],
                "speechConfig": {
                    "voiceConfig": {
                        "prebuiltVoiceConfig": {
                            "voiceName": self.voice_name
                        }
                    }
                }
            }
        }

        if not self.is_vertex:
            headers["x-goog-api-key"] = self.api_key

        logger.debug(f"Calling: {self.api_url}")

        try:
            async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client:
                response = await client.post(self.api_url, headers=headers, json=json_data)

            logger.debug(f"Response status: {response.status_code}")

            if response.status_code != 200:
                body = response.text
                logger.error(f"TTS API error {response.status_code}: {body[:300]}")
                try:
                    err = response.json().get("error", {})
                    msg = err.get("message", body[:200])
                except Exception:
                    msg = body[:200]
                raise HTTPException(status_code=response.status_code, detail=f"Gemini TTS error: {msg}")

            resp_data = response.json()
            audio_fragments = []

            # Handle both list (streamGenerateContent) and single object (generateContent)
            segments = resp_data if isinstance(resp_data, list) else [resp_data]
            for segment in segments:
                candidates = segment.get("candidates", [])
                if candidates:
                    parts = candidates[0].get("content", {}).get("parts", [])
                    for part in parts:
                        inline = part.get("inlineData", {})
                        data = inline.get("data")
                        if data:
                            audio_fragments.append(base64.b64decode(data))

            if not audio_fragments:
                logger.error(f"No audio in response. Full response: {json.dumps(resp_data)[:500]}")
                raise HTTPException(status_code=500, detail="No audio data in Gemini TTS response.")

            result = b"".join(audio_fragments)
            logger.debug(f"TTS returned {len(result)} PCM bytes")
            return result

        except HTTPException:
            raise
        except httpx.TimeoutException:
            logger.error("Gemini TTS request timed out after 30s")
            raise HTTPException(status_code=504, detail="Gemini TTS request timed out.")
        except Exception as e:
            logger.error(f"Unexpected TTS error: {type(e).__name__}: {e}")
            raise HTTPException(status_code=500, detail=f"Failed to generate speech: {e}")