import litellm
from dspy.clients.base_lm import BaseLM
class GeneralProvider(BaseLM):
def __init__(self, model_name: str, api_key: str):
self.model_name = model_name
self.api_key = api_key
# Call the parent constructor
super().__init__(model=model_name)
def forward(self, prompt=None, messages=None, **kwargs):
"""
Synchronous forward pass using LiteLLM.
"""
messages = messages or [{"role": "user", "content": prompt}]
request = {
"model": self.model_name,
"messages": messages,
"api_key": self.api_key,
**self.kwargs,
**kwargs,
}
try:
return litellm.completion(**request)
except Exception as e:
raise RuntimeError(f"Failed to get response from LiteLLM for model '{self.model_name}': {e}")
async def aforward(self, prompt=None, messages=None, **kwargs):
"""
Asynchronous forward pass using LiteLLM.
"""
messages = messages or [{"role": "user", "content": prompt}]
request = {
"model": self.model_name,
"messages": messages,
"api_key": self.api_key,
**self.kwargs,
**kwargs,
}
try:
return await litellm.acompletion(**request)
except Exception as e:
raise RuntimeError(f"Failed to get response from LiteLLM for model '{self.model_name}': {e}")