Newer
Older
cortex-hub / ai-hub / app / core / providers / tts / gemini.py
import os
import aiohttp
import asyncio
import base64
import logging
from typing import AsyncGenerator
from app.core.providers.base import TTSProvider

# Configure logging
logger = logging.getLogger(__name__)

# New concrete class for Gemini TTS with the corrected voice list
class GeminiTTSProvider(TTSProvider):
    # Class attribute with the corrected list of available voices
    AVAILABLE_VOICES = [
        "Zephyr", "Puck", "Charon", "Kore", "Fenrir", "Leda", 
        "Orus", "Aoede", "Callirrhoe", "Autonoe", "Enceladus", 
        "Iapetus", "Umbriel", "Algieba", "Despina", "Erinome", 
        "Algenib", "Rasalgethi", "Laomedeia", "Achernar", "Alnilam", 
        "Schedar", "Gacrux", "Pulcherrima", "Achird", "Zubenelgenubi", 
        "Vindemiatrix", "Sadachbia", "Sadaltager", "Sulafat"
    ]

    def __init__(self, api_key: str, voice_name: str = "Kore", model_name: str = "gemini-2.5-flash-preview-tts"):
        if voice_name not in self.AVAILABLE_VOICES:
            raise ValueError(f"Invalid voice name: {voice_name}. Choose from {self.AVAILABLE_VOICES}")
            
        self.api_key = api_key
        # The API URL is now a f-string that includes the configurable model name
        self.api_url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent"
        self.voice_name = voice_name
        self.model_name = model_name
        logger.debug(f"Initialized GeminiTTSProvider with model: {self.model_name}, voice: {self.voice_name}")

    async def generate_speech(self, text: str) -> bytes:
        logger.debug(f"Starting speech generation for text: '{text[:50]}...'")
        
        headers = {
            "x-goog-api-key": self.api_key,
            "Content-Type": "application/json"
        }
        json_data = {
            "contents": [{
                "parts": [{
                    "text": text
                }]
            }],
            "generationConfig": {
                "responseModalities": ["AUDIO"],
                "speechConfig": {
                    "voiceConfig": {
                        "prebuiltVoiceConfig": {
                            "voiceName": self.voice_name
                        }
                    }
                }
            },
            # The model is now configurable via the instance variable
            "model": self.model_name
        }
        
        logger.debug(f"API Request URL: {self.api_url}")
        logger.debug(f"Request Headers: {headers}")
        logger.debug(f"Request Payload: {json_data}")

        try:
            async with aiohttp.ClientSession() as session:
                async with session.post(self.api_url, headers=headers, json=json_data) as response:
                    logger.debug(f"Received API response with status code: {response.status}")
                    response.raise_for_status()
                    
                    response_json = await response.json()
                    logger.debug("Successfully parsed API response JSON.")
                    
                    inline_data = response_json['candidates'][0]['content']['parts'][0]['inlineData']['data']
                    logger.debug("Successfully extracted audio data from JSON response.")
                    
                    audio_bytes = base64.b64decode(inline_data)
                    logger.debug(f"Decoded audio data, size: {len(audio_bytes)} bytes.")
                    
                    return audio_bytes
        except aiohttp.ClientError as e:
            logger.error(f"Aiohttp client error occurred: {e}")
            raise HTTPException(status_code=500, detail=f"API request failed: {e}")
        except KeyError as e:
            logger.error(f"Key error in API response: {e}. Full response: {response_json}")
            raise HTTPException(status_code=500, detail="Malformed API response from Gemini.")
        except Exception as e:
            logger.error(f"An unexpected error occurred during speech generation: {e}")
            raise HTTPException(status_code=500, detail=f"Failed to generate speech: {e}")