import os
import aiohttp
import base64
import logging
from typing import Optional
from fastapi import HTTPException
from app.core.providers.base import STTProvider
# Configure logging
logger = logging.getLogger(__name__)
class GoogleSTTProvider(STTProvider):
"""Concrete STT provider for Google Gemini API using inline audio data."""
def __init__(self, api_key: Optional[str] = None, model_name: str = 'gemini-1.5-flash', **kwargs):
self.api_key = api_key or os.getenv('GEMINI_API_KEY')
if not self.api_key:
raise ValueError('GEMINI_API_KEY environment variable not set or provided.')
clean_model = model_name or 'gemini-1.5-flash'
model_id = clean_model.split('/')[-1]
self.model_name = model_id
# Use v1beta — the only endpoint that supports audio inline_data with Gemini 2.x
self.api_url = (
f'https://generativelanguage.googleapis.com/v1beta/models/'
f'{model_id}:generateContent?key={self.api_key}'
)
logger.debug(f"Initialized GoogleSTTProvider: model={self.model_name}")
def _detect_mime(self, data: bytes) -> str:
"""Sniff the audio byte signature to determine the real MIME type."""
if data[:4] == b'RIFF':
return 'audio/wav'
elif data[:4] == b'\x1aE\xdf\xa3':
return 'audio/webm'
elif data[:3] == b'ID3' or (len(data) > 1 and data[:2] == b'\xff\xfb'):
return 'audio/mpeg'
elif data[:4] == b'OggS':
return 'audio/ogg'
elif len(data) > 8 and data[4:8] == b'ftyp':
return 'audio/mp4'
elif len(data) > 1 and data[:2] == b'\x1a\x45':
return 'audio/webm'
# Default: browsers record as webm
return 'audio/webm'
async def transcribe_audio(self, audio_data: bytes) -> str:
"""Transcribes audio using Gemini's inline_data approach (no Files API needed)."""
logger.debug("Starting transcription process.")
mime_type = self._detect_mime(audio_data)
logger.debug(f"Detected MIME type: {mime_type}, size: {len(audio_data)} bytes.")
# Encode audio as base64 for inline submission
audio_b64 = base64.b64encode(audio_data).decode('utf-8')
payload = {
"contents": [
{
"role": "user",
"parts": [
{
"inline_data": {
"mime_type": mime_type,
"data": audio_b64
}
},
{"text": "Transcribe this audio. Return only the spoken words, nothing else."}
]
}
]
}
headers = {"Content-Type": "application/json"}
logger.debug(f"Sending inline audio to: {self.api_url}")
try:
timeout = aiohttp.ClientTimeout(total=30)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(self.api_url, headers=headers, json=payload) as response:
logger.debug(f"Transcription response status: {response.status}")
if not response.ok:
body = await response.text()
logger.error(f"STT API error {response.status}: {body}")
raise HTTPException(
status_code=500,
detail=f"API failed ({response.status}): {body[:300]}"
)
data = await response.json()
try:
candidate = data["candidates"][0]
parts = candidate.get("content", {}).get("parts", [])
if not parts:
# Gemini returns no parts for silent/empty audio - that's fine
logger.debug("Gemini returned no transcript parts (likely silence).")
return ""
transcript = parts[0].get("text", "")
logger.debug(f"Transcript: '{transcript[:80]}'")
return transcript.strip()
except (KeyError, IndexError) as e:
logger.error(f"Malformed API response: {e}. Full: {data}")
raise HTTPException(status_code=500, detail="Malformed API response from Gemini STT.")
except aiohttp.ClientError as e:
logger.error(f"Network error during STT: {e}")
raise HTTPException(status_code=500, detail=f"API request failed: {e}")
except HTTPException:
raise
except Exception as e:
logger.error(f"Unexpected STT error: {e}")
raise HTTPException(status_code=500, detail=f"Failed to transcribe audio: {e}")