diff --git a/ai-hub/app/core/providers/stt/gemini.py b/ai-hub/app/core/providers/stt/gemini.py new file mode 100644 index 0000000..0bc57d6 --- /dev/null +++ b/ai-hub/app/core/providers/stt/gemini.py @@ -0,0 +1,125 @@ +import os +import aiohttp +import asyncio +import logging +import mimetypes +from typing import Optional +from fastapi import HTTPException +from app.core.providers.base import STTProvider + +# Configure logging +logger = logging.getLogger(__name__) + +class GoogleSTTProvider(STTProvider): + """Concrete STT provider for Google Gemini API.""" + + def __init__( + self, + api_key: Optional[str] = None, + model_name: str = "gemini-2.5-flash" + ): + self.api_key = api_key or os.getenv("GEMINI_API_KEY") + if not self.api_key: + raise ValueError("GEMINI_API_KEY environment variable not set or provided.") + + self.model_name = model_name + self.api_url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent" + self.upload_url_base = "https://generativelanguage.googleapis.com/upload/v1beta/files" + + logger.debug(f"Initialized GoogleSTTProvider with model: {self.model_name}") + + async def transcribe_audio(self, audio_data: bytes) -> str: + logger.debug("Starting transcription process.") + + mime_type = mimetypes.guess_type("audio.wav")[0] or "application/octet-stream" + num_bytes = len(audio_data) + logger.debug(f"Detected MIME type: {mime_type}, size: {num_bytes} bytes.") + + try: + async with aiohttp.ClientSession() as session: + # Step 1: Start resumable upload + logger.debug("Starting resumable upload...") + start_headers = { + "x-goog-api-key": self.api_key, + "X-Goog-Upload-Protocol": "resumable", + "X-Goog-Upload-Command": "start", + "X-Goog-Upload-Header-Content-Length": str(num_bytes), + "X-Goog-Upload-Header-Content-Type": mime_type, + "Content-Type": "application/json", + } + start_payload = {"file": {"display_name": "AUDIO"}} + + async with session.post( + self.upload_url_base, + headers=start_headers, + json=start_payload + ) as resp: + logger.debug(f"Upload start response status: {resp.status}") + resp.raise_for_status() + upload_url = resp.headers.get("X-Goog-Upload-URL") + if not upload_url: + raise HTTPException(status_code=500, detail="No upload URL returned from Google API.") + logger.debug(f"Received upload URL: {upload_url}") + + # Step 2: Upload the file + logger.debug("Uploading audio file...") + upload_headers = { + "Content-Length": str(num_bytes), + "X-Goog-Upload-Offset": "0", + "X-Goog-Upload-Command": "upload, finalize", + } + async with session.post(upload_url, headers=upload_headers, data=audio_data) as resp: + logger.debug(f"File upload response status: {resp.status}") + resp.raise_for_status() + file_info = await resp.json() + + file_name = file_info["file"]["name"].split("/")[-1] + file_uri = f"https://generativelanguage.googleapis.com/v1beta/files/{file_name}" + logger.debug(f"Uploaded file URI: {file_uri}") + + # Step 3: Request transcription + logger.debug("Requesting transcription from Gemini API...") + transcription_headers = { + "x-goog-api-key": self.api_key, + "Content-Type": "application/json", + } + transcription_payload = { + "contents": [ + { + "parts": [ + { + "fileData": { + "mimeType": mime_type, + "fileUri": file_uri + } + }, + {"text": "Transcribe this audio file."} + ] + } + ] + } + + async with session.post( + self.api_url, + headers=transcription_headers, + json=transcription_payload + ) as resp: + logger.debug(f"Transcription request status: {resp.status}") + resp.raise_for_status() + data = await resp.json() + + # Step 4: Extract text + try: + transcript = data["candidates"][0]["content"]["parts"][0]["text"] + logger.debug(f"Successfully extracted transcript: '{transcript[:50]}...'") + return transcript + except (KeyError, IndexError) as e: + logger.error(f"Malformed API response: {e}. Full response: {data}") + raise HTTPException(status_code=500, detail="Malformed API response from Gemini.") + + except aiohttp.ClientError as e: + logger.error(f"Aiohttp client error occurred: {e}") + raise HTTPException(status_code=500, detail=f"API request failed: {e}") + except Exception as e: + logger.error(f"Unexpected error occurred during transcription: {e}") + raise HTTPException(status_code=500, detail=f"Failed to transcribe audio: {e}") diff --git a/ai-hub/app/core/providers/stt/gemini.py b/ai-hub/app/core/providers/stt/gemini.py new file mode 100644 index 0000000..0bc57d6 --- /dev/null +++ b/ai-hub/app/core/providers/stt/gemini.py @@ -0,0 +1,125 @@ +import os +import aiohttp +import asyncio +import logging +import mimetypes +from typing import Optional +from fastapi import HTTPException +from app.core.providers.base import STTProvider + +# Configure logging +logger = logging.getLogger(__name__) + +class GoogleSTTProvider(STTProvider): + """Concrete STT provider for Google Gemini API.""" + + def __init__( + self, + api_key: Optional[str] = None, + model_name: str = "gemini-2.5-flash" + ): + self.api_key = api_key or os.getenv("GEMINI_API_KEY") + if not self.api_key: + raise ValueError("GEMINI_API_KEY environment variable not set or provided.") + + self.model_name = model_name + self.api_url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent" + self.upload_url_base = "https://generativelanguage.googleapis.com/upload/v1beta/files" + + logger.debug(f"Initialized GoogleSTTProvider with model: {self.model_name}") + + async def transcribe_audio(self, audio_data: bytes) -> str: + logger.debug("Starting transcription process.") + + mime_type = mimetypes.guess_type("audio.wav")[0] or "application/octet-stream" + num_bytes = len(audio_data) + logger.debug(f"Detected MIME type: {mime_type}, size: {num_bytes} bytes.") + + try: + async with aiohttp.ClientSession() as session: + # Step 1: Start resumable upload + logger.debug("Starting resumable upload...") + start_headers = { + "x-goog-api-key": self.api_key, + "X-Goog-Upload-Protocol": "resumable", + "X-Goog-Upload-Command": "start", + "X-Goog-Upload-Header-Content-Length": str(num_bytes), + "X-Goog-Upload-Header-Content-Type": mime_type, + "Content-Type": "application/json", + } + start_payload = {"file": {"display_name": "AUDIO"}} + + async with session.post( + self.upload_url_base, + headers=start_headers, + json=start_payload + ) as resp: + logger.debug(f"Upload start response status: {resp.status}") + resp.raise_for_status() + upload_url = resp.headers.get("X-Goog-Upload-URL") + if not upload_url: + raise HTTPException(status_code=500, detail="No upload URL returned from Google API.") + logger.debug(f"Received upload URL: {upload_url}") + + # Step 2: Upload the file + logger.debug("Uploading audio file...") + upload_headers = { + "Content-Length": str(num_bytes), + "X-Goog-Upload-Offset": "0", + "X-Goog-Upload-Command": "upload, finalize", + } + async with session.post(upload_url, headers=upload_headers, data=audio_data) as resp: + logger.debug(f"File upload response status: {resp.status}") + resp.raise_for_status() + file_info = await resp.json() + + file_name = file_info["file"]["name"].split("/")[-1] + file_uri = f"https://generativelanguage.googleapis.com/v1beta/files/{file_name}" + logger.debug(f"Uploaded file URI: {file_uri}") + + # Step 3: Request transcription + logger.debug("Requesting transcription from Gemini API...") + transcription_headers = { + "x-goog-api-key": self.api_key, + "Content-Type": "application/json", + } + transcription_payload = { + "contents": [ + { + "parts": [ + { + "fileData": { + "mimeType": mime_type, + "fileUri": file_uri + } + }, + {"text": "Transcribe this audio file."} + ] + } + ] + } + + async with session.post( + self.api_url, + headers=transcription_headers, + json=transcription_payload + ) as resp: + logger.debug(f"Transcription request status: {resp.status}") + resp.raise_for_status() + data = await resp.json() + + # Step 4: Extract text + try: + transcript = data["candidates"][0]["content"]["parts"][0]["text"] + logger.debug(f"Successfully extracted transcript: '{transcript[:50]}...'") + return transcript + except (KeyError, IndexError) as e: + logger.error(f"Malformed API response: {e}. Full response: {data}") + raise HTTPException(status_code=500, detail="Malformed API response from Gemini.") + + except aiohttp.ClientError as e: + logger.error(f"Aiohttp client error occurred: {e}") + raise HTTPException(status_code=500, detail=f"API request failed: {e}") + except Exception as e: + logger.error(f"Unexpected error occurred during transcription: {e}") + raise HTTPException(status_code=500, detail=f"Failed to transcribe audio: {e}") diff --git a/ai-hub/app/core/services/stt.py b/ai-hub/app/core/services/stt.py new file mode 100644 index 0000000..e1d558b --- /dev/null +++ b/ai-hub/app/core/services/stt.py @@ -0,0 +1,45 @@ +import logging +from fastapi import HTTPException +from app.core.providers.base import STTProvider + +# Configure logging +logger = logging.getLogger(__name__) + +class STTService: + """ + Service class for transcribing audio into text using an STT provider. + """ + + def __init__(self, stt_provider: STTProvider): + """ + Initializes the STTService with a concrete STT provider. + """ + self.stt_provider = stt_provider + + async def transcribe(self, audio_bytes: bytes) -> str: + """ + Transcribes the provided audio bytes into text using the STT provider. + """ + logger.info(f"Starting transcription for audio data ({len(audio_bytes)} bytes).") + + if not audio_bytes: + logger.warning("No audio data provided for transcription.") + raise HTTPException(status_code=400, detail="No audio data provided.") + + try: + transcript = await self.stt_provider.transcribe_audio(audio_bytes) + if not transcript: + logger.warning("STT provider returned an empty transcript.") + raise HTTPException(status_code=500, detail="Failed to transcribe audio.") + + logger.info(f"Successfully transcribed audio. Transcript length: {len(transcript)} characters.") + return transcript + + except HTTPException: + raise # Pass through existing HTTPException without wrapping + except Exception as e: + logger.error(f"Unexpected error during transcription: {e}") + raise HTTPException( + status_code=500, + detail=f"Error during transcription: {e}" + ) from e diff --git a/ai-hub/app/core/providers/stt/gemini.py b/ai-hub/app/core/providers/stt/gemini.py new file mode 100644 index 0000000..0bc57d6 --- /dev/null +++ b/ai-hub/app/core/providers/stt/gemini.py @@ -0,0 +1,125 @@ +import os +import aiohttp +import asyncio +import logging +import mimetypes +from typing import Optional +from fastapi import HTTPException +from app.core.providers.base import STTProvider + +# Configure logging +logger = logging.getLogger(__name__) + +class GoogleSTTProvider(STTProvider): + """Concrete STT provider for Google Gemini API.""" + + def __init__( + self, + api_key: Optional[str] = None, + model_name: str = "gemini-2.5-flash" + ): + self.api_key = api_key or os.getenv("GEMINI_API_KEY") + if not self.api_key: + raise ValueError("GEMINI_API_KEY environment variable not set or provided.") + + self.model_name = model_name + self.api_url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent" + self.upload_url_base = "https://generativelanguage.googleapis.com/upload/v1beta/files" + + logger.debug(f"Initialized GoogleSTTProvider with model: {self.model_name}") + + async def transcribe_audio(self, audio_data: bytes) -> str: + logger.debug("Starting transcription process.") + + mime_type = mimetypes.guess_type("audio.wav")[0] or "application/octet-stream" + num_bytes = len(audio_data) + logger.debug(f"Detected MIME type: {mime_type}, size: {num_bytes} bytes.") + + try: + async with aiohttp.ClientSession() as session: + # Step 1: Start resumable upload + logger.debug("Starting resumable upload...") + start_headers = { + "x-goog-api-key": self.api_key, + "X-Goog-Upload-Protocol": "resumable", + "X-Goog-Upload-Command": "start", + "X-Goog-Upload-Header-Content-Length": str(num_bytes), + "X-Goog-Upload-Header-Content-Type": mime_type, + "Content-Type": "application/json", + } + start_payload = {"file": {"display_name": "AUDIO"}} + + async with session.post( + self.upload_url_base, + headers=start_headers, + json=start_payload + ) as resp: + logger.debug(f"Upload start response status: {resp.status}") + resp.raise_for_status() + upload_url = resp.headers.get("X-Goog-Upload-URL") + if not upload_url: + raise HTTPException(status_code=500, detail="No upload URL returned from Google API.") + logger.debug(f"Received upload URL: {upload_url}") + + # Step 2: Upload the file + logger.debug("Uploading audio file...") + upload_headers = { + "Content-Length": str(num_bytes), + "X-Goog-Upload-Offset": "0", + "X-Goog-Upload-Command": "upload, finalize", + } + async with session.post(upload_url, headers=upload_headers, data=audio_data) as resp: + logger.debug(f"File upload response status: {resp.status}") + resp.raise_for_status() + file_info = await resp.json() + + file_name = file_info["file"]["name"].split("/")[-1] + file_uri = f"https://generativelanguage.googleapis.com/v1beta/files/{file_name}" + logger.debug(f"Uploaded file URI: {file_uri}") + + # Step 3: Request transcription + logger.debug("Requesting transcription from Gemini API...") + transcription_headers = { + "x-goog-api-key": self.api_key, + "Content-Type": "application/json", + } + transcription_payload = { + "contents": [ + { + "parts": [ + { + "fileData": { + "mimeType": mime_type, + "fileUri": file_uri + } + }, + {"text": "Transcribe this audio file."} + ] + } + ] + } + + async with session.post( + self.api_url, + headers=transcription_headers, + json=transcription_payload + ) as resp: + logger.debug(f"Transcription request status: {resp.status}") + resp.raise_for_status() + data = await resp.json() + + # Step 4: Extract text + try: + transcript = data["candidates"][0]["content"]["parts"][0]["text"] + logger.debug(f"Successfully extracted transcript: '{transcript[:50]}...'") + return transcript + except (KeyError, IndexError) as e: + logger.error(f"Malformed API response: {e}. Full response: {data}") + raise HTTPException(status_code=500, detail="Malformed API response from Gemini.") + + except aiohttp.ClientError as e: + logger.error(f"Aiohttp client error occurred: {e}") + raise HTTPException(status_code=500, detail=f"API request failed: {e}") + except Exception as e: + logger.error(f"Unexpected error occurred during transcription: {e}") + raise HTTPException(status_code=500, detail=f"Failed to transcribe audio: {e}") diff --git a/ai-hub/app/core/services/stt.py b/ai-hub/app/core/services/stt.py new file mode 100644 index 0000000..e1d558b --- /dev/null +++ b/ai-hub/app/core/services/stt.py @@ -0,0 +1,45 @@ +import logging +from fastapi import HTTPException +from app.core.providers.base import STTProvider + +# Configure logging +logger = logging.getLogger(__name__) + +class STTService: + """ + Service class for transcribing audio into text using an STT provider. + """ + + def __init__(self, stt_provider: STTProvider): + """ + Initializes the STTService with a concrete STT provider. + """ + self.stt_provider = stt_provider + + async def transcribe(self, audio_bytes: bytes) -> str: + """ + Transcribes the provided audio bytes into text using the STT provider. + """ + logger.info(f"Starting transcription for audio data ({len(audio_bytes)} bytes).") + + if not audio_bytes: + logger.warning("No audio data provided for transcription.") + raise HTTPException(status_code=400, detail="No audio data provided.") + + try: + transcript = await self.stt_provider.transcribe_audio(audio_bytes) + if not transcript: + logger.warning("STT provider returned an empty transcript.") + raise HTTPException(status_code=500, detail="Failed to transcribe audio.") + + logger.info(f"Successfully transcribed audio. Transcript length: {len(transcript)} characters.") + return transcript + + except HTTPException: + raise # Pass through existing HTTPException without wrapping + except Exception as e: + logger.error(f"Unexpected error during transcription: {e}") + raise HTTPException( + status_code=500, + detail=f"Error during transcription: {e}" + ) from e diff --git a/ai-hub/tests/core/providers/stt/test_stt_gemini.py b/ai-hub/tests/core/providers/stt/test_stt_gemini.py new file mode 100644 index 0000000..6384c66 --- /dev/null +++ b/ai-hub/tests/core/providers/stt/test_stt_gemini.py @@ -0,0 +1,93 @@ +import pytest +import aiohttp +from unittest.mock import patch, AsyncMock +from app.core.providers.stt.gemini import GoogleSTTProvider +from fastapi import HTTPException +from unittest.mock import AsyncMock, MagicMock + +# Helper to create an async context manager mock with expected behavior +def create_async_context_manager_mock(status_code=200, headers=None, json_data=None): + mock = AsyncMock() + mock.status = status_code + mock.headers = headers or {} + mock.json = AsyncMock(return_value=json_data) + # raise_for_status is synchronous in aiohttp, so use MagicMock, not AsyncMock + mock.raise_for_status = MagicMock(return_value=None) + mock.__aenter__.return_value = mock + mock.__aexit__.return_value = None + return mock + +@pytest.mark.asyncio +@patch('aiohttp.ClientSession.post') +async def test_transcribe_audio_success(mock_post): + provider = GoogleSTTProvider(api_key="fake-key") + + mock_start_resp = create_async_context_manager_mock( + headers={"X-Goog-Upload-URL": "https://fake-upload-url"} + ) + mock_upload_resp = create_async_context_manager_mock( + json_data={"file": {"name": "files/123"}} + ) + mock_transcribe_resp = create_async_context_manager_mock( + json_data={ + "candidates": [ + { + "content": { + "parts": [{"text": "Hello world"}] + } + } + ] + } + ) + + # Make the side_effect return the mocks directly (no coroutine) + mock_post.side_effect = [ + mock_start_resp, + mock_upload_resp, + mock_transcribe_resp, + ] + + result = await provider.transcribe_audio(b"fake-bytes") + assert result == "Hello world" + assert mock_post.call_count == 3 + +@pytest.mark.asyncio +@patch('aiohttp.ClientSession.post') +async def test_transcribe_audio_no_upload_url(mock_post): + provider = GoogleSTTProvider(api_key="fake-key") + + mock_start_resp = create_async_context_manager_mock(headers={}) + mock_post.side_effect = [mock_start_resp] + + with pytest.raises(HTTPException) as exc: + await provider.transcribe_audio(b"fake-bytes") + + assert exc.value.status_code == 500 + assert "No upload URL" in exc.value.detail + assert mock_post.call_count == 1 + +@pytest.mark.asyncio +@patch('aiohttp.ClientSession.post') +async def test_transcribe_audio_malformed_response(mock_post): + provider = GoogleSTTProvider(api_key="fake-key") + + mock_start_resp = create_async_context_manager_mock( + headers={"X-Goog-Upload-URL": "https://fake-upload-url"} + ) + mock_upload_resp = create_async_context_manager_mock( + json_data={"file": {"name": "files/123"}} + ) + mock_transcribe_resp = create_async_context_manager_mock(json_data={}) + + mock_post.side_effect = [ + mock_start_resp, + mock_upload_resp, + mock_transcribe_resp, + ] + + with pytest.raises(HTTPException) as exc: + await provider.transcribe_audio(b"fake-bytes") + + assert exc.value.status_code == 500 + assert "Malformed API response" in exc.value.detail + assert mock_post.call_count == 3 diff --git a/ai-hub/app/core/providers/stt/gemini.py b/ai-hub/app/core/providers/stt/gemini.py new file mode 100644 index 0000000..0bc57d6 --- /dev/null +++ b/ai-hub/app/core/providers/stt/gemini.py @@ -0,0 +1,125 @@ +import os +import aiohttp +import asyncio +import logging +import mimetypes +from typing import Optional +from fastapi import HTTPException +from app.core.providers.base import STTProvider + +# Configure logging +logger = logging.getLogger(__name__) + +class GoogleSTTProvider(STTProvider): + """Concrete STT provider for Google Gemini API.""" + + def __init__( + self, + api_key: Optional[str] = None, + model_name: str = "gemini-2.5-flash" + ): + self.api_key = api_key or os.getenv("GEMINI_API_KEY") + if not self.api_key: + raise ValueError("GEMINI_API_KEY environment variable not set or provided.") + + self.model_name = model_name + self.api_url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent" + self.upload_url_base = "https://generativelanguage.googleapis.com/upload/v1beta/files" + + logger.debug(f"Initialized GoogleSTTProvider with model: {self.model_name}") + + async def transcribe_audio(self, audio_data: bytes) -> str: + logger.debug("Starting transcription process.") + + mime_type = mimetypes.guess_type("audio.wav")[0] or "application/octet-stream" + num_bytes = len(audio_data) + logger.debug(f"Detected MIME type: {mime_type}, size: {num_bytes} bytes.") + + try: + async with aiohttp.ClientSession() as session: + # Step 1: Start resumable upload + logger.debug("Starting resumable upload...") + start_headers = { + "x-goog-api-key": self.api_key, + "X-Goog-Upload-Protocol": "resumable", + "X-Goog-Upload-Command": "start", + "X-Goog-Upload-Header-Content-Length": str(num_bytes), + "X-Goog-Upload-Header-Content-Type": mime_type, + "Content-Type": "application/json", + } + start_payload = {"file": {"display_name": "AUDIO"}} + + async with session.post( + self.upload_url_base, + headers=start_headers, + json=start_payload + ) as resp: + logger.debug(f"Upload start response status: {resp.status}") + resp.raise_for_status() + upload_url = resp.headers.get("X-Goog-Upload-URL") + if not upload_url: + raise HTTPException(status_code=500, detail="No upload URL returned from Google API.") + logger.debug(f"Received upload URL: {upload_url}") + + # Step 2: Upload the file + logger.debug("Uploading audio file...") + upload_headers = { + "Content-Length": str(num_bytes), + "X-Goog-Upload-Offset": "0", + "X-Goog-Upload-Command": "upload, finalize", + } + async with session.post(upload_url, headers=upload_headers, data=audio_data) as resp: + logger.debug(f"File upload response status: {resp.status}") + resp.raise_for_status() + file_info = await resp.json() + + file_name = file_info["file"]["name"].split("/")[-1] + file_uri = f"https://generativelanguage.googleapis.com/v1beta/files/{file_name}" + logger.debug(f"Uploaded file URI: {file_uri}") + + # Step 3: Request transcription + logger.debug("Requesting transcription from Gemini API...") + transcription_headers = { + "x-goog-api-key": self.api_key, + "Content-Type": "application/json", + } + transcription_payload = { + "contents": [ + { + "parts": [ + { + "fileData": { + "mimeType": mime_type, + "fileUri": file_uri + } + }, + {"text": "Transcribe this audio file."} + ] + } + ] + } + + async with session.post( + self.api_url, + headers=transcription_headers, + json=transcription_payload + ) as resp: + logger.debug(f"Transcription request status: {resp.status}") + resp.raise_for_status() + data = await resp.json() + + # Step 4: Extract text + try: + transcript = data["candidates"][0]["content"]["parts"][0]["text"] + logger.debug(f"Successfully extracted transcript: '{transcript[:50]}...'") + return transcript + except (KeyError, IndexError) as e: + logger.error(f"Malformed API response: {e}. Full response: {data}") + raise HTTPException(status_code=500, detail="Malformed API response from Gemini.") + + except aiohttp.ClientError as e: + logger.error(f"Aiohttp client error occurred: {e}") + raise HTTPException(status_code=500, detail=f"API request failed: {e}") + except Exception as e: + logger.error(f"Unexpected error occurred during transcription: {e}") + raise HTTPException(status_code=500, detail=f"Failed to transcribe audio: {e}") diff --git a/ai-hub/app/core/services/stt.py b/ai-hub/app/core/services/stt.py new file mode 100644 index 0000000..e1d558b --- /dev/null +++ b/ai-hub/app/core/services/stt.py @@ -0,0 +1,45 @@ +import logging +from fastapi import HTTPException +from app.core.providers.base import STTProvider + +# Configure logging +logger = logging.getLogger(__name__) + +class STTService: + """ + Service class for transcribing audio into text using an STT provider. + """ + + def __init__(self, stt_provider: STTProvider): + """ + Initializes the STTService with a concrete STT provider. + """ + self.stt_provider = stt_provider + + async def transcribe(self, audio_bytes: bytes) -> str: + """ + Transcribes the provided audio bytes into text using the STT provider. + """ + logger.info(f"Starting transcription for audio data ({len(audio_bytes)} bytes).") + + if not audio_bytes: + logger.warning("No audio data provided for transcription.") + raise HTTPException(status_code=400, detail="No audio data provided.") + + try: + transcript = await self.stt_provider.transcribe_audio(audio_bytes) + if not transcript: + logger.warning("STT provider returned an empty transcript.") + raise HTTPException(status_code=500, detail="Failed to transcribe audio.") + + logger.info(f"Successfully transcribed audio. Transcript length: {len(transcript)} characters.") + return transcript + + except HTTPException: + raise # Pass through existing HTTPException without wrapping + except Exception as e: + logger.error(f"Unexpected error during transcription: {e}") + raise HTTPException( + status_code=500, + detail=f"Error during transcription: {e}" + ) from e diff --git a/ai-hub/tests/core/providers/stt/test_stt_gemini.py b/ai-hub/tests/core/providers/stt/test_stt_gemini.py new file mode 100644 index 0000000..6384c66 --- /dev/null +++ b/ai-hub/tests/core/providers/stt/test_stt_gemini.py @@ -0,0 +1,93 @@ +import pytest +import aiohttp +from unittest.mock import patch, AsyncMock +from app.core.providers.stt.gemini import GoogleSTTProvider +from fastapi import HTTPException +from unittest.mock import AsyncMock, MagicMock + +# Helper to create an async context manager mock with expected behavior +def create_async_context_manager_mock(status_code=200, headers=None, json_data=None): + mock = AsyncMock() + mock.status = status_code + mock.headers = headers or {} + mock.json = AsyncMock(return_value=json_data) + # raise_for_status is synchronous in aiohttp, so use MagicMock, not AsyncMock + mock.raise_for_status = MagicMock(return_value=None) + mock.__aenter__.return_value = mock + mock.__aexit__.return_value = None + return mock + +@pytest.mark.asyncio +@patch('aiohttp.ClientSession.post') +async def test_transcribe_audio_success(mock_post): + provider = GoogleSTTProvider(api_key="fake-key") + + mock_start_resp = create_async_context_manager_mock( + headers={"X-Goog-Upload-URL": "https://fake-upload-url"} + ) + mock_upload_resp = create_async_context_manager_mock( + json_data={"file": {"name": "files/123"}} + ) + mock_transcribe_resp = create_async_context_manager_mock( + json_data={ + "candidates": [ + { + "content": { + "parts": [{"text": "Hello world"}] + } + } + ] + } + ) + + # Make the side_effect return the mocks directly (no coroutine) + mock_post.side_effect = [ + mock_start_resp, + mock_upload_resp, + mock_transcribe_resp, + ] + + result = await provider.transcribe_audio(b"fake-bytes") + assert result == "Hello world" + assert mock_post.call_count == 3 + +@pytest.mark.asyncio +@patch('aiohttp.ClientSession.post') +async def test_transcribe_audio_no_upload_url(mock_post): + provider = GoogleSTTProvider(api_key="fake-key") + + mock_start_resp = create_async_context_manager_mock(headers={}) + mock_post.side_effect = [mock_start_resp] + + with pytest.raises(HTTPException) as exc: + await provider.transcribe_audio(b"fake-bytes") + + assert exc.value.status_code == 500 + assert "No upload URL" in exc.value.detail + assert mock_post.call_count == 1 + +@pytest.mark.asyncio +@patch('aiohttp.ClientSession.post') +async def test_transcribe_audio_malformed_response(mock_post): + provider = GoogleSTTProvider(api_key="fake-key") + + mock_start_resp = create_async_context_manager_mock( + headers={"X-Goog-Upload-URL": "https://fake-upload-url"} + ) + mock_upload_resp = create_async_context_manager_mock( + json_data={"file": {"name": "files/123"}} + ) + mock_transcribe_resp = create_async_context_manager_mock(json_data={}) + + mock_post.side_effect = [ + mock_start_resp, + mock_upload_resp, + mock_transcribe_resp, + ] + + with pytest.raises(HTTPException) as exc: + await provider.transcribe_audio(b"fake-bytes") + + assert exc.value.status_code == 500 + assert "Malformed API response" in exc.value.detail + assert mock_post.call_count == 3 diff --git a/ai-hub/tests/core/providers/tts/test_gemini.py b/ai-hub/tests/core/providers/tts/test_gemini.py deleted file mode 100644 index 5a5e1b2..0000000 --- a/ai-hub/tests/core/providers/tts/test_gemini.py +++ /dev/null @@ -1,81 +0,0 @@ -# Fixed test file -import pytest -import aiohttp -import asyncio -import base64 -from aioresponses import aioresponses -from app.core.providers.tts.gemini import GeminiTTSProvider -from app.core.providers.base import TTSProvider - -@pytest.mark.asyncio -async def test_generate_speech_success(): - """ - Tests that generate_speech correctly makes an API call and processes the response. - """ - api_key = "test_api_key" - text_to_speak = "Hello, world!" - model_name = "gemini-2.5-flash-preview-tts" - - # Create a dummy base64 encoded audio response - dummy_audio_bytes = b"This is a test audio stream." - dummy_base64_data = base64.b64encode(dummy_audio_bytes).decode('utf-8') - - # The mocked JSON response from the API - mock_response_json = { - "candidates": [{ - "content": { - "parts": [{ - "inlineData": { - "data": dummy_base64_data - } - }] - } - }] - } - - # Configure aioresponses to intercept the API call and return our mock data - tts_provider = GeminiTTSProvider(api_key=api_key, model_name=model_name) - with aioresponses() as m: - m.post( - tts_provider.api_url, - status=200, - payload=mock_response_json, - repeat=True - ) - - # Call the method under test, now awaiting the coroutine - audio_data = await tts_provider.generate_speech(text_to_speak) - - # Assert that the returned data is correct - assert audio_data == dummy_audio_bytes - -def test_init_with_valid_voice_name(): - """ - Tests that initialization succeeds with a valid voice name. - """ - api_key = "test_api_key" - voice_name = "Zephyr" - tts_provider = GeminiTTSProvider(api_key=api_key, voice_name=voice_name) - assert tts_provider.api_key == api_key - assert tts_provider.voice_name == voice_name - assert tts_provider.model_name == "gemini-2.5-flash-preview-tts" - assert "gemini-2.5-flash-preview-tts" in tts_provider.api_url - -def test_init_with_invalid_voice_name(): - """ - Tests that initialization fails with an invalid voice name. - """ - api_key = "test_api_key" - invalid_voice_name = "InvalidVoice" - with pytest.raises(ValueError, match="Invalid voice name"): - GeminiTTSProvider(api_key=api_key, voice_name=invalid_voice_name) - -def test_init_with_custom_model_name(): - """ - Tests that the provider can be initialized with a custom model name. - """ - api_key = "test_api_key" - custom_model_name = "gemini-tts-beta" - tts_provider = GeminiTTSProvider(api_key=api_key, model_name=custom_model_name) - assert tts_provider.model_name == custom_model_name - assert custom_model_name in tts_provider.api_url \ No newline at end of file diff --git a/ai-hub/app/core/providers/stt/gemini.py b/ai-hub/app/core/providers/stt/gemini.py new file mode 100644 index 0000000..0bc57d6 --- /dev/null +++ b/ai-hub/app/core/providers/stt/gemini.py @@ -0,0 +1,125 @@ +import os +import aiohttp +import asyncio +import logging +import mimetypes +from typing import Optional +from fastapi import HTTPException +from app.core.providers.base import STTProvider + +# Configure logging +logger = logging.getLogger(__name__) + +class GoogleSTTProvider(STTProvider): + """Concrete STT provider for Google Gemini API.""" + + def __init__( + self, + api_key: Optional[str] = None, + model_name: str = "gemini-2.5-flash" + ): + self.api_key = api_key or os.getenv("GEMINI_API_KEY") + if not self.api_key: + raise ValueError("GEMINI_API_KEY environment variable not set or provided.") + + self.model_name = model_name + self.api_url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent" + self.upload_url_base = "https://generativelanguage.googleapis.com/upload/v1beta/files" + + logger.debug(f"Initialized GoogleSTTProvider with model: {self.model_name}") + + async def transcribe_audio(self, audio_data: bytes) -> str: + logger.debug("Starting transcription process.") + + mime_type = mimetypes.guess_type("audio.wav")[0] or "application/octet-stream" + num_bytes = len(audio_data) + logger.debug(f"Detected MIME type: {mime_type}, size: {num_bytes} bytes.") + + try: + async with aiohttp.ClientSession() as session: + # Step 1: Start resumable upload + logger.debug("Starting resumable upload...") + start_headers = { + "x-goog-api-key": self.api_key, + "X-Goog-Upload-Protocol": "resumable", + "X-Goog-Upload-Command": "start", + "X-Goog-Upload-Header-Content-Length": str(num_bytes), + "X-Goog-Upload-Header-Content-Type": mime_type, + "Content-Type": "application/json", + } + start_payload = {"file": {"display_name": "AUDIO"}} + + async with session.post( + self.upload_url_base, + headers=start_headers, + json=start_payload + ) as resp: + logger.debug(f"Upload start response status: {resp.status}") + resp.raise_for_status() + upload_url = resp.headers.get("X-Goog-Upload-URL") + if not upload_url: + raise HTTPException(status_code=500, detail="No upload URL returned from Google API.") + logger.debug(f"Received upload URL: {upload_url}") + + # Step 2: Upload the file + logger.debug("Uploading audio file...") + upload_headers = { + "Content-Length": str(num_bytes), + "X-Goog-Upload-Offset": "0", + "X-Goog-Upload-Command": "upload, finalize", + } + async with session.post(upload_url, headers=upload_headers, data=audio_data) as resp: + logger.debug(f"File upload response status: {resp.status}") + resp.raise_for_status() + file_info = await resp.json() + + file_name = file_info["file"]["name"].split("/")[-1] + file_uri = f"https://generativelanguage.googleapis.com/v1beta/files/{file_name}" + logger.debug(f"Uploaded file URI: {file_uri}") + + # Step 3: Request transcription + logger.debug("Requesting transcription from Gemini API...") + transcription_headers = { + "x-goog-api-key": self.api_key, + "Content-Type": "application/json", + } + transcription_payload = { + "contents": [ + { + "parts": [ + { + "fileData": { + "mimeType": mime_type, + "fileUri": file_uri + } + }, + {"text": "Transcribe this audio file."} + ] + } + ] + } + + async with session.post( + self.api_url, + headers=transcription_headers, + json=transcription_payload + ) as resp: + logger.debug(f"Transcription request status: {resp.status}") + resp.raise_for_status() + data = await resp.json() + + # Step 4: Extract text + try: + transcript = data["candidates"][0]["content"]["parts"][0]["text"] + logger.debug(f"Successfully extracted transcript: '{transcript[:50]}...'") + return transcript + except (KeyError, IndexError) as e: + logger.error(f"Malformed API response: {e}. Full response: {data}") + raise HTTPException(status_code=500, detail="Malformed API response from Gemini.") + + except aiohttp.ClientError as e: + logger.error(f"Aiohttp client error occurred: {e}") + raise HTTPException(status_code=500, detail=f"API request failed: {e}") + except Exception as e: + logger.error(f"Unexpected error occurred during transcription: {e}") + raise HTTPException(status_code=500, detail=f"Failed to transcribe audio: {e}") diff --git a/ai-hub/app/core/services/stt.py b/ai-hub/app/core/services/stt.py new file mode 100644 index 0000000..e1d558b --- /dev/null +++ b/ai-hub/app/core/services/stt.py @@ -0,0 +1,45 @@ +import logging +from fastapi import HTTPException +from app.core.providers.base import STTProvider + +# Configure logging +logger = logging.getLogger(__name__) + +class STTService: + """ + Service class for transcribing audio into text using an STT provider. + """ + + def __init__(self, stt_provider: STTProvider): + """ + Initializes the STTService with a concrete STT provider. + """ + self.stt_provider = stt_provider + + async def transcribe(self, audio_bytes: bytes) -> str: + """ + Transcribes the provided audio bytes into text using the STT provider. + """ + logger.info(f"Starting transcription for audio data ({len(audio_bytes)} bytes).") + + if not audio_bytes: + logger.warning("No audio data provided for transcription.") + raise HTTPException(status_code=400, detail="No audio data provided.") + + try: + transcript = await self.stt_provider.transcribe_audio(audio_bytes) + if not transcript: + logger.warning("STT provider returned an empty transcript.") + raise HTTPException(status_code=500, detail="Failed to transcribe audio.") + + logger.info(f"Successfully transcribed audio. Transcript length: {len(transcript)} characters.") + return transcript + + except HTTPException: + raise # Pass through existing HTTPException without wrapping + except Exception as e: + logger.error(f"Unexpected error during transcription: {e}") + raise HTTPException( + status_code=500, + detail=f"Error during transcription: {e}" + ) from e diff --git a/ai-hub/tests/core/providers/stt/test_stt_gemini.py b/ai-hub/tests/core/providers/stt/test_stt_gemini.py new file mode 100644 index 0000000..6384c66 --- /dev/null +++ b/ai-hub/tests/core/providers/stt/test_stt_gemini.py @@ -0,0 +1,93 @@ +import pytest +import aiohttp +from unittest.mock import patch, AsyncMock +from app.core.providers.stt.gemini import GoogleSTTProvider +from fastapi import HTTPException +from unittest.mock import AsyncMock, MagicMock + +# Helper to create an async context manager mock with expected behavior +def create_async_context_manager_mock(status_code=200, headers=None, json_data=None): + mock = AsyncMock() + mock.status = status_code + mock.headers = headers or {} + mock.json = AsyncMock(return_value=json_data) + # raise_for_status is synchronous in aiohttp, so use MagicMock, not AsyncMock + mock.raise_for_status = MagicMock(return_value=None) + mock.__aenter__.return_value = mock + mock.__aexit__.return_value = None + return mock + +@pytest.mark.asyncio +@patch('aiohttp.ClientSession.post') +async def test_transcribe_audio_success(mock_post): + provider = GoogleSTTProvider(api_key="fake-key") + + mock_start_resp = create_async_context_manager_mock( + headers={"X-Goog-Upload-URL": "https://fake-upload-url"} + ) + mock_upload_resp = create_async_context_manager_mock( + json_data={"file": {"name": "files/123"}} + ) + mock_transcribe_resp = create_async_context_manager_mock( + json_data={ + "candidates": [ + { + "content": { + "parts": [{"text": "Hello world"}] + } + } + ] + } + ) + + # Make the side_effect return the mocks directly (no coroutine) + mock_post.side_effect = [ + mock_start_resp, + mock_upload_resp, + mock_transcribe_resp, + ] + + result = await provider.transcribe_audio(b"fake-bytes") + assert result == "Hello world" + assert mock_post.call_count == 3 + +@pytest.mark.asyncio +@patch('aiohttp.ClientSession.post') +async def test_transcribe_audio_no_upload_url(mock_post): + provider = GoogleSTTProvider(api_key="fake-key") + + mock_start_resp = create_async_context_manager_mock(headers={}) + mock_post.side_effect = [mock_start_resp] + + with pytest.raises(HTTPException) as exc: + await provider.transcribe_audio(b"fake-bytes") + + assert exc.value.status_code == 500 + assert "No upload URL" in exc.value.detail + assert mock_post.call_count == 1 + +@pytest.mark.asyncio +@patch('aiohttp.ClientSession.post') +async def test_transcribe_audio_malformed_response(mock_post): + provider = GoogleSTTProvider(api_key="fake-key") + + mock_start_resp = create_async_context_manager_mock( + headers={"X-Goog-Upload-URL": "https://fake-upload-url"} + ) + mock_upload_resp = create_async_context_manager_mock( + json_data={"file": {"name": "files/123"}} + ) + mock_transcribe_resp = create_async_context_manager_mock(json_data={}) + + mock_post.side_effect = [ + mock_start_resp, + mock_upload_resp, + mock_transcribe_resp, + ] + + with pytest.raises(HTTPException) as exc: + await provider.transcribe_audio(b"fake-bytes") + + assert exc.value.status_code == 500 + assert "Malformed API response" in exc.value.detail + assert mock_post.call_count == 3 diff --git a/ai-hub/tests/core/providers/tts/test_gemini.py b/ai-hub/tests/core/providers/tts/test_gemini.py deleted file mode 100644 index 5a5e1b2..0000000 --- a/ai-hub/tests/core/providers/tts/test_gemini.py +++ /dev/null @@ -1,81 +0,0 @@ -# Fixed test file -import pytest -import aiohttp -import asyncio -import base64 -from aioresponses import aioresponses -from app.core.providers.tts.gemini import GeminiTTSProvider -from app.core.providers.base import TTSProvider - -@pytest.mark.asyncio -async def test_generate_speech_success(): - """ - Tests that generate_speech correctly makes an API call and processes the response. - """ - api_key = "test_api_key" - text_to_speak = "Hello, world!" - model_name = "gemini-2.5-flash-preview-tts" - - # Create a dummy base64 encoded audio response - dummy_audio_bytes = b"This is a test audio stream." - dummy_base64_data = base64.b64encode(dummy_audio_bytes).decode('utf-8') - - # The mocked JSON response from the API - mock_response_json = { - "candidates": [{ - "content": { - "parts": [{ - "inlineData": { - "data": dummy_base64_data - } - }] - } - }] - } - - # Configure aioresponses to intercept the API call and return our mock data - tts_provider = GeminiTTSProvider(api_key=api_key, model_name=model_name) - with aioresponses() as m: - m.post( - tts_provider.api_url, - status=200, - payload=mock_response_json, - repeat=True - ) - - # Call the method under test, now awaiting the coroutine - audio_data = await tts_provider.generate_speech(text_to_speak) - - # Assert that the returned data is correct - assert audio_data == dummy_audio_bytes - -def test_init_with_valid_voice_name(): - """ - Tests that initialization succeeds with a valid voice name. - """ - api_key = "test_api_key" - voice_name = "Zephyr" - tts_provider = GeminiTTSProvider(api_key=api_key, voice_name=voice_name) - assert tts_provider.api_key == api_key - assert tts_provider.voice_name == voice_name - assert tts_provider.model_name == "gemini-2.5-flash-preview-tts" - assert "gemini-2.5-flash-preview-tts" in tts_provider.api_url - -def test_init_with_invalid_voice_name(): - """ - Tests that initialization fails with an invalid voice name. - """ - api_key = "test_api_key" - invalid_voice_name = "InvalidVoice" - with pytest.raises(ValueError, match="Invalid voice name"): - GeminiTTSProvider(api_key=api_key, voice_name=invalid_voice_name) - -def test_init_with_custom_model_name(): - """ - Tests that the provider can be initialized with a custom model name. - """ - api_key = "test_api_key" - custom_model_name = "gemini-tts-beta" - tts_provider = GeminiTTSProvider(api_key=api_key, model_name=custom_model_name) - assert tts_provider.model_name == custom_model_name - assert custom_model_name in tts_provider.api_url \ No newline at end of file diff --git a/ai-hub/tests/core/providers/tts/test_tts_gemini.py b/ai-hub/tests/core/providers/tts/test_tts_gemini.py new file mode 100644 index 0000000..5a5e1b2 --- /dev/null +++ b/ai-hub/tests/core/providers/tts/test_tts_gemini.py @@ -0,0 +1,81 @@ +# Fixed test file +import pytest +import aiohttp +import asyncio +import base64 +from aioresponses import aioresponses +from app.core.providers.tts.gemini import GeminiTTSProvider +from app.core.providers.base import TTSProvider + +@pytest.mark.asyncio +async def test_generate_speech_success(): + """ + Tests that generate_speech correctly makes an API call and processes the response. + """ + api_key = "test_api_key" + text_to_speak = "Hello, world!" + model_name = "gemini-2.5-flash-preview-tts" + + # Create a dummy base64 encoded audio response + dummy_audio_bytes = b"This is a test audio stream." + dummy_base64_data = base64.b64encode(dummy_audio_bytes).decode('utf-8') + + # The mocked JSON response from the API + mock_response_json = { + "candidates": [{ + "content": { + "parts": [{ + "inlineData": { + "data": dummy_base64_data + } + }] + } + }] + } + + # Configure aioresponses to intercept the API call and return our mock data + tts_provider = GeminiTTSProvider(api_key=api_key, model_name=model_name) + with aioresponses() as m: + m.post( + tts_provider.api_url, + status=200, + payload=mock_response_json, + repeat=True + ) + + # Call the method under test, now awaiting the coroutine + audio_data = await tts_provider.generate_speech(text_to_speak) + + # Assert that the returned data is correct + assert audio_data == dummy_audio_bytes + +def test_init_with_valid_voice_name(): + """ + Tests that initialization succeeds with a valid voice name. + """ + api_key = "test_api_key" + voice_name = "Zephyr" + tts_provider = GeminiTTSProvider(api_key=api_key, voice_name=voice_name) + assert tts_provider.api_key == api_key + assert tts_provider.voice_name == voice_name + assert tts_provider.model_name == "gemini-2.5-flash-preview-tts" + assert "gemini-2.5-flash-preview-tts" in tts_provider.api_url + +def test_init_with_invalid_voice_name(): + """ + Tests that initialization fails with an invalid voice name. + """ + api_key = "test_api_key" + invalid_voice_name = "InvalidVoice" + with pytest.raises(ValueError, match="Invalid voice name"): + GeminiTTSProvider(api_key=api_key, voice_name=invalid_voice_name) + +def test_init_with_custom_model_name(): + """ + Tests that the provider can be initialized with a custom model name. + """ + api_key = "test_api_key" + custom_model_name = "gemini-tts-beta" + tts_provider = GeminiTTSProvider(api_key=api_key, model_name=custom_model_name) + assert tts_provider.model_name == custom_model_name + assert custom_model_name in tts_provider.api_url \ No newline at end of file