from app.config import settings
from .base import TTSProvider, STTProvider
from .llm.general import GeneralProvider
from .tts.general import GeneralTTSProvider
from .stt.general import GeneralSTTProvider
from .tts.gemini import GeminiTTSProvider
from .tts.gcloud_tts import GCloudTTSProvider
from .stt.gemini import GoogleSTTProvider
from dspy.clients.base_lm import BaseLM
from openai import AsyncOpenAI
import litellm
# --- 1. Initialize API Clients from Central Config ---
# deepseek_client = AsyncOpenAI(api_key=settings.DEEPSEEK_API_KEY, base_url="https://api.deepseek.com")
# GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{settings.GEMINI_MODEL_NAME}:generateContent?key={settings.GEMINI_API_KEY}"
# --- 2. The Factory Dictionaries ---
_llm_providers = {
"deepseek": settings.DEEPSEEK_API_KEY,
"gemini": settings.GEMINI_API_KEY
}
_llm_models = {
"deepseek": settings.DEEPSEEK_MODEL_NAME,
"gemini": settings.GEMINI_MODEL_NAME
}
_tts_registry = {
"google_gemini": GeminiTTSProvider,
"gcloud_tts": GCloudTTSProvider
}
_stt_registry = {
"google_gemini": GoogleSTTProvider
}
def get_registered_tts_providers():
return list(_tts_registry.keys())
def get_registered_stt_providers():
return list(_stt_registry.keys())
# --- 3. The Factory Functions ---
def get_llm_provider(provider_name: str, model_name: str = "", system_prompt: str = None, api_key_override: str = None, **kwargs) -> BaseLM:
"""Factory function to get the appropriate, pre-configured LLM provider, with optional system prompt."""
providerKey = api_key_override or _llm_providers.get(provider_name)
modelName = model_name
if not modelName:
modelName = _llm_models.get(provider_name)
if not modelName:
raise ValueError(f"No model name provided for '{provider_name}'.")
full_model = f'{provider_name}/{modelName}' if '/' not in modelName else modelName
# Pass the optional system_prompt and kwargs to the GeneralProvider constructor
return GeneralProvider(model_name=full_model, api_key=providerKey, system_prompt=system_prompt, **kwargs)
def get_tts_provider(provider_name: str, api_key: str, model_name: str, voice_name: str, **kwargs) -> TTSProvider:
def is_masked(k):
return not k or k in ("None", "none", "") or "*" in str(k)
actual_key = api_key
if is_masked(actual_key):
if not is_masked(settings.TTS_API_KEY):
actual_key = settings.TTS_API_KEY
elif provider_name == "google_gemini" and not is_masked(settings.GEMINI_API_KEY):
actual_key = settings.GEMINI_API_KEY
elif provider_name in settings.LLM_PROVIDERS and not is_masked(settings.LLM_PROVIDERS[provider_name].get("api_key")):
actual_key = settings.LLM_PROVIDERS[provider_name].get("api_key")
provider_cls = _tts_registry.get(provider_name)
if provider_cls:
if provider_name == "gcloud_tts":
return provider_cls(api_key=actual_key, voice_name=voice_name, **kwargs)
return provider_cls(api_key=actual_key, model_name=model_name, voice_name=voice_name, **kwargs)
# Fallback to General LiteLLM implementation
full_model = model_name
if "/" not in full_model and provider_name not in ["google_gemini", "gcloud_tts"]:
full_model = f"{provider_name}/{model_name}"
return GeneralTTSProvider(model_name=full_model, api_key=actual_key, voice_name=voice_name, **kwargs)
def get_stt_provider(provider_name: str, api_key: str, model_name: str, **kwargs) -> STTProvider:
def is_masked(k):
return not k or k in ("None", "none", "") or "*" in str(k)
actual_key = api_key
if is_masked(actual_key):
if not is_masked(settings.STT_API_KEY):
actual_key = settings.STT_API_KEY
elif provider_name == "google_gemini" and not is_masked(settings.GEMINI_API_KEY):
actual_key = settings.GEMINI_API_KEY
elif provider_name in settings.LLM_PROVIDERS and not is_masked(settings.LLM_PROVIDERS[provider_name].get("api_key")):
actual_key = settings.LLM_PROVIDERS[provider_name].get("api_key")
provider_cls = _stt_registry.get(provider_name)
if provider_cls:
return provider_cls(api_key=actual_key, model_name=model_name, **kwargs)
# Fallback to General LiteLLM implementation
full_model = model_name
if "/" not in full_model and provider_name not in ["google_gemini"]:
full_model = f"{provider_name}/{model_name}"
return GeneralSTTProvider(model_name=full_model, api_key=actual_key, **kwargs)