Newer
Older
cortex-hub / ai-hub / app / core / providers / llm / general.py
import litellm
from dspy.clients.base_lm import BaseLM

class GeneralProvider(BaseLM):
    def __init__(self, model_name: str, api_key: str, system_prompt: str = None, **kwargs):
        self.model_name = model_name
        self.api_key = api_key
        self.system_prompt = system_prompt
        # Call the parent constructor
        max_tokens = 8000
        if model_name.startswith("gemini"):
            max_tokens = 10000000
        super().__init__(model=model_name, max_tokens=max_tokens, **kwargs)

    def _prepare_messages(self, prompt=None, messages=None):
        """Helper to prepare the messages list, including the system prompt."""
        if messages is None:
            messages = [{"role": "user", "content": prompt}]

        if self.system_prompt:
            # Check if a system message is already present
            if not messages or messages[0]['role'] != 'system':
                messages.insert(0, {"role": "system", "content": self.system_prompt})
        
        return messages

    def forward(self, prompt=None, messages=None, **kwargs):
        """
        Synchronous forward pass using LiteLLM.
        """
        # Use the helper to prepare messages
        prepared_messages = self._prepare_messages(prompt=prompt, messages=messages)

        request = {
            "model": self.model_name,
            "messages": prepared_messages,
            "api_key": self.api_key,
            **self.kwargs,
            **kwargs,
        }
        try:
            return litellm.completion(**request)
        except Exception as e:
            raise RuntimeError(f"Failed to get response from LiteLLM for model '{self.model_name}': {e}")

    async def aforward(self, prompt=None, messages=None, **kwargs):
        """
        Asynchronous forward pass using LiteLLM.
        """
        # Use the helper to prepare messages
        prepared_messages = self._prepare_messages(prompt=prompt, messages=messages)

        request = {
            "model": self.model_name,
            "messages": prepared_messages,
            "api_key": self.api_key,
            **self.kwargs,
            **kwargs,
        }
        try:
            return await litellm.acompletion(**request)
        except Exception as e:
            raise RuntimeError(f"Failed to get response from LiteLLM for model '{self.model_name}': {e}")