Newer
Older
cortex-hub / ai-hub / app / core / providers / factory.py
from app.config import settings
from .base import TTSProvider, STTProvider
from .llm.general import GeneralProvider
from .tts.general import GeneralTTSProvider
from .stt.general import GeneralSTTProvider
from .tts.gemini import GeminiTTSProvider
from .tts.gcloud_tts import GCloudTTSProvider
from .stt.gemini import GoogleSTTProvider
from dspy.clients.base_lm import BaseLM
from openai import AsyncOpenAI

import litellm


# --- 1. Initialize API Clients from Central Config ---
# deepseek_client = AsyncOpenAI(api_key=settings.DEEPSEEK_API_KEY, base_url="https://api.deepseek.com")
# GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{settings.GEMINI_MODEL_NAME}:generateContent?key={settings.GEMINI_API_KEY}"

# --- 2. The Factory Dictionaries ---
_llm_providers = {
    "deepseek": settings.DEEPSEEK_API_KEY,
    "gemini": settings.GEMINI_API_KEY
}

_llm_models = {
    "deepseek": settings.DEEPSEEK_MODEL_NAME,
    "gemini": settings.GEMINI_MODEL_NAME
}

_tts_registry = {
    "google_gemini": GeminiTTSProvider,
    "gcloud_tts": GCloudTTSProvider
}

_stt_registry = {
    "google_gemini": GoogleSTTProvider
}

def get_registered_tts_providers():
    return list(_tts_registry.keys())

def get_registered_stt_providers():
    return list(_stt_registry.keys())

# --- 3. The Factory Functions ---
def get_llm_provider(provider_name: str, model_name: str = "", system_prompt: str = None, api_key_override: str = None, **kwargs) -> BaseLM:
    """Factory function to get the appropriate, pre-configured LLM provider, with optional system prompt."""
    providerKey = api_key_override or _llm_providers.get(provider_name)
    
    modelName = model_name
    if not modelName:
        modelName = _llm_models.get(provider_name)
        if not modelName:
            raise ValueError(f"No model name provided for '{provider_name}'.")
        
    full_model = f'{provider_name}/{modelName}' if '/' not in modelName else modelName
    
    # Pass the optional system_prompt and kwargs to the GeneralProvider constructor
    return GeneralProvider(model_name=full_model, api_key=providerKey, system_prompt=system_prompt, **kwargs)

def get_tts_provider(provider_name: str, api_key: str, model_name: str, voice_name: str, **kwargs) -> TTSProvider:
    def is_masked(k):
        return not k or k in ("None", "none", "") or "*" in str(k)

    actual_key = api_key
    if is_masked(actual_key):
        if not is_masked(settings.TTS_API_KEY):
            actual_key = settings.TTS_API_KEY
        elif provider_name == "google_gemini" and not is_masked(settings.GEMINI_API_KEY):
            actual_key = settings.GEMINI_API_KEY
        elif provider_name in settings.LLM_PROVIDERS and not is_masked(settings.LLM_PROVIDERS[provider_name].get("api_key")):
            actual_key = settings.LLM_PROVIDERS[provider_name].get("api_key")

    provider_cls = _tts_registry.get(provider_name)
    if provider_cls:
        if provider_name == "gcloud_tts":
             return provider_cls(api_key=actual_key, voice_name=voice_name, **kwargs)
        return provider_cls(api_key=actual_key, model_name=model_name, voice_name=voice_name, **kwargs)
    
    # Fallback to General LiteLLM implementation
    full_model = model_name
    if "/" not in full_model and provider_name not in ["google_gemini", "gcloud_tts"]:
        full_model = f"{provider_name}/{model_name}"
        
    return GeneralTTSProvider(model_name=full_model, api_key=actual_key, voice_name=voice_name, **kwargs)

def get_stt_provider(provider_name: str, api_key: str, model_name: str, **kwargs) -> STTProvider:
    def is_masked(k):
        return not k or k in ("None", "none", "") or "*" in str(k)

    actual_key = api_key
    if is_masked(actual_key):
        if not is_masked(settings.STT_API_KEY):
            actual_key = settings.STT_API_KEY
        elif provider_name == "google_gemini" and not is_masked(settings.GEMINI_API_KEY):
            actual_key = settings.GEMINI_API_KEY
        elif provider_name in settings.LLM_PROVIDERS and not is_masked(settings.LLM_PROVIDERS[provider_name].get("api_key")):
            actual_key = settings.LLM_PROVIDERS[provider_name].get("api_key")

    provider_cls = _stt_registry.get(provider_name)
    if provider_cls:
        return provider_cls(api_key=actual_key, model_name=model_name, **kwargs)
        
    # Fallback to General LiteLLM implementation
    full_model = model_name
    if "/" not in full_model and provider_name not in ["google_gemini"]:
        full_model = f"{provider_name}/{model_name}"
        
    return GeneralSTTProvider(model_name=full_model, api_key=actual_key, **kwargs)