import os
import logging
from app.config import settings
from .base import TTSProvider, STTProvider
from .llm.general import GeneralProvider
from .tts.general import GeneralTTSProvider
from .stt.general import GeneralSTTProvider
from .tts.gemini import GeminiTTSProvider
from .tts.gcloud_tts import GCloudTTSProvider
from .stt.gemini import GoogleSTTProvider
from openai import AsyncOpenAI
import litellm
LITELLM_AVAILABLE = True
try:
import litellm
except ImportError:
LITELLM_AVAILABLE = False
def resolve_provider_info(name, section, registry, litellm_list=None):
"""
Resolves the base provider type from a potentially suffixed instance name.
Example: 'gemini_2' -> 'gemini'
"""
if name in registry:
return name
if litellm_list and name in litellm_list:
if name == "vertex_ai_beta":
return "vertex_ai"
return name
# Try prefixes for suffixed instances (split by underscore)
if "_" in name:
parts = name.split("_")
# Check longest possible prefix first (important for types like google_gemini)
for i in range(len(parts) - 1, 0, -1):
prefix = "_".join(parts[:i])
if prefix in registry or (litellm_list and prefix in litellm_list):
if prefix == "vertex_ai_beta": return "vertex_ai"
return prefix
return name
# --- 1. Initialize API Clients from Central Config ---
# --- 2. The Factory Dictionaries ---
_llm_providers = ["gemini", "deepseek", "openai", "anthropic", "vertex_ai", "google_gemini"]
_tts_registry = {
"google_gemini": GeminiTTSProvider,
"gcloud_tts": GCloudTTSProvider
}
_stt_registry = {
"google_gemini": GoogleSTTProvider
}
def get_registered_tts_providers():
return list(_tts_registry.keys())
def get_registered_stt_providers():
return list(_stt_registry.keys())
# --- 3. The Factory Functions ---
def get_llm_provider(provider_name: str, model_name: str = "", system_prompt: str = None, api_key_override: str = None, **kwargs) -> GeneralProvider:
"""Factory function to get the appropriate, pre-configured LLM provider, with optional system prompt."""
logging.info(f"[Factory] Resolving: {provider_name} / {model_name}")
# helper for masked/null keys
def is_empty(k):
return not k or k in ("None", "none", "") or "*" in str(k)
# Extract base provider for API key lookups
base_provider_for_keys = provider_name.split("/")[0] if "/" in provider_name else provider_name
# 1. Resolve Provider Key
providerKey = api_key_override
if is_empty(providerKey):
# Check LLM_PROVIDERS dict first (hot-loaded via admin)
p_info = settings.LLM_PROVIDERS.get(base_provider_for_keys, {})
providerKey = p_info.get("api_key")
# Secondary fallback to hardcoded env settings
if is_empty(providerKey):
# Dynamic ENV check: {PROVIDER}_API_KEY
providerKey = os.getenv(f"{base_provider_for_keys.upper()}_API_KEY")
# 2. Resolve Model Name (always runs, regardless of key resolution path)
modelName = model_name
if not modelName:
# Priority 1: Extract model from provider_name if it contains one (e.g. "gemini/gemini-1.5-flash")
if "/" in provider_name:
modelName = provider_name.split("/", 1)[1]
# Priority 2: If we have a suffixed name like "gemini_gemini-1.5-flash"
if not modelName and "_" in provider_name:
parts = provider_name.split("_")
if len(parts) > 1 and parts[0] in ["gemini", "openai", "deepseek", "anthropic"]:
potential_model = "_".join(parts[1:])
if "flash" in potential_model or "gpt" in potential_model or "chat" in potential_model:
modelName = potential_model
if not modelName:
# Resolve from provided model_name argument if present
modelName = model_name
if not modelName:
# Final safety check: if we STILL don't have a model, we can't initialize
raise ValueError(f"Could not resolve model name for provider '{provider_name}'. The caller must provide a valid model string.")
# Extract base type (e.g. 'gemini_2' -> 'gemini')
if not LITELLM_AVAILABLE:
raise ImportError("LiteLLM is not installed or failed to load. Please check your requirements.txt.")
litellm_providers = [p.value for p in litellm.LlmProviders]
base_type = kwargs.get("provider_type") or resolve_provider_info(base_provider_for_keys, "llm", _llm_providers, litellm_providers)
# Prevent doubling like 'gemini/gemini/gemini-1.5-flash'
if '/' in modelName:
full_model = modelName
else:
full_model = f'{base_type}/{modelName}'
return GeneralProvider(model_name=full_model, api_key=providerKey, system_prompt=system_prompt, **kwargs)
def get_tts_provider(provider_name: str, api_key: str, model_name: str, voice_name: str, **kwargs) -> TTSProvider:
def is_masked(k):
return not k or k in ("None", "none", "") or "*" in str(k)
actual_key = api_key
if is_masked(actual_key):
# Resolve from LLM_PROVIDERS if possible (sharing keys)
if provider_name in settings.LLM_PROVIDERS and not is_masked(settings.LLM_PROVIDERS[provider_name].get("api_key")):
actual_key = settings.LLM_PROVIDERS[provider_name].get("api_key")
# Resolve base technology type
base_type = kwargs.get("provider_type") or resolve_provider_info(provider_name, "tts", _tts_registry)
provider_cls = _tts_registry.get(base_type)
if provider_cls:
if base_type == "gcloud_tts":
return provider_cls(api_key=actual_key, voice_name=voice_name, **kwargs)
return provider_cls(api_key=actual_key, model_name=model_name, voice_name=voice_name, **kwargs)
# Fallback to General LiteLLM implementation
full_model = model_name
if "/" not in full_model and base_type not in ["google_gemini", "gcloud_tts"]:
full_model = f"{base_type}/{model_name}"
return GeneralTTSProvider(model_name=full_model, api_key=actual_key, voice_name=voice_name, **kwargs)
def get_stt_provider(provider_name: str, api_key: str, model_name: str, **kwargs) -> STTProvider:
def is_masked(k):
return not k or k in ("None", "none", "") or "*" in str(k)
actual_key = api_key
if is_masked(actual_key):
# Resolve from LLM_PROVIDERS if possible (sharing keys)
if provider_name in settings.LLM_PROVIDERS and not is_masked(settings.LLM_PROVIDERS[provider_name].get("api_key")):
actual_key = settings.LLM_PROVIDERS[provider_name].get("api_key")
base_type = kwargs.get("provider_type") or resolve_provider_info(provider_name, "stt", _stt_registry)
provider_cls = _stt_registry.get(base_type)
if provider_cls:
return provider_cls(api_key=actual_key, model_name=model_name, **kwargs)
# Fallback to General LiteLLM implementation
full_model = model_name
if "/" not in full_model and base_type not in ["google_gemini"]:
full_model = f"{base_type}/{model_name}"
return GeneralSTTProvider(model_name=full_model, api_key=actual_key, **kwargs)
def get_model_limit(provider_name: str, model_name: str = None) -> int:
"""
Gets the token limit (context window) for a given provider/model using LiteLLM.
Used for UI progress bars and validation.
"""
if not provider_name:
raise ValueError("provider_name must be a valid string to retrieve model limits.")
base_provider_for_keys = provider_name.split("/")[0] if "/" in provider_name else provider_name
# 1. Resolve Model Name
modelName = model_name
if not modelName:
modelName = settings.LLM_PROVIDERS.get(base_provider_for_keys, {}).get("model")
if not modelName:
raise ValueError(f"Model name not configured for provider '{provider_name}'. Please go to Settings > LLM Providers in the UI to set a default model for this provider.")
# 2. Resolve Base Type
litellm_providers = [p.value for p in litellm.LlmProviders]
base_type = resolve_provider_info(base_provider_for_keys, "llm", _llm_providers, litellm_providers)
# Prevent doubling like 'gemini/gemini'
if modelName and '/' in modelName:
full_model = modelName
elif modelName:
full_model = f'{base_type}/{modelName}'
else:
# Final fallback
full_model = f'{base_type}/{base_type}'
try:
info = litellm.get_model_info(full_model)
if info:
# Prefer max_input_tokens as it represents the context window
input_tokens = info.get("max_input_tokens")
# If litellm gave us an empty value, fall back to max_tokens or default
if not input_tokens:
input_tokens = info.get("max_tokens") or 32000
return input_tokens
except:
pass
return 10000