Newer
Older
cortex-hub / ai-hub / app / core / providers / factory.py
from app.config import settings
from .base import TTSProvider, STTProvider
from .llm.general import GeneralProvider
from .tts.general import GeneralTTSProvider
from .stt.general import GeneralSTTProvider
from .tts.gemini import GeminiTTSProvider
from .tts.gcloud_tts import GCloudTTSProvider
from .stt.gemini import GoogleSTTProvider
from openai import AsyncOpenAI

import litellm

def resolve_provider_info(name, section, registry, litellm_list=None):
    """
    Resolves the base provider type from a potentially suffixed instance name.
    Example: 'gemini_2' -> 'gemini'
    """
    if name in registry:
        return name
    if litellm_list and name in litellm_list:
        if name == "vertex_ai_beta":
            return "vertex_ai"
        return name
    
    # Try prefixes for suffixed instances (split by underscore)
    if "_" in name:
        parts = name.split("_")
        # Check longest possible prefix first (important for types like google_gemini)
        for i in range(len(parts) - 1, 0, -1):
            prefix = "_".join(parts[:i])
            if prefix in registry or (litellm_list and prefix in litellm_list):
                if prefix == "vertex_ai_beta": return "vertex_ai"
                return prefix
    return name

# --- 1. Initialize API Clients from Central Config ---
# deepseek_client = AsyncOpenAI(api_key=settings.DEEPSEEK_API_KEY, base_url="https://api.deepseek.com")
# GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{settings.GEMINI_MODEL_NAME}:generateContent?key={settings.GEMINI_API_KEY}"

# --- 2. The Factory Dictionaries ---
_llm_providers = ["gemini", "deepseek", "openai", "anthropic", "vertex_ai", "google_gemini"]

_tts_registry = {
    "google_gemini": GeminiTTSProvider,
    "gcloud_tts": GCloudTTSProvider
}

_stt_registry = {
    "google_gemini": GoogleSTTProvider
}

def get_registered_tts_providers():
    return list(_tts_registry.keys())

def get_registered_stt_providers():
    return list(_stt_registry.keys())

# --- 3. The Factory Functions ---
def get_llm_provider(provider_name: str, model_name: str = "", system_prompt: str = None, api_key_override: str = None, **kwargs) -> GeneralProvider:
    """Factory function to get the appropriate, pre-configured LLM provider, with optional system prompt."""
    
    # helper for masked/null keys
    def is_empty(k):
        return not k or k in ("None", "none", "") or "*" in str(k)

    # Extract base provider for API key lookups
    base_provider_for_keys = provider_name.split("/")[0] if "/" in provider_name else provider_name

    # 1. Resolve Provider Key
    providerKey = api_key_override
    if is_empty(providerKey):
        # Check LLM_PROVIDERS dict first (hot-loaded via admin)
        p_info = settings.LLM_PROVIDERS.get(base_provider_for_keys, {})
        providerKey = p_info.get("api_key")
        
        # Secondary fallback to hardcoded env settings
        if is_empty(providerKey):
            if base_provider_for_keys == "gemini": providerKey = settings.GEMINI_API_KEY
            elif base_provider_for_keys == "deepseek": providerKey = settings.DEEPSEEK_API_KEY
    
    # 2. Resolve Model Name
    modelName = model_name
    if not modelName:
        # Priority 1: Extract model from provider_name if it contains one (e.g. "gemini/gemini-2.5-flash")
        if "/" in provider_name:
            modelName = provider_name.split("/", 1)[1]
        
        # Priority 2: If we have a suffixed name like "gemini_gemini-2.5-flash"
        if not modelName and "_" in provider_name:
             parts = provider_name.split("_")
             if len(parts) > 1 and parts[0] in ["gemini", "openai", "deepseek", "anthropic"]:
                  potential_model = "_".join(parts[1:])
                  if "flash" in potential_model or "gpt" in potential_model or "chat" in potential_model:
                       modelName = potential_model

        # Priority 3: Check settings using base provider
        if not modelName:
            modelName = settings.LLM_PROVIDERS.get(base_provider_for_keys, {}).get("model")
        
        # Priority 4: Final fallback for Gemini if still missing
        if not modelName and "gemini" in base_provider_for_keys:
            modelName = "gemini-2.5-flash"
        
    # Extract base type (e.g. 'gemini_2' -> 'gemini')
    litellm_providers = [p.value for p in litellm.LlmProviders]
    base_type = kwargs.get("provider_type") or resolve_provider_info(base_provider_for_keys, "llm", _llm_providers, litellm_providers)
    
    # Task: Prevent doubling like 'gemini/gemini/gemini-2.5-flash'
    if '/' in modelName:
        full_model = modelName
    else:
        full_model = f'{base_type}/{modelName}'
    
    # Pass the optional system_prompt and kwargs to the GeneralProvider constructor
    return GeneralProvider(model_name=full_model, api_key=providerKey, system_prompt=system_prompt, **kwargs)

