Newer
Older
cortex-hub / ai-hub / app / core / providers / stt / gemini.py
import os
import aiohttp
import asyncio
import logging
import mimetypes
from typing import Optional
from fastapi import HTTPException
from app.core.providers.base import STTProvider

# Configure logging
logger = logging.getLogger(__name__)

class GoogleSTTProvider(STTProvider):
    """Concrete STT provider for Google Gemini API."""

    def __init__(
        self,
        api_key: Optional[str] = None,
        model_name: str = "gemini-2.5-flash"
    ):
        self.api_key = api_key or os.getenv("GEMINI_API_KEY")
        if not self.api_key:
            raise ValueError("GEMINI_API_KEY environment variable not set or provided.")

        self.model_name = model_name
        self.api_url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent"
        self.upload_url_base = "https://generativelanguage.googleapis.com/upload/v1beta/files"

        logger.debug(f"Initialized GoogleSTTProvider with model: {self.model_name}")

    async def transcribe_audio(self, audio_data: bytes) -> str:
        logger.debug("Starting transcription process.")

        mime_type = mimetypes.guess_type("audio.wav")[0] or "application/octet-stream"
        num_bytes = len(audio_data)
        logger.debug(f"Detected MIME type: {mime_type}, size: {num_bytes} bytes.")

        try:
            async with aiohttp.ClientSession() as session:
                # Step 1: Start resumable upload
                logger.debug("Starting resumable upload...")
                start_headers = {
                    "x-goog-api-key": self.api_key,
                    "X-Goog-Upload-Protocol": "resumable",
                    "X-Goog-Upload-Command": "start",
                    "X-Goog-Upload-Header-Content-Length": str(num_bytes),
                    "X-Goog-Upload-Header-Content-Type": mime_type,
                    "Content-Type": "application/json",
                }
                start_payload = {"file": {"display_name": "AUDIO"}}

                async with session.post(
                    self.upload_url_base,
                    headers=start_headers,
                    json=start_payload
                ) as resp:
                    logger.debug(f"Upload start response status: {resp.status}")
                    resp.raise_for_status()
                    upload_url = resp.headers.get("X-Goog-Upload-URL")
                    if not upload_url:
                        raise HTTPException(status_code=500, detail="No upload URL returned from Google API.")
                logger.debug(f"Received upload URL: {upload_url}")

                # Step 2: Upload the file
                logger.debug("Uploading audio file...")
                upload_headers = {
                    "Content-Length": str(num_bytes),
                    "X-Goog-Upload-Offset": "0",
                    "X-Goog-Upload-Command": "upload, finalize",
                }
                async with session.post(upload_url, headers=upload_headers, data=audio_data) as resp:
                    logger.debug(f"File upload response status: {resp.status}")
                    resp.raise_for_status()
                    file_info = await resp.json()

                file_name = file_info["file"]["name"].split("/")[-1]
                file_uri = f"https://generativelanguage.googleapis.com/v1beta/files/{file_name}"
                logger.debug(f"Uploaded file URI: {file_uri}")

                # Step 3: Request transcription
                logger.debug("Requesting transcription from Gemini API...")
                transcription_headers = {
                    "x-goog-api-key": self.api_key,
                    "Content-Type": "application/json",
                }
                transcription_payload = {
                    "contents": [
                        {
                            "parts": [
                                {
                                    "fileData": {
                                        "mimeType": mime_type,
                                        "fileUri": file_uri
                                    }
                                },
                                {"text": "Transcribe this audio file."}
                            ]
                        }
                    ]
                }

                async with session.post(
                    self.api_url,
                    headers=transcription_headers,
                    json=transcription_payload
                ) as resp:
                    logger.debug(f"Transcription request status: {resp.status}")
                    resp.raise_for_status()
                    data = await resp.json()

            # Step 4: Extract text
            try:
                transcript = data["candidates"][0]["content"]["parts"][0]["text"]
                logger.debug(f"Successfully extracted transcript: '{transcript[:50]}...'")
                return transcript
            except (KeyError, IndexError) as e:
                logger.error(f"Malformed API response: {e}. Full response: {data}")
                raise HTTPException(status_code=500, detail="Malformed API response from Gemini.")

        except aiohttp.ClientError as e:
            logger.error(f"Aiohttp client error occurred: {e}")
            raise HTTPException(status_code=500, detail=f"API request failed: {e}")
        except Exception as e:
            logger.error(f"Unexpected error occurred during transcription: {e}")
            raise HTTPException(status_code=500, detail=f"Failed to transcribe audio: {e}")