from app.config import settings from .base import TTSProvider, STTProvider from .llm.general import GeneralProvider # Assuming GeneralProvider is now in this file or imported from .tts.gemini import GeminiTTSProvider from .tts.gcloud_tts import GCloudTTSProvider from .stt.gemini import GoogleSTTProvider from dspy.clients.base_lm import BaseLM from openai import AsyncOpenAI import litellm # --- 1. Initialize API Clients from Central Config --- # deepseek_client = AsyncOpenAI(api_key=settings.DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") # GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{settings.GEMINI_MODEL_NAME}:generateContent?key={settings.GEMINI_API_KEY}" # --- 2. The Factory Dictionaries --- _llm_providers = { "deepseek": settings.DEEPSEEK_API_KEY, "gemini": settings.GEMINI_API_KEY } _llm_models = { "deepseek": settings.DEEPSEEK_MODEL_NAME, "gemini": settings.GEMINI_MODEL_NAME } # --- 3. The Factory Functions --- def get_llm_provider(provider_name: str, model_name: str = "", system_prompt: str = None) -> BaseLM: """Factory function to get the appropriate, pre-configured LLM provider, with optional system prompt.""" providerKey = _llm_providers.get(provider_name) if not providerKey: raise ValueError(f"Unsupported model provider: '{provider_name}'. Supported providers are: {list(_llm_providers.keys())}") modelName = model_name if modelName == "": modelName = _llm_models.get(provider_name) if not modelName: raise ValueError(f"Unsupported model provider: '{provider_name}'. Supported providers are: {list(_llm_providers.keys())}") # Pass the optional system_prompt to the GeneralProvider constructor return GeneralProvider(model_name=f'{provider_name}/{modelName}', api_key=providerKey, system_prompt=system_prompt) def get_tts_provider(provider_name: str, api_key: str, model_name: str, voice_name: str) -> TTSProvider: if provider_name == "google_gemini": return GeminiTTSProvider(api_key=api_key, model_name=model_name, voice_name=voice_name) elif provider_name == "gcloud_tts": return GCloudTTSProvider(api_key=api_key, voice_name=voice_name) raise ValueError(f"Unsupported TTS provider: '{provider_name}'. Supported providers are: ['google_gemini', 'gcloud_tts']") def get_stt_provider(provider_name: str, api_key: str, model_name: str) -> STTProvider: if provider_name == "google_gemini": return GoogleSTTProvider(api_key=api_key, model_name=model_name) raise ValueError(f"Unsupported STT provider: '{provider_name}'. Supported providers are: ['google_gemini']")