import os
import logging
from app.config import settings
from .base import TTSProvider, STTProvider
from .llm.general import GeneralProvider
from .tts.general import GeneralTTSProvider
from .stt.general import GeneralSTTProvider
from .tts.gemini import GeminiTTSProvider
from .tts.gcloud_tts import GCloudTTSProvider
from .stt.gemini import GoogleSTTProvider
from openai import AsyncOpenAI

import litellm
LITELLM_AVAILABLE = True
try:
    import litellm
except ImportError:
    LITELLM_AVAILABLE = False


def resolve_provider_info(name, section, registry, litellm_list=None):
    """
    Resolves the base provider type from a potentially suffixed instance name.
    Example: 'gemini_2' -> 'gemini'
    """
    if name in registry:
        return name
    if litellm_list and name in litellm_list:
        if name == "vertex_ai_beta":
            return "vertex_ai"
        return name
    
    # Try prefixes for suffixed instances (split by underscore)
    if "_" in name:
        parts = name.split("_")
        # Check longest possible prefix first (important for types like google_gemini)
        for i in range(len(parts) - 1, 0, -1):
            prefix = "_".join(parts[:i])
            if prefix in registry or (litellm_list and prefix in litellm_list):
                if prefix == "vertex_ai_beta": return "vertex_ai"
                return prefix
    return name

# --- 1. Initialize API Clients from Central Config ---

# --- 2. The Factory Dictionaries ---
_llm_providers = ["gemini", "deepseek", "openai", "anthropic", "vertex_ai", "google_gemini"]

_tts_registry = {
    "google_gemini": GeminiTTSProvider,
    "gcloud_tts": GCloudTTSProvider
}

_stt_registry = {
    "google_gemini": GoogleSTTProvider
}

def get_registered_tts_providers():
    return list(_tts_registry.keys())

def get_registered_stt_providers():
    return list(_stt_registry.keys())

# --- 3. The Factory Functions ---
def get_llm_provider(provider_name: str, model_name: str = "", system_prompt: str = None, api_key_override: str = None, **kwargs) -> GeneralProvider:
    """Factory function to get the appropriate, pre-configured LLM provider, with optional system prompt."""
    logging.info(f"[Factory] Resolving: {provider_name} / {model_name}")
    
    # helper for masked/null keys
    def is_empty(k):
        return not k or k in ("None", "none", "") or "*" in str(k)

    # Extract base provider for API key lookups
    base_provider_for_keys = provider_name.split("/")[0] if "/" in provider_name else provider_name

    # 1. Resolve Provider Key
    providerKey = api_key_override
    if is_empty(providerKey):
        # Check LLM_PROVIDERS dict first (hot-loaded via admin)
        p_info = settings.LLM_PROVIDERS.get(base_provider_for_keys, {})
        providerKey = p_info.get("api_key")
        
        # Secondary fallback to hardcoded env settings
        if is_empty(providerKey):
            # Dynamic ENV check: {PROVIDER}_API_KEY
            providerKey = os.getenv(f"{base_provider_for_keys.upper()}_API_KEY")

    # 2. Resolve Model Name (always runs, regardless of key resolution path)
    modelName = model_name
    if not modelName:
        # Priority 1: Extract model from provider_name if it contains one (e.g. "gemini/gemini-1.5-flash")
        if "/" in provider_name:
            modelName = provider_name.split("/", 1)[1]
        
        # Priority 2: If we have a suffixed name like "gemini_gemini-1.5-flash"
        if not modelName and "_" in provider_name:
            parts = provider_name.split("_")
            if len(parts) > 1 and parts[0] in ["gemini", "openai", "deepseek", "anthropic"]:
                potential_model = "_".join(parts[1:])
                if "flash" in potential_model or "gpt" in potential_model or "chat" in potential_model:
                    modelName = potential_model

    if not modelName:
        # Resolve from provided model_name argument if present
        modelName = model_name

    if not modelName:
        # Final safety check: if we STILL don't have a model, we can't initialize
        raise ValueError(f"Could not resolve model name for provider '{provider_name}'. The caller must provide a valid model string.")

    # Extract base type (e.g. 'gemini_2' -> 'gemini')
    if not LITELLM_AVAILABLE:
        raise ImportError("LiteLLM is not installed or failed to load. Please check your requirements.txt.")
    
    litellm_providers = [p.value for p in litellm.LlmProviders]
    base_type = kwargs.get("provider_type") or resolve_provider_info(base_provider_for_keys, "llm", _llm_providers, litellm_providers)

    # Prevent doubling like 'gemini/gemini/gemini-1.5-flash'
    if '/' in modelName:
        full_model = modelName
    else:
        full_model = f'{base_type}/{modelName}'

    return GeneralProvider(model_name=full_model, api_key=providerKey, system_prompt=system_prompt, **kwargs)

