from fastapi import APIRouter, HTTPException, Query, Response, Depends
from fastapi.responses import StreamingResponse
from typing import AsyncGenerator
from sqlalchemy.orm import Session
from app.api.dependencies import ServiceContainer, get_db
from app.api.routes.user import get_current_user_id
from app.api import schemas

def create_tts_router(services: ServiceContainer) -> APIRouter:
    router = APIRouter(prefix="/speech", tags=["TTS"])

    @router.post(
        "",
        summary="Generate speech from text",
        response_description="Audio bytes in WAV or PCM format, either as a complete file or a stream.",
    )
    async def create_speech_response(
        request: schemas.SpeechRequest,
        stream: bool = Query(
            False,
            description="If true, returns a streamed audio response. Otherwise, returns a complete file."
        ),
        as_wav: bool = Query(
            True,
            description="If true, returns WAV format audio. If false, returns raw PCM audio data. Only applies when stream is true."
        ),
        provider_name: str = Query(
            None,
            description="Optional session-level override for the TTS provider"
        ),
        db: Session = Depends(get_db),
        user_id: str = Depends(get_current_user_id)
    ):
        try:
            provider_override = None
            if user_id:
                user = services.user_service.get_user_by_id(db=db, user_id=user_id)
                prefs = user.preferences.get("tts", {}) if user and user.preferences else {}
                from app.config import settings
                active_provider = provider_name or prefs.get("active_provider") or settings.TTS_PROVIDER
                active_prefs = prefs.get("providers", {}).get(active_provider, {})
                from app.core.providers.factory import get_tts_provider
                kwargs = {k: v for k, v in active_prefs.items() if k not in ["api_key", "model", "voice"]}
                provider_override = get_tts_provider(
                    provider_name=active_provider,
                    api_key=active_prefs.get("api_key"),
                    model_name=active_prefs.get("model", ""),
                    voice_name=active_prefs.get("voice", ""),
                    **kwargs
                )

            if stream:
                # Pre-flight: generate first chunk before streaming to catch errors cleanly
                # If we send StreamingResponse and then fail, the browser sees a network error
                # instead of a meaningful error message.
                # Split into first chunk for latency, then send entire rest for smoothness
                all_text = request.text
                separators = ['.', '?', '!', '\n', '。', '？', '！', '，']
                
                # Find first separator within a reasonable limit
                split_idx = -1
                chars_to_scan = min(len(all_text) - 1, 400)
                for i in range(chars_to_scan, 50, -1):
                    if all_text[i] in separators:
                        split_idx = i + 1
                        break
                
                if split_idx != -1 and split_idx < len(all_text):
                    chunks = [all_text[:split_idx], all_text[split_idx:]]
                else:
                    chunks = [all_text]
                
                provider = provider_override or services.tts_service.default_tts_provider
                if not chunks or not chunks[0].strip():
                    raise HTTPException(status_code=400, detail="No text to synthesize.")
                
                # Test first chunk synchronously to validate the provider works
                first_pcm = await provider.generate_speech(chunks[0])

                async def full_stream():
                    # Yield the already-generated first chunk
                    if as_wav:
                        from app.core.services.tts import _create_wav_file
                        yield _create_wav_file(first_pcm)
                    else:
                        yield first_pcm
                    # Then stream the remaining chunks using parallel fetching but sequential yielding
                    import asyncio
                    semaphore = asyncio.Semaphore(3) # Limit concurrent external requests

                    async def fetch_chunk(text_chunk):
                        retries = 3
                        delay = 1.0
                        for attempt in range(retries):
                            try:
                                async with semaphore:
                                    return await provider.generate_speech(text_chunk)
                            except Exception as e:
                                error_str = str(e)
                                if "No audio in response" in error_str or "finishReason" in error_str:
                                    import logging
                                    logging.getLogger(__name__).error(f"TTS chunk blocked by provider formatting/safety: {e}")
                                    return None
                                
                                if attempt == retries - 1:
                                    import logging
                                    logging.getLogger(__name__).error(f"TTS chunk failed after {retries} attempts: {e}")
                                    return None
                                await asyncio.sleep(delay)
                                delay *= 2

                    # Start all tasks concurrently
                    tasks = [asyncio.create_task(fetch_chunk(chunk)) for chunk in chunks[1:]]

                    for task in tasks:
                        pcm = await task
                        if pcm:
                            if as_wav:
                                from app.core.services.tts import _create_wav_file
                                yield _create_wav_file(pcm)
                            else:
                                yield pcm

                media_type = "audio/wav" if as_wav else "audio/pcm"
                return StreamingResponse(
                    full_stream(), 
                    media_type=media_type, 
                    headers={"X-TTS-Chunk-Count": str(len(chunks))}
                )

            else:
                # The non-streaming function only returns WAV, so this part remains the same
                audio_bytes = await services.tts_service.create_speech_non_stream(
                    text=request.text,
                    provider_override=provider_override
                )
                return Response(content=audio_bytes, media_type="audio/wav")

        except HTTPException:
            raise
        except Exception as e:
            raise HTTPException(
                status_code=500, detail=f"Failed to generate speech: {e}"
            )

    @router.get(
        "/voices",
        summary="List available TTS voices",
        response_description="A list of voice names"
    )
    async def list_voices(
        provider: str = Query(None, description="Optional provider name"),
        api_key: str = Query(None, description="Optional API key override"),
        db: Session = Depends(get_db), 
        user_id: str = Depends(get_current_user_id)
    ):
        from app.config import settings
        import httpx
        from app.core.providers.tts.gemini import GeminiTTSProvider
        from app.core.providers.tts.gcloud_tts import GCloudTTSProvider

        # Resolve masked key if needed
        key_to_use = api_key
        if key_to_use and "***" in key_to_use and user_id:
            user = services.user_service.get_user_by_id(db=db, user_id=user_id)
            if user and user.preferences:
                # Look for the key in any TTS provider since we don't necessarily know which one yet
                for p_name, p_data in user.preferences.get("tts", {}).get("providers", {}).items():
                    if p_data.get("api_key") and "***" not in p_data["api_key"]:
                        # If a provider was passed, only use its key
                        if not provider or provider == p_name:
                            key_to_use = p_data["api_key"]
                            break
        
        # Fallback to defaults
        if not key_to_use or "***" in key_to_use:
            key_to_use = settings.TTS_API_KEY or settings.GEMINI_API_KEY

        # If it's Gemini, or the key starts with AIza (common AI Studio key)
        if provider == "google_gemini" or (not provider and key_to_use and key_to_use.startswith("AIza")):
            return sorted(GeminiTTSProvider.AVAILABLE_VOICES)

        # Default or explicit GCloud
        if not key_to_use:
            return []
            
        url = f"https://texttospeech.googleapis.com/v1/voices?key={key_to_use}"
        try:
            async with httpx.AsyncClient(timeout=10) as client:
                res = await client.get(url)
                if res.status_code == 200:
                    data = res.json()
                    voices = data.get('voices', [])
                    names = [v['name'] for v in voices]
                    return sorted(names)
                
                # If Google Cloud TTS fails, maybe it's actually an AI Studio key being used for Gemini?
                # Fallback to Gemini voices if it seems likely
                if key_to_use.startswith("AIza"):
                     return sorted(GeminiTTSProvider.AVAILABLE_VOICES)
                     
                return []
        except Exception as e:
            import logging
            logging.getLogger(__name__).error(f"Failed to fetch voices: {e}")
            # Final fallback to standard list if everything else fails but we have a key
            if key_to_use and key_to_use.startswith("AIza"):
                return sorted(GeminiTTSProvider.AVAILABLE_VOICES)
            return sorted(GCloudTTSProvider.AVAILABLE_VOICES_EN)
            
    return router