Newer
Older
cortex-hub / ai-hub / app / core / providers / factory.py
from app.config import settings
from .base import TTSProvider, STTProvider
from .llm.general import GeneralProvider
from .tts.general import GeneralTTSProvider
from .stt.general import GeneralSTTProvider
from .tts.gemini import GeminiTTSProvider
from .tts.gcloud_tts import GCloudTTSProvider
from .stt.gemini import GoogleSTTProvider
from dspy.clients.base_lm import BaseLM
from openai import AsyncOpenAI

import litellm

def resolve_provider_info(name, section, registry, litellm_list=None):
    """
    Resolves the base provider type from a potentially suffixed instance name.
    Example: 'gemini_2' -> 'gemini'
    """
    if name in registry:
        return name
    if litellm_list and name in litellm_list:
        return name
    
    # Try prefixes for suffixed instances (split by underscore)
    if "_" in name:
        parts = name.split("_")
        # Check longest possible prefix first (important for types like google_gemini)
        for i in range(len(parts) - 1, 0, -1):
            prefix = "_".join(parts[:i])
            if prefix in registry or (litellm_list and prefix in litellm_list):
                return prefix
    return name

# --- 1. Initialize API Clients from Central Config ---
# deepseek_client = AsyncOpenAI(api_key=settings.DEEPSEEK_API_KEY, base_url="https://api.deepseek.com")
# GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{settings.GEMINI_MODEL_NAME}:generateContent?key={settings.GEMINI_API_KEY}"

# --- 2. The Factory Dictionaries ---
_llm_providers = {
    "deepseek": settings.DEEPSEEK_API_KEY,
    "gemini": settings.GEMINI_API_KEY
}

_llm_models = {
    "deepseek": settings.DEEPSEEK_MODEL_NAME,
    "gemini": settings.GEMINI_MODEL_NAME
}

_tts_registry = {
    "google_gemini": GeminiTTSProvider,
    "gcloud_tts": GCloudTTSProvider
}

_stt_registry = {
    "google_gemini": GoogleSTTProvider
}

def get_registered_tts_providers():
    return list(_tts_registry.keys())

def get_registered_stt_providers():
    return list(_stt_registry.keys())

# --- 3. The Factory Functions ---
def get_llm_provider(provider_name: str, model_name: str = "", system_prompt: str = None, api_key_override: str = None, **kwargs) -> BaseLM:
    """Factory function to get the appropriate, pre-configured LLM provider, with optional system prompt."""
    providerKey = api_key_override or _llm_providers.get(provider_name)
    
    # Extract base type (e.g. 'gemini_2' -> 'gemini')
    litellm_providers = [p.value for p in litellm.LlmProviders]
    base_type = kwargs.get("provider_type") or resolve_provider_info(provider_name, "llm", _llm_providers, litellm_providers)
    
    modelName = model_name
    if not modelName:
        modelName = _llm_models.get(provider_name)
        if not modelName:
            raise ValueError(f"No model name provided for '{provider_name}'.")
        
    full_model = f'{base_type}/{modelName}' if '/' not in modelName else modelName
    
    # Pass the optional system_prompt and kwargs to the GeneralProvider constructor
    return GeneralProvider(model_name=full_model, api_key=providerKey, system_prompt=system_prompt, **kwargs)

def get_tts_provider(provider_name: str, api_key: str, model_name: str, voice_name: str, **kwargs) -> TTSProvider:
    def is_masked(k):
        return not k or k in ("None", "none", "") or "*" in str(k)

    actual_key = api_key
    if is_masked(actual_key):
        if not is_masked(settings.TTS_API_KEY):
            actual_key = settings.TTS_API_KEY
        elif provider_name == "google_gemini" and not is_masked(settings.GEMINI_API_KEY):
            actual_key = settings.GEMINI_API_KEY
        elif provider_name in settings.LLM_PROVIDERS and not is_masked(settings.LLM_PROVIDERS[provider_name].get("api_key")):
            actual_key = settings.LLM_PROVIDERS[provider_name].get("api_key")

    # Resolve base technology type
    base_type = kwargs.get("provider_type") or resolve_provider_info(provider_name, "tts", _tts_registry)

    provider_cls = _tts_registry.get(base_type)
    if provider_cls:
        if base_type == "gcloud_tts":
             return provider_cls(api_key=actual_key, voice_name=voice_name, **kwargs)
        return provider_cls(api_key=actual_key, model_name=model_name, voice_name=voice_name, **kwargs)
    
    # Fallback to General LiteLLM implementation
    full_model = model_name
    if "/" not in full_model and base_type not in ["google_gemini", "gcloud_tts"]:
        full_model = f"{base_type}/{model_name}"
        
    return GeneralTTSProvider(model_name=full_model, api_key=actual_key, voice_name=voice_name, **kwargs)

def get_stt_provider(provider_name: str, api_key: str, model_name: str, **kwargs) -> STTProvider:
    def is_masked(k):
        return not k or k in ("None", "none", "") or "*" in str(k)

    actual_key = api_key
    if is_masked(actual_key):
        if not is_masked(settings.STT_API_KEY):
            actual_key = settings.STT_API_KEY
        elif provider_name == "google_gemini" and not is_masked(settings.GEMINI_API_KEY):
            actual_key = settings.GEMINI_API_KEY
        elif provider_name in settings.LLM_PROVIDERS and not is_masked(settings.LLM_PROVIDERS[provider_name].get("api_key")):
            actual_key = settings.LLM_PROVIDERS[provider_name].get("api_key")

    base_type = kwargs.get("provider_type") or resolve_provider_info(provider_name, "stt", _stt_registry)

    provider_cls = _stt_registry.get(base_type)
    if provider_cls:
        return provider_cls(api_key=actual_key, model_name=model_name, **kwargs)
        
    # Fallback to General LiteLLM implementation
    full_model = model_name
    if "/" not in full_model and base_type not in ["google_gemini"]:
        full_model = f"{base_type}/{model_name}"
        
    return GeneralSTTProvider(model_name=full_model, api_key=actual_key, **kwargs)