import litellm
from dspy.clients.base_lm import BaseLM
class GeneralProvider(BaseLM):
def __init__(self, model_name: str, api_key: str, system_prompt: str = None, **kwargs):
self.model_name = model_name
self.api_key = api_key
self.system_prompt = system_prompt
# Call the parent constructor
max_tokens = 8000
if model_name == "gemini":
max_tokens = 10000000
super().__init__(model=model_name, max_tokens=max_tokens, **kwargs)
def _prepare_messages(self, prompt=None, messages=None):
"""Helper to prepare the messages list, including the system prompt."""
if messages is None:
messages = [{"role": "user", "content": prompt}]
if self.system_prompt:
# Check if a system message is already present
if not messages or messages[0]['role'] != 'system':
messages.insert(0, {"role": "system", "content": self.system_prompt})
return messages
def forward(self, prompt=None, messages=None, **kwargs):
"""
Synchronous forward pass using LiteLLM.
"""
# Use the helper to prepare messages
prepared_messages = self._prepare_messages(prompt=prompt, messages=messages)
request = {
"model": self.model_name,
"messages": prepared_messages,
"api_key": self.api_key,
**self.kwargs,
**kwargs,
}
try:
return litellm.completion(**request)
except Exception as e:
raise RuntimeError(f"Failed to get response from LiteLLM for model '{self.model_name}': {e}")
async def aforward(self, prompt=None, messages=None, **kwargs):
"""
Asynchronous forward pass using LiteLLM.
"""
# Use the helper to prepare messages
prepared_messages = self._prepare_messages(prompt=prompt, messages=messages)
request = {
"model": self.model_name,
"messages": prepared_messages,
"api_key": self.api_key,
**self.kwargs,
**kwargs,
}
try:
return await litellm.acompletion(**request)
except Exception as e:
raise RuntimeError(f"Failed to get response from LiteLLM for model '{self.model_name}': {e}")