import litellm from dspy.clients.base_lm import BaseLM class GeneralProvider(BaseLM): def __init__(self, model_name: str, api_key: str, system_prompt: str = None, **kwargs): self.model_name = model_name self.api_key = api_key self.system_prompt = system_prompt # Call the parent constructor max_tokens = 8000 if model_name.startswith("gemini"): max_tokens = 10000000 super().__init__(model=model_name, max_tokens=max_tokens, **kwargs) def _prepare_messages(self, prompt=None, messages=None): """Helper to prepare the messages list, including the system prompt.""" if messages is None: messages = [{"role": "user", "content": prompt}] if self.system_prompt: # Check if a system message is already present if not messages or messages[0]['role'] != 'system': messages.insert(0, {"role": "system", "content": self.system_prompt}) return messages def forward(self, prompt=None, messages=None, **kwargs): """ Synchronous forward pass using LiteLLM. """ # Use the helper to prepare messages prepared_messages = self._prepare_messages(prompt=prompt, messages=messages) request = { "model": self.model_name, "messages": prepared_messages, "api_key": self.api_key, **self.kwargs, **kwargs, } try: return litellm.completion(**request) except Exception as e: raise RuntimeError(f"Failed to get response from LiteLLM for model '{self.model_name}': {e}") async def aforward(self, prompt=None, messages=None, **kwargs): """ Asynchronous forward pass using LiteLLM. """ # Use the helper to prepare messages prepared_messages = self._prepare_messages(prompt=prompt, messages=messages) request = { "model": self.model_name, "messages": prepared_messages, "api_key": self.api_key, **self.kwargs, **kwargs, } try: return await litellm.acompletion(**request) except Exception as e: raise RuntimeError(f"Failed to get response from LiteLLM for model '{self.model_name}': {e}")