Newer
Older
cortex-hub / ai-hub / app / api / routes / tts.py
from fastapi import APIRouter, HTTPException, Query, Response, Depends
from fastapi.responses import StreamingResponse
from typing import AsyncGenerator
from sqlalchemy.orm import Session
from app.api.dependencies import ServiceContainer, get_db
from app.api.routes.user import get_current_user_id
from app.api import schemas

def create_tts_router(services: ServiceContainer) -> APIRouter:
    router = APIRouter(prefix="/speech", tags=["TTS"])

    @router.post(
        "",
        summary="Generate speech from text",
        response_description="Audio bytes in WAV or PCM format, either as a complete file or a stream.",
    )
    async def create_speech_response(
        request: schemas.SpeechRequest,
        stream: bool = Query(
            False,
            description="If true, returns a streamed audio response. Otherwise, returns a complete file."
        ),
        as_wav: bool = Query(
            True,
            description="If true, returns WAV format audio. If false, returns raw PCM audio data. Only applies when stream is true."
        ),
        provider_name: str = Query(
            None,
            description="Optional session-level override for the TTS provider"
        ),
        db: Session = Depends(get_db),
        user_id: str = Depends(get_current_user_id)
    ):
        try:
            provider_override = None
            if user_id:
                user = services.user_service.get_user_by_id(db=db, user_id=user_id)
                prefs = user.preferences.get("tts", {}) if user and user.preferences else {}
                from app.config import settings
                active_provider = provider_name or prefs.get("active_provider") or settings.TTS_PROVIDER
                active_prefs = prefs.get("providers", {}).get(active_provider, {})
                if active_prefs:
                    from app.core.providers.factory import get_tts_provider
                    kwargs = {k: v for k, v in active_prefs.items() if k not in ["api_key", "model", "voice"]}
                    provider_override = get_tts_provider(
                        provider_name=active_provider,
                        api_key=active_prefs.get("api_key"),
                        model_name=active_prefs.get("model", ""),
                        voice_name=active_prefs.get("voice", ""),
                        **kwargs
                    )

            if stream:
                # Pre-flight: generate first chunk before streaming to catch errors cleanly
                # If we send StreamingResponse and then fail, the browser sees a network error
                # instead of a meaningful error message.
                chunks = await services.tts_service._split_text_into_chunks(request.text)
                provider = provider_override or services.tts_service.default_tts_provider
                if not chunks:
                    raise HTTPException(status_code=400, detail="No text to synthesize.")
                
                # Test first chunk synchronously to validate the provider works
                first_pcm = await provider.generate_speech(chunks[0])

                async def full_stream():
                    # Yield the already-generated first chunk
                    if as_wav:
                        from app.core.services.tts import _create_wav_file
                        yield _create_wav_file(first_pcm)
                    else:
                        yield first_pcm
                    # Then stream the remaining chunks
                    for chunk in chunks[1:]:
                        try:
                            pcm = await provider.generate_speech(chunk)
                            if pcm:
                                if as_wav:
                                    from app.core.services.tts import _create_wav_file
                                    yield _create_wav_file(pcm)
                                else:
                                    yield pcm
                        except Exception as e:
                            import logging
                            logging.getLogger(__name__).error(f"TTS chunk error: {e}")
                            break  # Stop cleanly rather than crashing the stream

                media_type = "audio/wav" if as_wav else "audio/pcm"
                return StreamingResponse(full_stream(), media_type=media_type)

            else:
                # The non-streaming function only returns WAV, so this part remains the same
                audio_bytes = await services.tts_service.create_speech_non_stream(
                    text=request.text,
                    provider_override=provider_override
                )
                return Response(content=audio_bytes, media_type="audio/wav")

        except HTTPException:
            raise
        except Exception as e:
            raise HTTPException(
                status_code=500, detail=f"Failed to generate speech: {e}"
            )

    @router.get(
        "/voices",
        summary="List available TTS voices",
        response_description="A list of voice names"
    )
    async def list_voices(
        provider: str = Query(None, description="Optional provider name"),
        api_key: str = Query(None, description="Optional API key override"),
        db: Session = Depends(get_db), 
        user_id: str = Depends(get_current_user_id)
    ):
        from app.config import settings
        import httpx
        from app.core.providers.tts.gemini import GeminiTTSProvider
        from app.core.providers.tts.gcloud_tts import GCloudTTSProvider

        # Resolve masked key if needed
        key_to_use = api_key
        if key_to_use and "***" in key_to_use and user_id:
            user = services.user_service.get_user_by_id(db=db, user_id=user_id)
            if user and user.preferences:
                # Look for the key in any TTS provider since we don't necessarily know which one yet
                for p_name, p_data in user.preferences.get("tts", {}).get("providers", {}).items():
                    if p_data.get("api_key") and "***" not in p_data["api_key"]:
                        # If a provider was passed, only use its key
                        if not provider or provider == p_name:
                            key_to_use = p_data["api_key"]
                            break
        
        # Fallback to defaults
        if not key_to_use or "***" in key_to_use:
            key_to_use = settings.TTS_API_KEY or settings.GEMINI_API_KEY

        # If it's Gemini, or the key starts with AIza (common AI Studio key)
        if provider == "google_gemini" or (not provider and key_to_use and key_to_use.startswith("AIza")):
            return sorted(GeminiTTSProvider.AVAILABLE_VOICES)

        # Default or explicit GCloud
        if not key_to_use:
            return []
            
        url = f"https://texttospeech.googleapis.com/v1/voices?key={key_to_use}"
        try:
            async with httpx.AsyncClient(timeout=10) as client:
                res = await client.get(url)
                if res.status_code == 200:
                    data = res.json()
                    voices = data.get('voices', [])
                    names = [v['name'] for v in voices]
                    return sorted(names)
                
                # If Google Cloud TTS fails, maybe it's actually an AI Studio key being used for Gemini?
                # Fallback to Gemini voices if it seems likely
                if key_to_use.startswith("AIza"):
                     return sorted(GeminiTTSProvider.AVAILABLE_VOICES)
                     
                return []
        except Exception as e:
            import logging
            logging.getLogger(__name__).error(f"Failed to fetch voices: {e}")
            # Final fallback to standard list if everything else fails but we have a key
            if key_to_use and key_to_use.startswith("AIza"):
                return sorted(GeminiTTSProvider.AVAILABLE_VOICES)
            return sorted(GCloudTTSProvider.AVAILABLE_VOICES_EN)
            
    return router