Newer
Older
cortex-hub / ai-hub / app / core / providers / tts / gemini.py
import os
import aiohttp
import asyncio
import base64
from typing import AsyncGenerator
from app.core.providers.base import TTSProvider

# New concrete class for Gemini TTS with the corrected voice list
class GeminiTTSProvider(TTSProvider):
    # Class attribute with the corrected list of available voices
    AVAILABLE_VOICES = [
        "Zephyr", "Puck", "Charon", "Kore", "Fenrir", "Leda", 
        "Orus", "Aoede", "Callirrhoe", "Autonoe", "Enceladus", 
        "Iapetus", "Umbriel", "Algieba", "Despina", "Erinome", 
        "Algenib", "Rasalgethi", "Laomedeia", "Achernar", "Alnilam", 
        "Schedar", "Gacrux", "Pulcherrima", "Achird", "Zubenelgenubi", 
        "Vindemiatrix", "Sadachbia", "Sadaltager", "Sulafat"
    ]

    def __init__(self, api_key: str, voice_name: str = "Kore", model_name: str = "gemini-2.5-flash-preview-tts"):
        if voice_name not in self.AVAILABLE_VOICES:
            raise ValueError(f"Invalid voice name: {voice_name}. Choose from {self.AVAILABLE_VOICES}")
            
        self.api_key = api_key
        # The API URL is now a f-string that includes the configurable model name
        self.api_url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent"
        self.voice_name = voice_name
        self.model_name = model_name

    async def generate_speech(self, text: str) -> AsyncGenerator[bytes, None]:
        headers = {
            "x-goog-api-key": self.api_key,
            "Content-Type": "application/json"
        }
        json_data = {
            "contents": [{
                "parts": [{
                    "text": text
                }]
            }],
            "generationConfig": {
                "responseModalities": ["AUDIO"],
                "speechConfig": {
                    "voiceConfig": {
                        "prebuiltVoiceConfig": {
                            "voiceName": self.voice_name
                        }
                    }
                }
            },
            # The model is now configurable via the instance variable
            "model": self.model_name
        }

        async with aiohttp.ClientSession() as session:
            async with session.post(self.api_url, headers=headers, json=json_data) as response:
                response.raise_for_status()
                response_json = await response.json()
                
                inline_data = response_json['candidates'][0]['content']['parts'][0]['inlineData']['data']
                audio_bytes = base64.b64decode(inline_data)
                
                yield audio_bytes