Newer
Older
cortex-hub / ai-hub / app / api / routes / stt.py
import logging
from fastapi import APIRouter, HTTPException, UploadFile, File, Depends
from sqlalchemy.orm import Session
from app.api.dependencies import ServiceContainer, get_db
from app.api.routes.user import get_current_user_id
from app.api import schemas
from app.core.services.stt import STTService

# Configure logging
logger = logging.getLogger(__name__)

def create_stt_router(services: ServiceContainer) -> APIRouter:
    """
    Creates and configures the API router for Speech-to-Text (STT) functionality.
    """
    router = APIRouter(prefix="/stt", tags=["STT"])

    @router.post(
        "/transcribe",
        summary="Transcribe audio to text",
        response_description="The transcribed text from the audio file.",
        response_model=schemas.STTResponse
    )
    async def transcribe_audio_to_text(
        audio_file: UploadFile = File(...),
        provider_name: str | None = None,
        db: Session = Depends(get_db),
        user_id: str = Depends(get_current_user_id)
    ):
        """
        Transcribes an uploaded audio file into text using the configured STT service.

        The audio file is expected to be a common audio format like WAV or MP3,
        though the specific provider implementation will determine supported formats.
        """
        logger.info(f"Received transcription request for file: {audio_file.filename}")

        if not audio_file.content_type.startswith("audio/"):
            logger.warning(f"Invalid file type uploaded: {audio_file.content_type}")
            raise HTTPException(
                status_code=415,
                detail="Unsupported media type. Please upload an audio file."
            )

        try:
            # Read the audio bytes from the uploaded file
            audio_bytes = await audio_file.read()

            provider_override = None
            if user_id:
                user = services.user_service.get_user_by_id(db=db, user_id=user_id)
                prefs = user.preferences.get("stt", {}) if user and user.preferences else {}
                from app.config import settings
                active_provider = provider_name or prefs.get("active_provider") or settings.STT_PROVIDER
                active_prefs = prefs.get("providers", {}).get(active_provider, {})
                if active_prefs:
                    from app.core.providers.factory import get_stt_provider
                    kwargs = {k: v for k, v in active_prefs.items() if k not in ["api_key", "model"]}
                    provider_override = get_stt_provider(
                        provider_name=active_provider,
                        api_key=active_prefs.get("api_key"),
                        model_name=active_prefs.get("model", ""),
                        **kwargs
                    )
            
            # Use the STT service to get the transcript
            transcript = await services.stt_service.transcribe(
                audio_bytes, 
                provider_override=provider_override
            )
            
            # Return the transcript in a simple JSON response
            return schemas.STTResponse(transcript=transcript)

        except HTTPException:
            # Re-raise Fast API exceptions so they're handled correctly
            raise
        except Exception as e:
            logger.error(f"Failed to transcribe audio file: {e}")
            raise HTTPException(
                status_code=500, detail=f"Failed to transcribe audio: {e}"
            ) from e
    
    return router