Newer
Older
cortex-hub / ai-hub / app / api / routes / tts.py
from fastapi import APIRouter, HTTPException, Query, Response, Depends
from fastapi.responses import StreamingResponse
from typing import AsyncGenerator
from sqlalchemy.orm import Session
from app.api.dependencies import ServiceContainer, get_db
from app.api.routes.user import get_current_user_id
from app.api import schemas

def create_tts_router(services: ServiceContainer) -> APIRouter:
    router = APIRouter(prefix="/speech", tags=["TTS"])

    @router.post(
        "",
        summary="Generate speech from text",
        response_description="Audio bytes in WAV or PCM format, either as a complete file or a stream.",
    )
    async def create_speech_response(
        request: schemas.SpeechRequest,
        stream: bool = Query(
            False,
            description="If true, returns a streamed audio response. Otherwise, returns a complete file."
        ),
        as_wav: bool = Query(
            True,
            description="If true, returns WAV format audio. If false, returns raw PCM audio data. Only applies when stream is true."
        ),
        provider_name: str = Query(
            None,
            description="Optional session-level override for the TTS provider"
        ),
        db: Session = Depends(get_db),
        user_id: str = Depends(get_current_user_id)
    ):
        try:
            provider_override = None
            if user_id:
                user = services.user_service.get_user_by_id(db=db, user_id=user_id)
                prefs = user.preferences.get("tts", {}) if user and user.preferences else {}
                from app.config import settings
                active_provider = provider_name or prefs.get("active_provider") or settings.TTS_PROVIDER
                active_prefs = prefs.get("providers", {}).get(active_provider, {})
                from app.core.providers.factory import get_tts_provider
                kwargs = {k: v for k, v in active_prefs.items() if k not in ["api_key", "model", "voice"]}
                provider_override = get_tts_provider(
                    provider_name=active_provider,
                    api_key=active_prefs.get("api_key"),
                    model_name=active_prefs.get("model", ""),
                    voice_name=active_prefs.get("voice", ""),
                    **kwargs
                )

            if stream:
                # Pre-flight: generate first chunk before streaming to catch errors cleanly
                # If we send StreamingResponse and then fail, the browser sees a network error
                # instead of a meaningful error message.
                # Split into first chunk for latency, then send entire rest for smoothness
                all_text = request.text
                separators = ['.', '?', '!', '\n', '。', '?', '!', ',']
                
                # Find first separator within a reasonable limit
                split_idx = -1
                chars_to_scan = min(len(all_text) - 1, 400)
                for i in range(chars_to_scan, 50, -1):
                    if all_text[i] in separators:
                        split_idx = i + 1
                        break
                
                if split_idx != -1 and split_idx < len(all_text):
                    chunks = [all_text[:split_idx], all_text[split_idx:]]
                else:
                    chunks = [all_text]
                
                provider = provider_override or services.tts_service.default_tts_provider
                if not chunks or not chunks[0].strip():
                    raise HTTPException(status_code=400, detail="No text to synthesize.")
                
                # Test first chunk synchronously to validate the provider works
                first_pcm = await provider.generate_speech(chunks[0])
                logger.info(f"TTS Stream started for session {user_id}. Initial chunk: {len(first_pcm)} bytes.")

                async def full_stream():
                    try:
                        # Yield the already-generated first chunk
                        if as_wav:
                            from app.core.services.tts import _create_wav_file
                            yield _create_wav_file(first_pcm)
                        else:
                            yield first_pcm
                        
                        # Then stream the remaining chunks using parallel fetching but sequential yielding
                        import asyncio
                        semaphore = asyncio.Semaphore(3) # Limit concurrent external requests

                        async def fetch_chunk(text_chunk, idx):
                            retries = 3
                            delay = 1.0
                            for attempt in range(retries):
                                try:
                                    async with semaphore:
                                        pcm_data = await provider.generate_speech(text_chunk)
                                        logger.debug(f"TTS Chunk {idx} generated: {len(pcm_data)} bytes.")
                                        return pcm_data
                                except Exception as e:
                                    error_str = str(e)
                                    if "No audio in response" in error_str or "finishReason" in error_str:
                                        logger.error(f"TTS Chunk {idx} blocked: {e}")
                                        return None
                                    
                                    if attempt == retries - 1:
                                        logger.error(f"TTS Chunk {idx} failed after {retries} attempts: {e}")
                                        return None
                                    await asyncio.sleep(delay)
                                    delay *= 2

                        # Start all tasks concurrently
                        tasks = [asyncio.create_task(fetch_chunk(chunk, i+1)) for i, chunk in enumerate(chunks[1:])]

                        for i, task in enumerate(tasks):
                            pcm = await task
                            if pcm:
                                if as_wav:
                                    from app.core.services.tts import _create_wav_file
                                    yield _create_wav_file(pcm)
                                else:
                                    yield pcm
                                logger.debug(f"TTS Chunk {i+1} yielded.")
                                
                    except Exception as e:
                        logger.error(f"Runtime error in TTS stream: {e}")
                        raise
                    finally:
                        logger.info(f"TTS Stream finished for session {user_id}")

                media_type = "audio/wav" if as_wav else "application/octet-stream"
                return StreamingResponse(
                    full_stream(), 
                    media_type=media_type, 
                    headers={
                        "X-TTS-Chunk-Count": str(len(chunks)),
                        "Cache-Control": "no-cache",
                        "Connection": "keep-alive"
                    }
                )

            else:
                # The non-streaming function only returns WAV
                audio_bytes = await services.tts_service.create_speech_non_stream(
                    text=request.text,
                    provider_override=provider_override
                )
                return Response(content=audio_bytes, media_type="audio/wav")

        except HTTPException:
            raise
        except Exception as e:
            logger.error(f"TTS route error: {e}")
            raise HTTPException(
                status_code=500, detail=f"Failed to generate speech: {e}"
            )

    @router.get(
        "/voices",
        summary="List available TTS voices",
        response_description="A list of voice names"
    )
    async def list_voices(
        provider: str = Query(None, description="Optional provider name"),
        api_key: str = Query(None, description="Optional API key override"),
        db: Session = Depends(get_db), 
        user_id: str = Depends(get_current_user_id)
    ):
        from app.config import settings
        import httpx
        from app.core.providers.tts.gemini import GeminiTTSProvider
        from app.core.providers.tts.gcloud_tts import GCloudTTSProvider

        # Resolve masked key if needed
        key_to_use = api_key
        if key_to_use and "***" in key_to_use and user_id:
            user = services.user_service.get_user_by_id(db=db, user_id=user_id)
            if user and user.preferences:
                # Look for the key in any TTS provider since we don't necessarily know which one yet
                for p_name, p_data in user.preferences.get("tts", {}).get("providers", {}).items():
                    if p_data.get("api_key") and "***" not in p_data["api_key"]:
                        # If a provider was passed, only use its key
                        if not provider or provider == p_name:
                            key_to_use = p_data["api_key"]
                            break
        
        # Fallback to defaults
        if not key_to_use or "***" in key_to_use:
            key_to_use = settings.TTS_API_KEY or settings.GEMINI_API_KEY

        # If it's Gemini, or the key starts with AIza (common AI Studio key)
        if provider == "google_gemini" or (not provider and key_to_use and key_to_use.startswith("AIza")):
            return sorted(GeminiTTSProvider.AVAILABLE_VOICES)

        # Default or explicit GCloud
        if not key_to_use:
            return []
            
        url = f"https://texttospeech.googleapis.com/v1/voices?key={key_to_use}"
        try:
            async with httpx.AsyncClient(timeout=10) as client:
                res = await client.get(url)
                if res.status_code == 200:
                    data = res.json()
                    voices = data.get('voices', [])
                    names = [v['name'] for v in voices]
                    return sorted(names)
                
                # If Google Cloud TTS fails, maybe it's actually an AI Studio key being used for Gemini?
                # Fallback to Gemini voices if it seems likely
                if key_to_use.startswith("AIza"):
                     return sorted(GeminiTTSProvider.AVAILABLE_VOICES)
                     
                return []
        except Exception as e:
            import logging
            logging.getLogger(__name__).error(f"Failed to fetch voices: {e}")
            # Final fallback to standard list if everything else fails but we have a key
            if key_to_use and key_to_use.startswith("AIza"):
                return sorted(GeminiTTSProvider.AVAILABLE_VOICES)
            return sorted(GCloudTTSProvider.AVAILABLE_VOICES_EN)
            
    return router