Newer
Older
cortex-hub / ai-hub / app / core / llm_providers.py
import os
import httpx
import dspy
from abc import ABC, abstractmethod
from openai import OpenAI
from typing import final, Dict, Type

# --- 1. Load Configuration from Environment ---
# Best practice is to centralize configuration loading at the top.

# API Keys (required)
DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY")
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

# Model Names (optional, with defaults)
# Allows changing the model version without code changes.
DEEPSEEK_MODEL = os.getenv("DEEPSEEK_MODEL_NAME", "deepseek-chat")
GEMINI_MODEL = os.getenv("GEMINI_MODEL_NAME", "gemini-1.5-flash-latest")
OPENAI_MODEL = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")

# --- 2. Initialize API Clients and URLs ---
# Initialize any clients or constants that will be used by the providers.
deepseek_client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com")
GEMINI_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_MODEL}:generateContent?key={GEMINI_API_KEY}"

# --- 3. Provider Interface and Implementations ---

class LLMProvider(ABC):
    """Abstract base class ('Interface') for all LLM providers."""
    @abstractmethod
    async def generate_response(self, prompt: str) -> str:
        """Generates a response from the LLM."""
        pass

@final
class DeepSeekProvider(LLMProvider):
    """Provider for the DeepSeek API."""
    def __init__(self, model_name: str):
        self.model = model_name
        print(f"DeepSeekProvider initialized with model: {self.model}")

    async def generate_response(self, prompt: str) -> str:
        try:
            chat_completion = deepseek_client.chat.completions.create(
                model=self.model,
                messages=[
                    {"role": "system", "content": "You are a helpful assistant."},
                    {"role": "user", "content": prompt},
                ],
                stream=False
            )
            return chat_completion.choices[0].message.content
        except Exception as e:
            print(f"DeepSeek Error: {e}")
            raise  # Re-raise the exception to be handled by the main app

@final
class GeminiProvider(LLMProvider):
    """Provider for the Google Gemini API."""
    def __init__(self, api_url: str):
        self.url = api_url
        print(f"GeminiProvider initialized for URL: {self.url.split('?')[0]}")

    async def generate_response(self, prompt: str) -> str:
        payload = {"contents": [{"parts": [{"text": prompt}]}]}
        headers = {"Content-Type": "application/json"}
        
        try:
            async with httpx.AsyncClient() as client:
                response = await client.post(self.url, json=payload, headers=headers)
                response.raise_for_status()
                data = response.json()
                return data['candidates'][0]['content']['parts'][0]['text']
        except (httpx.HTTPStatusError, KeyError, IndexError) as e:
            print(f"Gemini Error: {e}")
            raise # Re-raise for the main app to handle

# --- 4. The Factory Function ---
# This is where we instantiate our concrete providers with their configuration.

_providers = {
    "deepseek": DeepSeekProvider(model_name=DEEPSEEK_MODEL),
    "gemini": GeminiProvider(api_url=GEMINI_URL)
}

def get_llm_provider(model_name: str) -> LLMProvider:
    """Factory function to get the appropriate, pre-configured LLM provider."""
    provider = _providers.get(model_name)
    if not provider:
        raise ValueError(f"Unsupported model provider: '{model_name}'. Supported providers are: {list(_providers.keys())}")
    return provider


# --- 5. DSPy-specific Bridge ---
# This function helps to bridge our providers with DSPy's required LM classes.

def get_dspy_lm(model_name: str, api_key: str) -> dspy.LM:
    """
    Factory function to get a DSPy-compatible language model instance.
    
    Args:
        model_name (str): The name of the model to use.
        api_key (str): The API key for the model.
        
    Returns:
        dspy.LM: An instantiated DSPy language model object.
    
    Raises:
        ValueError: If the provided model name is not supported.
    """
    if model_name == DEEPSEEK_MODEL:
        # Use DSPy's OpenAI wrapper for DeepSeek, with a custom base_url
        return dspy.OpenAI(
            model=DEEPSEEK_MODEL,
            api_key=api_key,
            api_base="https://api.deepseek.com/v1"
        )
    elif model_name == OPENAI_MODEL:
        # Use DSPy's OpenAI wrapper for standard OpenAI models
        return dspy.OpenAI(model=OPENAI_MODEL, api_key=api_key)
    elif model_name == GEMINI_MODEL:
        # Use DSPy's Google wrapper for Gemini
        return dspy.Google(model=GEMINI_MODEL, api_key=api_key)
    else:
        raise ValueError(f"Unsupported DSPy model: '{model_name}'. Supported models are: {DEEPSEEK_MODEL, OPENAI_MODEL, GEMINI_MODEL}")