import logging
from fastapi import APIRouter, HTTPException, UploadFile, File, Depends
from sqlalchemy.orm import Session
from app.api.dependencies import ServiceContainer, get_db
from app.api.routes.user import get_current_user_id
from app.api import schemas
from app.core.services.stt import STTService
# Configure logging
logger = logging.getLogger(__name__)
def create_stt_router(services: ServiceContainer) -> APIRouter:
"""
Creates and configures the API router for Speech-to-Text (STT) functionality.
"""
router = APIRouter(prefix="/stt", tags=["STT"])
@router.post(
"/transcribe",
summary="Transcribe audio to text",
response_description="The transcribed text from the audio file.",
response_model=schemas.STTResponse
)
async def transcribe_audio_to_text(
audio_file: UploadFile = File(...),
provider_name: str | None = None,
db: Session = Depends(get_db),
user_id: str = Depends(get_current_user_id)
):
"""
Transcribes an uploaded audio file into text using the configured STT service.
The audio file is expected to be a common audio format like WAV or MP3,
though the specific provider implementation will determine supported formats.
"""
logger.info(f"Received transcription request for file: {audio_file.filename}")
if not audio_file.content_type.startswith("audio/"):
logger.warning(f"Invalid file type uploaded: {audio_file.content_type}")
raise HTTPException(
status_code=415,
detail="Unsupported media type. Please upload an audio file."
)
try:
# Read the audio bytes from the uploaded file
audio_bytes = await audio_file.read()
provider_override = None
if user_id:
user = services.user_service.get_user_by_id(db=db, user_id=user_id)
prefs = user.preferences.get("stt", {}) if user and user.preferences else {}
from app.config import settings
active_provider = provider_name or prefs.get("active_provider") or settings.STT_PROVIDER
active_prefs = prefs.get("providers", {}).get(active_provider, {})
if active_prefs:
from app.core.providers.factory import get_stt_provider
kwargs = {k: v for k, v in active_prefs.items() if k not in ["api_key", "model"]}
provider_override = get_stt_provider(
provider_name=active_provider,
api_key=active_prefs.get("api_key"),
model_name=active_prefs.get("model", ""),
**kwargs
)
# Use the STT service to get the transcript
transcript = await services.stt_service.transcribe(
audio_bytes,
provider_override=provider_override
)
# Return the transcript in a simple JSON response
return schemas.STTResponse(transcript=transcript)
except HTTPException:
# Re-raise Fast API exceptions so they're handled correctly
raise
except Exception as e:
logger.error(f"Failed to transcribe audio file: {e}")
raise HTTPException(
status_code=500, detail=f"Failed to transcribe audio: {e}"
) from e
return router