from app.config import settings
from .base import TTSProvider, STTProvider
from .llm.general import GeneralProvider # Assuming GeneralProvider is now in this file or imported
from .tts.gemini import GeminiTTSProvider
from .tts.gcloud_tts import GCloudTTSProvider
from .stt.gemini import GoogleSTTProvider
from dspy.clients.base_lm import BaseLM
from openai import AsyncOpenAI
import litellm
# --- 1. Initialize API Clients from Central Config ---
# deepseek_client = AsyncOpenAI(api_key=settings.DEEPSEEK_API_KEY, base_url="https://api.deepseek.com")
# GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{settings.GEMINI_MODEL_NAME}:generateContent?key={settings.GEMINI_API_KEY}"
# --- 2. The Factory Dictionaries ---
_llm_providers = {
"deepseek": settings.DEEPSEEK_API_KEY,
"gemini": settings.GEMINI_API_KEY
}
_llm_models = {
"deepseek": settings.DEEPSEEK_MODEL_NAME,
"gemini": settings.GEMINI_MODEL_NAME
}
# --- 3. The Factory Functions ---
def get_llm_provider(provider_name: str, model_name: str = "", system_prompt: str = None) -> BaseLM:
"""Factory function to get the appropriate, pre-configured LLM provider, with optional system prompt."""
providerKey = _llm_providers.get(provider_name)
if not providerKey:
raise ValueError(f"Unsupported model provider: '{provider_name}'. Supported providers are: {list(_llm_providers.keys())}")
modelName = model_name
if modelName == "":
modelName = _llm_models.get(provider_name)
if not modelName:
raise ValueError(f"Unsupported model provider: '{provider_name}'. Supported providers are: {list(_llm_providers.keys())}")
# Pass the optional system_prompt to the GeneralProvider constructor
return GeneralProvider(model_name=f'{provider_name}/{modelName}', api_key=providerKey, system_prompt=system_prompt)
def get_tts_provider(provider_name: str, api_key: str, model_name: str, voice_name: str) -> TTSProvider:
if provider_name == "google_gemini":
return GeminiTTSProvider(api_key=api_key, model_name=model_name, voice_name=voice_name)
elif provider_name == "gcloud_tts":
return GCloudTTSProvider(api_key=api_key, voice_name=voice_name)
raise ValueError(f"Unsupported TTS provider: '{provider_name}'. Supported providers are: ['google_gemini', 'gcloud_tts']")
def get_stt_provider(provider_name: str, api_key: str, model_name: str) -> STTProvider:
if provider_name == "google_gemini":
return GoogleSTTProvider(api_key=api_key, model_name=model_name)
raise ValueError(f"Unsupported STT provider: '{provider_name}'. Supported providers are: ['google_gemini']")