Newer
Older
cortex-hub / ai-hub / app / core / providers / factory.py
from app.config import settings
from .base import TTSProvider, STTProvider
from .llm.general import GeneralProvider # Assuming GeneralProvider is now in this file or imported
from .tts.gemini import GeminiTTSProvider
from .tts.gcloud_tts import GCloudTTSProvider
from .stt.gemini import GoogleSTTProvider
from dspy.clients.base_lm import BaseLM
from openai import AsyncOpenAI

import litellm


# --- 1. Initialize API Clients from Central Config ---
# deepseek_client = AsyncOpenAI(api_key=settings.DEEPSEEK_API_KEY, base_url="https://api.deepseek.com")
# GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{settings.GEMINI_MODEL_NAME}:generateContent?key={settings.GEMINI_API_KEY}"

# --- 2. The Factory Dictionaries ---
_llm_providers = {
    "deepseek": settings.DEEPSEEK_API_KEY,
    "gemini": settings.GEMINI_API_KEY
}

_llm_models = {
    "deepseek": settings.DEEPSEEK_MODEL_NAME,
    "gemini": settings.GEMINI_MODEL_NAME
}

# --- 3. The Factory Functions ---
def get_llm_provider(provider_name: str, model_name: str = "", system_prompt: str = None) -> BaseLM:
    """Factory function to get the appropriate, pre-configured LLM provider, with optional system prompt."""
    providerKey = _llm_providers.get(provider_name)
    if not providerKey:
        raise ValueError(f"Unsupported model provider: '{provider_name}'. Supported providers are: {list(_llm_providers.keys())}")
    
    modelName = model_name
    if modelName == "":
        modelName = _llm_models.get(provider_name)
        if not modelName:
            raise ValueError(f"Unsupported model provider: '{provider_name}'. Supported providers are: {list(_llm_providers.keys())}")
        
    # Pass the optional system_prompt to the GeneralProvider constructor
    return GeneralProvider(model_name=f'{provider_name}/{modelName}', api_key=providerKey, system_prompt=system_prompt)

def get_tts_provider(provider_name: str, api_key: str, model_name: str, voice_name: str) -> TTSProvider:
    if provider_name == "google_gemini":
        return GeminiTTSProvider(api_key=api_key, model_name=model_name, voice_name=voice_name)
    elif provider_name == "gcloud_tts":
        return GCloudTTSProvider(api_key=api_key, voice_name=voice_name)
    raise ValueError(f"Unsupported TTS provider: '{provider_name}'. Supported providers are: ['google_gemini', 'gcloud_tts']")

def get_stt_provider(provider_name: str, api_key: str, model_name: str) -> STTProvider:
    if provider_name == "google_gemini":
        return GoogleSTTProvider(api_key=api_key, model_name=model_name)
    raise ValueError(f"Unsupported STT provider: '{provider_name}'. Supported providers are: ['google_gemini']")