def get_tts_provider(provider_name: str, api_key: str, model_name: str, voice_name: str, **kwargs) -> TTSProvider:
    def is_masked(k):
        return not k or k in ("None", "none", "") or "*" in str(k)

    actual_key = api_key
    if is_masked(actual_key):
        # Resolve from LLM_PROVIDERS if possible (sharing keys)
        if provider_name in settings.LLM_PROVIDERS and not is_masked(settings.LLM_PROVIDERS[provider_name].get("api_key")):
            actual_key = settings.LLM_PROVIDERS[provider_name].get("api_key")

    # Resolve base technology type
    base_type = kwargs.get("provider_type") or resolve_provider_info(provider_name, "tts", _tts_registry)

    provider_cls = _tts_registry.get(base_type)
    if provider_cls:
        if base_type == "gcloud_tts":
             return provider_cls(api_key=actual_key, voice_name=voice_name, **kwargs)
        return provider_cls(api_key=actual_key, model_name=model_name, voice_name=voice_name, **kwargs)
    
    # Fallback to General LiteLLM implementation
    full_model = model_name
    if "/" not in full_model and base_type not in ["google_gemini", "gcloud_tts"]:
        full_model = f"{base_type}/{model_name}"
        
    return GeneralTTSProvider(model_name=full_model, api_key=actual_key, voice_name=voice_name, **kwargs)

def get_stt_provider(provider_name: str, api_key: str, model_name: str, **kwargs) -> STTProvider:
    def is_masked(k):
        return not k or k in ("None", "none", "") or "*" in str(k)

    actual_key = api_key
    if is_masked(actual_key):
        # Resolve from LLM_PROVIDERS if possible (sharing keys)
        if provider_name in settings.LLM_PROVIDERS and not is_masked(settings.LLM_PROVIDERS[provider_name].get("api_key")):
            actual_key = settings.LLM_PROVIDERS[provider_name].get("api_key")

    base_type = kwargs.get("provider_type") or resolve_provider_info(provider_name, "stt", _stt_registry)

    provider_cls = _stt_registry.get(base_type)
    if provider_cls:
        return provider_cls(api_key=actual_key, model_name=model_name, **kwargs)
        
    # Fallback to General LiteLLM implementation
    full_model = model_name
    if "/" not in full_model and base_type not in ["google_gemini"]:
        full_model = f"{base_type}/{model_name}"
        
    return GeneralSTTProvider(model_name=full_model, api_key=actual_key, **kwargs)

def get_model_limit(provider_name: str, model_name: str = None) -> int:
    """
    Gets the token limit (context window) for a given provider/model using LiteLLM.
    Used for UI progress bars and validation.
    """
    if not provider_name:
        raise ValueError("provider_name must be a valid string to retrieve model limits.")

    base_provider_for_keys = provider_name.split("/")[0] if "/" in provider_name else provider_name
    
    # 1. Resolve Model Name
    modelName = model_name
    if not modelName:
        modelName = settings.LLM_PROVIDERS.get(base_provider_for_keys, {}).get("model")
    if not modelName:
        raise ValueError(f"Model name not configured for provider '{provider_name}'. Please go to Settings > LLM Providers in the UI to set a default model for this provider.")
    
    # 2. Resolve Base Type
    litellm_providers = [p.value for p in litellm.LlmProviders]
    base_type = resolve_provider_info(base_provider_for_keys, "llm", _llm_providers, litellm_providers)
    
    # Prevent doubling like 'gemini/gemini'
    if modelName and '/' in modelName:
        full_model = modelName
    elif modelName:
        full_model = f'{base_type}/{modelName}'
    else:
        # Final fallback
        full_model = f'{base_type}/{base_type}'
    
    try:
        info = litellm.get_model_info(full_model)
        if info:
            # Prefer max_input_tokens as it represents the context window
            input_tokens = info.get("max_input_tokens")
            
            # If litellm gave us an empty value, fall back to max_tokens or default
            if not input_tokens:
                input_tokens = info.get("max_tokens") or 32000
            
            return input_tokens
    except:
        pass
        
    return 10000