import os
import json
import httpx
import base64
import logging
import asyncio
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type, retry_if_exception
from app.core.providers.base import TTSProvider
from fastapi import HTTPException


# Configure logging
logger = logging.getLogger(__name__)


def is_retryable_exception(exception):
    """Check if the exception is one we should retry on."""
    if isinstance(exception, httpx.TimeoutException):
        return True
    if isinstance(exception, httpx.NetworkError):
        return True
    if isinstance(exception, HTTPException):
        # Retry on 429 (Too Many Requests) or 5xx (Server Errors)
        return exception.status_code == 429 or 500 <= exception.status_code < 600
    return False

class GeminiTTSProvider(TTSProvider):
    """TTS provider using Gemini's audio responseModalities via Google AI Studio."""

    AVAILABLE_VOICES = [
        "Zephyr", "Puck", "Charon", "Kore", "Fenrir", "Leda",
        "Orus", "Aoede", "Callirrhoe", "Autonoe", "Enceladus",
        "Iapetus", "Umbriel", "Algieba", "Despina", "Erinome",
        "Algenib", "Rasalgethi", "Laomedeia", "Achernar", "Alnilam",
        "Schedar", "Gacrux", "Pulcherrima", "Achird", "Zubenelgenubi",
        "Vindemiatrix", "Sadachbia", "Sadaltager", "Sulafat"
    ]

    def __init__(self, api_key: str, model_name: str = "gemini-2.5-flash-preview-tts",
                 voice_name: str = "Kore", **kwargs):
        raw_model = model_name or "gemini-2.5-flash-preview-tts"
        # Strip any provider prefix (e.g. "vertex_ai/model" or "gemini/model") → keep only the model id
        model_id = raw_model.split("/")[-1]
        # Normalise short names: "gemini-2-flash-tts" → "gemini-2.5-flash-preview-tts"
        if model_id in ("gemini-2-flash-tts", "gemini-2.5-flash-tts", "flash-tts", "gemini-2.5-flash"):
            model_id = "gemini-2.5-flash-preview-tts"
            logger.info(f"Normalised model name to: {model_id}")

        # Route to Vertex AI ONLY when the key is a Vertex service-account key (starting with "AQ.")
        # AI Studio keys start with "AIza" and must use the generativelanguage endpoint.
        is_vertex_key = bool(api_key) and api_key.startswith("AQ.")

        if is_vertex_key:
            self.api_url = (
                f"https://us-central1-aiplatform.googleapis.com/v1/publishers/google/"
                f"models/{model_id}:streamGenerateContent"
            )
            self.is_vertex = True
        else:
            # Google AI Studio — v1beta is required for audio responseModalities
            self.api_url = (
                f"https://generativelanguage.googleapis.com/v1beta/models/"
                f"{model_id}:generateContent?key={api_key}"
            )
            self.is_vertex = False

        self.api_key = api_key
        self.voice_name = voice_name
        self.model_name = model_id
        logger.debug(f"GeminiTTSProvider: model={self.model_name}, vertex={self.is_vertex}")
        logger.debug(f"  endpoint: {self.api_url[:80]}...")

    @retry(
        retry=retry_if_exception_type((httpx.TimeoutException, httpx.NetworkError)) | retry_if_exception(is_retryable_exception),
        stop=stop_after_attempt(3),
        wait=wait_exponential(multiplier=1, min=1, max=10),
        reraise=True,
        before_sleep=lambda retry_state: logger.warning(f"Retrying Gemini TTS request (attempt {retry_state.attempt_number})...")
    )
    async def generate_speech(self, text: str) -> bytes:
        logger.info(f"TTS request [model={self.model_name}, vertex={self.is_vertex}]: '{text}'")

        headers = {"Content-Type": "application/json"}

        # The dedicated TTS models require a system instruction to produce only audio
        json_data = {
            "contents": [{"role": "user", "parts": [{"text": text}]}],
            "generationConfig": {
                "responseModalities": ["AUDIO"],
                "speechConfig": {
                    "voiceConfig": {
                        "prebuiltVoiceConfig": {
                            "voiceName": self.voice_name
                        }
                    }
                }
            }
        }

        if not self.is_vertex:
            headers["x-goog-api-key"] = self.api_key

        logger.debug(f"Calling: {self.api_url}")

        try:
            async with httpx.AsyncClient(timeout=httpx.Timeout(60.0)) as client:
                response = await client.post(self.api_url, headers=headers, json=json_data)

            logger.debug(f"Response status: {response.status_code}")

            if response.status_code != 200:
                body = response.text
                logger.error(f"TTS API error {response.status_code}: {body[:300]}")
                try:
                    err = response.json().get("error", {})
                    msg = err.get("message", body[:200])
                except Exception:
                    msg = body[:200]
                
                # Check if we should retry (429 or 5xx)
                status_code = response.status_code
                if status_code == 429 or 500 <= status_code < 600:
                    # tenacity will catch this if we configure it to retry on HTTPException
                    raise HTTPException(status_code=status_code, detail=f"Gemini TTS error: {msg}")
                else:
                    # Non-retryable error
                    raise HTTPException(status_code=status_code, detail=f"Gemini TTS error: {msg}")

            resp_data = response.json()
            audio_fragments = []

            # Handle both list (streamGenerateContent) and single object (generateContent)
            segments = resp_data if isinstance(resp_data, list) else [resp_data]
            for segment in segments:
                candidates = segment.get("candidates", [])
                if candidates:
                    parts = candidates[0].get("content", {}).get("parts", [])
                    for part in parts:
                        inline = part.get("inlineData", {})
                        data = inline.get("data")
                        if data:
                            audio_fragments.append(base64.b64decode(data))

            if not audio_fragments:
                logger.error(f"No audio in response. Full response: {json.dumps(resp_data)[:500]}")
                raise HTTPException(status_code=500, detail="No audio data in Gemini TTS response.")

            result = b"".join(audio_fragments)
            logger.debug(f"TTS returned {len(result)} PCM bytes")
            return result

        except (httpx.TimeoutException, httpx.NetworkError) as e:
            logger.error(f"Gemini TTS request ({type(e).__name__}) after 60s")
            # tenacity will catch this and retry
            raise
        except HTTPException:
            # tenacity might catch this if it's 429 or 5xx
            raise
        except Exception as e:
            logger.error(f"Unexpected TTS error: {type(e).__name__}: {e}")
            raise HTTPException(status_code=500, detail=f"Failed to generate speech: {e}")