Newer
Older
cortex-hub / ai-hub / app / core / providers / stt / gemini.py
import os
import aiohttp
import base64
import logging
from typing import Optional
from fastapi import HTTPException
from app.core.providers.base import STTProvider

# Configure logging
logger = logging.getLogger(__name__)


class GoogleSTTProvider(STTProvider):
    """Concrete STT provider for Google Gemini API using inline audio data."""

    def __init__(self, api_key: Optional[str] = None, model_name: str = 'gemini-1.5-flash', **kwargs):
        self.api_key = api_key or os.getenv('GEMINI_API_KEY')
        if not self.api_key:
            raise ValueError('GEMINI_API_KEY environment variable not set or provided.')

        clean_model = model_name or 'gemini-1.5-flash'
        model_id = clean_model.split('/')[-1]
        self.model_name = model_id

        # Use v1beta — the only endpoint that supports audio inline_data with Gemini 2.x
        self.api_url = (
            f'https://generativelanguage.googleapis.com/v1beta/models/'
            f'{model_id}:generateContent?key={self.api_key}'
        )

        logger.debug(f"Initialized GoogleSTTProvider: model={self.model_name}")

    def _detect_mime(self, data: bytes) -> str:
        """Sniff the audio byte signature to determine the real MIME type."""
        if data[:4] == b'RIFF':
            return 'audio/wav'
        elif data[:4] == b'\x1aE\xdf\xa3':
            return 'audio/webm'
        elif data[:3] == b'ID3' or (len(data) > 1 and data[:2] == b'\xff\xfb'):
            return 'audio/mpeg'
        elif data[:4] == b'OggS':
            return 'audio/ogg'
        elif len(data) > 8 and data[4:8] == b'ftyp':
            return 'audio/mp4'
        elif len(data) > 1 and data[:2] == b'\x1a\x45':
            return 'audio/webm'
        # Default: browsers record as webm
        return 'audio/webm'

    async def transcribe_audio(self, audio_data: bytes) -> str:
        """Transcribes audio using Gemini's inline_data approach (no Files API needed)."""
        logger.debug("Starting transcription process.")

        mime_type = self._detect_mime(audio_data)
        logger.debug(f"Detected MIME type: {mime_type}, size: {len(audio_data)} bytes.")

        # Encode audio as base64 for inline submission
        audio_b64 = base64.b64encode(audio_data).decode('utf-8')

        payload = {
            "contents": [
                {
                    "role": "user",
                    "parts": [
                        {
                            "inline_data": {
                                "mime_type": mime_type,
                                "data": audio_b64
                            }
                        },
                        {"text": "Transcribe this audio. Return only the spoken words, nothing else."}
                    ]
                }
            ]
        }

        headers = {"Content-Type": "application/json"}
        logger.debug(f"Sending inline audio to: {self.api_url}")

        try:
            timeout = aiohttp.ClientTimeout(total=30)
            async with aiohttp.ClientSession(timeout=timeout) as session:
                async with session.post(self.api_url, headers=headers, json=payload) as response:
                    logger.debug(f"Transcription response status: {response.status}")
                    if not response.ok:
                        body = await response.text()
                        logger.error(f"STT API error {response.status}: {body}")
                        raise HTTPException(
                            status_code=500,
                            detail=f"API failed ({response.status}): {body[:300]}"
                        )
                    data = await response.json()

            try:
                candidate = data["candidates"][0]
                parts = candidate.get("content", {}).get("parts", [])
                if not parts:
                    # Gemini returns no parts for silent/empty audio - that's fine
                    logger.debug("Gemini returned no transcript parts (likely silence).")
                    return ""
                transcript = parts[0].get("text", "")
                logger.debug(f"Transcript: '{transcript[:80]}'")
                return transcript.strip()
            except (KeyError, IndexError) as e:
                logger.error(f"Malformed API response: {e}. Full: {data}")
                raise HTTPException(status_code=500, detail="Malformed API response from Gemini STT.")

        except aiohttp.ClientError as e:
            logger.error(f"Network error during STT: {e}")
            raise HTTPException(status_code=500, detail=f"API request failed: {e}")
        except HTTPException:
            raise
        except Exception as e:
            logger.error(f"Unexpected STT error: {e}")
            raise HTTPException(status_code=500, detail=f"Failed to transcribe audio: {e}")