def get_tts_provider(provider_name: str, api_key: str, model_name: str, voice_name: str, **kwargs) -> TTSProvider:
    def is_masked(k):
        return not k or k in ("None", "none", "") or "*" in str(k)

    actual_key = api_key
    if is_masked(actual_key):
        if not is_masked(settings.TTS_API_KEY):
            actual_key = settings.TTS_API_KEY
        elif provider_name == "google_gemini" and not is_masked(settings.GEMINI_API_KEY):
            actual_key = settings.GEMINI_API_KEY
        elif provider_name in settings.LLM_PROVIDERS and not is_masked(settings.LLM_PROVIDERS[provider_name].get("api_key")):
            actual_key = settings.LLM_PROVIDERS[provider_name].get("api_key")
        elif "google" in provider_name or "gemini" in provider_name:
            # Final desperate search for ANY gemini key in LLM_PROVIDERS
            for p, p_d in settings.LLM_PROVIDERS.items():
                if ("gemini" in p.lower() or "google" in p.lower()) and not is_masked(p_d.get("api_key")):
                    actual_key = p_d.get("api_key")
                    break

    # Resolve base technology type
    base_type = kwargs.get("provider_type") or resolve_provider_info(provider_name, "tts", _tts_registry)

    provider_cls = _tts_registry.get(base_type)
    if provider_cls:
        if base_type == "gcloud_tts":
             return provider_cls(api_key=actual_key, voice_name=voice_name, **kwargs)
        return provider_cls(api_key=actual_key, model_name=model_name, voice_name=voice_name, **kwargs)
    
    # Fallback to General LiteLLM implementation
    full_model = model_name
    if "/" not in full_model and base_type not in ["google_gemini", "gcloud_tts"]:
        full_model = f"{base_type}/{model_name}"
        
    return GeneralTTSProvider(model_name=full_model, api_key=actual_key, voice_name=voice_name, **kwargs)

def get_stt_provider(provider_name: str, api_key: str, model_name: str, **kwargs) -> STTProvider:
    def is_masked(k):
        return not k or k in ("None", "none", "") or "*" in str(k)

    actual_key = api_key
    if is_masked(actual_key):
        if not is_masked(settings.STT_API_KEY):
            actual_key = settings.STT_API_KEY
        elif provider_name == "google_gemini" and not is_masked(settings.GEMINI_API_KEY):
            actual_key = settings.GEMINI_API_KEY
        elif provider_name in settings.LLM_PROVIDERS and not is_masked(settings.LLM_PROVIDERS[provider_name].get("api_key")):
            actual_key = settings.LLM_PROVIDERS[provider_name].get("api_key")
        elif "google" in provider_name or "gemini" in provider_name:
             # Final desperate search for ANY gemini key in LLM_PROVIDERS
            for p, p_d in settings.LLM_PROVIDERS.items():
                if ("gemini" in p.lower() or "google" in p.lower()) and not is_masked(p_d.get("api_key")):
                    actual_key = p_d.get("api_key")
                    break

    base_type = kwargs.get("provider_type") or resolve_provider_info(provider_name, "stt", _stt_registry)

    provider_cls = _stt_registry.get(base_type)
    if provider_cls:
        return provider_cls(api_key=actual_key, model_name=model_name, **kwargs)
        
    # Fallback to General LiteLLM implementation
    full_model = model_name
    if "/" not in full_model and base_type not in ["google_gemini"]:
        full_model = f"{base_type}/{model_name}"
        
    return GeneralSTTProvider(model_name=full_model, api_key=actual_key, **kwargs)

def get_model_limit(provider_name: str, model_name: str = None) -> int:
    """
    Gets the token limit (context window) for a given provider/model using LiteLLM.
    Used for UI progress bars and validation.
    """
    base_provider_for_keys = provider_name.split("/")[0] if "/" in provider_name else provider_name
    
    # 1. Resolve Model Name
    modelName = model_name
    if not modelName:
        modelName = settings.LLM_PROVIDERS.get(base_provider_for_keys, {}).get("model")
        if not modelName:
             if "/" in provider_name:
                 modelName = provider_name.split("/", 1)[1]
             elif base_provider_for_keys == "gemini": modelName = settings.GEMINI_MODEL_NAME
             elif base_provider_for_keys == "deepseek": modelName = settings.DEEPSEEK_MODEL_NAME
             elif "gemini" in base_provider_for_keys.lower(): modelName = settings.GEMINI_MODEL_NAME
             elif "deepseek" in base_provider_for_keys.lower(): modelName = settings.DEEPSEEK_MODEL_NAME
             else:
                 return 100000 # Safety default
    
    # 2. Resolve Base Type
    litellm_providers = [p.value for p in litellm.LlmProviders]
    base_type = resolve_provider_info(base_provider_for_keys, "llm", _llm_providers, litellm_providers)
    
    full_model = f'{base_type}/{modelName}' if '/' not in modelName else modelName
    
    try:
        info = litellm.get_model_info(full_model)
        if info:
            # Prefer max_input_tokens as it represents the context window
            input_tokens = info.get("max_input_tokens")
            
            # If litellm gave us an empty value or a suspiciously low value like 8192 
            # (which is often the max_output_tokens, not the context window), override it
            if not input_tokens or input_tokens <= 32000:
                if "gemini" in full_model.lower():
                    input_tokens = 1048576  # Gemini 1.5 1M context
                elif "deepseek" in full_model.lower():
                    input_tokens = 128000
                elif "gpt-4o" in full_model.lower():
                    input_tokens = 128000
                elif "claude" in full_model.lower():
                    input_tokens = 200000
                else:
                    input_tokens = info.get("max_tokens") or 100000
            
            return input_tokens
    except:
        pass
        
    # Final default behavior if completely unknown
    if "gemini" in full_model.lower():
        return 1048576
    elif "deepseek" in full_model.lower():
        return 128000
        
    return 100000