Newer
Older
cortex-hub / ai-hub / app / core / providers / factory.py
from app.config import settings
from .base import LLMProvider, TTSProvider, STTProvider
from .llm.deepseek import DeepSeekProvider
from .llm.gemini import GeminiProvider
from .tts.gemini import GeminiTTSProvider
from .stt.gemini import GoogleSTTProvider
from openai import AsyncOpenAI

# --- 1. Initialize API Clients from Central Config ---
deepseek_client = AsyncOpenAI(api_key=settings.DEEPSEEK_API_KEY, base_url="https://api.deepseek.com")
GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{settings.GEMINI_MODEL_NAME}:generateContent?key={settings.GEMINI_API_KEY}"

# --- 2. The Factory Dictionaries ---
_llm_providers = {
    "deepseek": DeepSeekProvider(model_name=settings.DEEPSEEK_MODEL_NAME, client=deepseek_client),
    "gemini": GeminiProvider(api_url=GEMINI_URL)
}

# --- 3. The Factory Functions ---
def get_llm_provider(model_name: str) -> LLMProvider:
    """Factory function to get the appropriate, pre-configured LLM provider."""
    provider = _llm_providers.get(model_name)
    if not provider:
        raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_llm_providers.keys())}")
    return provider

def get_tts_provider(provider_name: str, api_key: str) -> TTSProvider:
    if provider_name == "google_genai":
        return GeminiTTSProvider(api_key=api_key)
    raise ValueError(f"Unsupported TTS provider: '{provider_name}'. Supported providers are: ['google_genai']")

def get_stt_provider(provider_name: str, api_key: str, model_name: str) -> STTProvider:
    if provider_name == "google_gemini":
        return GoogleSTTProvider(api_key=api_key, model_name=model_name)
    raise ValueError(f"Unsupported STT provider: '{provider_name}'. Supported providers are: ['google_gemini']")