from app.config import settings
from .base import TTSProvider, STTProvider
from .llm.general import GeneralProvider
from .tts.general import GeneralTTSProvider
from .stt.general import GeneralSTTProvider
from .tts.gemini import GeminiTTSProvider
from .tts.gcloud_tts import GCloudTTSProvider
from .stt.gemini import GoogleSTTProvider
from dspy.clients.base_lm import BaseLM
from openai import AsyncOpenAI
import litellm
def resolve_provider_info(name, section, registry, litellm_list=None):
"""
Resolves the base provider type from a potentially suffixed instance name.
Example: 'gemini_2' -> 'gemini'
"""
if name in registry:
return name
if litellm_list and name in litellm_list:
return name
# Try prefixes for suffixed instances (split by underscore)
if "_" in name:
parts = name.split("_")
# Check longest possible prefix first (important for types like google_gemini)
for i in range(len(parts) - 1, 0, -1):
prefix = "_".join(parts[:i])
if prefix in registry or (litellm_list and prefix in litellm_list):
return prefix
return name
# --- 1. Initialize API Clients from Central Config ---
# deepseek_client = AsyncOpenAI(api_key=settings.DEEPSEEK_API_KEY, base_url="https://api.deepseek.com")
# GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{settings.GEMINI_MODEL_NAME}:generateContent?key={settings.GEMINI_API_KEY}"
# --- 2. The Factory Dictionaries ---
_llm_providers = {
"deepseek": settings.DEEPSEEK_API_KEY,
"gemini": settings.GEMINI_API_KEY
}
_llm_models = {
"deepseek": settings.DEEPSEEK_MODEL_NAME,
"gemini": settings.GEMINI_MODEL_NAME
}
_tts_registry = {
"google_gemini": GeminiTTSProvider,
"gcloud_tts": GCloudTTSProvider
}
_stt_registry = {
"google_gemini": GoogleSTTProvider
}
def get_registered_tts_providers():
return list(_tts_registry.keys())
def get_registered_stt_providers():
return list(_stt_registry.keys())
# --- 3. The Factory Functions ---
def get_llm_provider(provider_name: str, model_name: str = "", system_prompt: str = None, api_key_override: str = None, **kwargs) -> BaseLM:
"""Factory function to get the appropriate, pre-configured LLM provider, with optional system prompt."""
providerKey = api_key_override or _llm_providers.get(provider_name)
# Extract base type (e.g. 'gemini_2' -> 'gemini')
litellm_providers = [p.value for p in litellm.LlmProviders]
base_type = kwargs.get("provider_type") or resolve_provider_info(provider_name, "llm", _llm_providers, litellm_providers)
modelName = model_name
if not modelName:
modelName = _llm_models.get(provider_name)
if not modelName:
raise ValueError(f"No model name provided for '{provider_name}'.")
full_model = f'{base_type}/{modelName}' if '/' not in modelName else modelName
# Pass the optional system_prompt and kwargs to the GeneralProvider constructor
return GeneralProvider(model_name=full_model, api_key=providerKey, system_prompt=system_prompt, **kwargs)
def get_tts_provider(provider_name: str, api_key: str, model_name: str, voice_name: str, **kwargs) -> TTSProvider:
def is_masked(k):
return not k or k in ("None", "none", "") or "*" in str(k)
actual_key = api_key
if is_masked(actual_key):
if not is_masked(settings.TTS_API_KEY):
actual_key = settings.TTS_API_KEY
elif provider_name == "google_gemini" and not is_masked(settings.GEMINI_API_KEY):
actual_key = settings.GEMINI_API_KEY
elif provider_name in settings.LLM_PROVIDERS and not is_masked(settings.LLM_PROVIDERS[provider_name].get("api_key")):
actual_key = settings.LLM_PROVIDERS[provider_name].get("api_key")
# Resolve base technology type
base_type = kwargs.get("provider_type") or resolve_provider_info(provider_name, "tts", _tts_registry)
provider_cls = _tts_registry.get(base_type)
if provider_cls:
if base_type == "gcloud_tts":
return provider_cls(api_key=actual_key, voice_name=voice_name, **kwargs)
return provider_cls(api_key=actual_key, model_name=model_name, voice_name=voice_name, **kwargs)
# Fallback to General LiteLLM implementation
full_model = model_name
if "/" not in full_model and base_type not in ["google_gemini", "gcloud_tts"]:
full_model = f"{base_type}/{model_name}"
return GeneralTTSProvider(model_name=full_model, api_key=actual_key, voice_name=voice_name, **kwargs)
def get_stt_provider(provider_name: str, api_key: str, model_name: str, **kwargs) -> STTProvider:
def is_masked(k):
return not k or k in ("None", "none", "") or "*" in str(k)
actual_key = api_key
if is_masked(actual_key):
if not is_masked(settings.STT_API_KEY):
actual_key = settings.STT_API_KEY
elif provider_name == "google_gemini" and not is_masked(settings.GEMINI_API_KEY):
actual_key = settings.GEMINI_API_KEY
elif provider_name in settings.LLM_PROVIDERS and not is_masked(settings.LLM_PROVIDERS[provider_name].get("api_key")):
actual_key = settings.LLM_PROVIDERS[provider_name].get("api_key")
base_type = kwargs.get("provider_type") or resolve_provider_info(provider_name, "stt", _stt_registry)
provider_cls = _stt_registry.get(base_type)
if provider_cls:
return provider_cls(api_key=actual_key, model_name=model_name, **kwargs)
# Fallback to General LiteLLM implementation
full_model = model_name
if "/" not in full_model and base_type not in ["google_gemini"]:
full_model = f"{base_type}/{model_name}"
return GeneralSTTProvider(model_name=full_model, api_key=actual_key, **kwargs)