import litellm
import asyncio
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
class GeneralProvider:
def __init__(self, model_name: str, api_key: str, system_prompt: str = None, **kwargs):
self.model_name = model_name
self.api_key = api_key
self.system_prompt = system_prompt
self.kwargs = kwargs
# Validate API Key early
if not api_key or "*" in str(api_key):
raise ValueError(f"Invalid or missing API key for LLM provider '{model_name}'. Please configure it in Settings.")
def _prepare_messages(self, prompt=None, messages=None):
"""Helper to prepare the messages list, including the system prompt."""
if messages is None:
messages = [{"role": "user", "content": prompt}]
if self.system_prompt:
# Check if a system message is already present
if not messages or messages[0]['role'] != 'system':
messages.insert(0, {"role": "system", "content": self.system_prompt})
return messages
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=2, max=10),
retry=retry_if_exception_type((litellm.ServiceUnavailableError, litellm.InternalServerError)),
reraise=True
)
async def _acompletion_with_retry(self, request):
"""Internal helper for filtered retries."""
try:
return await litellm.acompletion(**request)
except Exception as e:
# Handle specific LiteLLM errors that tenacity doesn't catch natively via type
err_msg = str(e).lower()
if "503" in err_msg or "unavailable" in err_msg or "high demand" in err_msg:
# Force retry if it looks like a temporary provider spike
provider_hint = self.model_name.split("/")[0] if "/" in self.model_name else "litellm"
raise litellm.ServiceUnavailableError(f"Temporary Provider Spike detected: {e}", model=self.model_name, llm_provider=provider_hint)
raise
async def acompletion(self, prompt=None, messages=None, **kwargs):
"""
Asynchronous completion pass using LiteLLM with intelligent retries.
"""
prepared_messages = self._prepare_messages(prompt=prompt, messages=messages)
request = {
"model": self.model_name,
"messages": prepared_messages,
"api_key": self.api_key,
**self.kwargs,
**kwargs,
}
try:
return await self._acompletion_with_retry(request)
except Exception as e:
err_msg = str(e)
if "authentication" in err_msg.lower() or "401" in err_msg:
raise RuntimeError(f"Authentication failed for {self.model_name}. Check your API key.")
# If we still fail after retries, wrap it in a cleaner runtime error
raise RuntimeError(f"Core Orchestrator Fault: {err_msg}")