import litellm
from dspy.clients.base_lm import BaseLM

class GeneralProvider(BaseLM):
    def __init__(self, model_name: str, api_key: str, system_prompt: str = None, **kwargs):
        self.model_name = model_name
        self.api_key = api_key
        self.system_prompt = system_prompt
        
        # Validate API Key early
        if not api_key or "*" in str(api_key):
             raise ValueError(f"Invalid or missing API key for LLM provider '{model_name}'. Please configure it in Settings.")

        # Determine max tokens dynamically via LiteLLM info
        max_tokens = 8000
        try:
            info = litellm.get_model_info(model_name)
            if info and "max_tokens" in info:
                max_tokens = info["max_tokens"]
        except:
            pass
            
        super().__init__(model=model_name, max_tokens=max_tokens, **kwargs)

    def _prepare_messages(self, prompt=None, messages=None):
        """Helper to prepare the messages list, including the system prompt."""
        if messages is None:
            messages = [{"role": "user", "content": prompt}]

        if self.system_prompt:
            # Check if a system message is already present
            if not messages or messages[0]['role'] != 'system':
                messages.insert(0, {"role": "system", "content": self.system_prompt})
        
        return messages

    def forward(self, prompt=None, messages=None, **kwargs):
        """
        Synchronous forward pass using LiteLLM.
        """
        # Use the helper to prepare messages
        prepared_messages = self._prepare_messages(prompt=prompt, messages=messages)

        request = {
            "model": self.model_name,
            "messages": prepared_messages,
            "api_key": self.api_key,
            **self.kwargs,
            **kwargs,
        }
        try:
            return litellm.completion(**request)
        except Exception as e:
            # Distinguish between network errors and missing credits
            err_msg = str(e)
            if "authentication" in err_msg.lower() or "401" in err_msg:
                raise RuntimeError(f"Authentication failed for {self.model_name}. Check your API key.")
            raise RuntimeError(f"LiteLLM Error ({self.model_name}): {err_msg}")

    async def aforward(self, prompt=None, messages=None, **kwargs):
        """
        Asynchronous forward pass using LiteLLM.
        """
        # Use the helper to prepare messages
        prepared_messages = self._prepare_messages(prompt=prompt, messages=messages)

        request = {
            "model": self.model_name,
            "messages": prepared_messages,
            "api_key": self.api_key,
            **self.kwargs,
            **kwargs,
        }
        try:
            return await litellm.acompletion(**request)
        except Exception as e:
            err_msg = str(e)
            if "authentication" in err_msg.lower() or "401" in err_msg:
                raise RuntimeError(f"Authentication failed for {self.model_name}. Check your API key.")
            raise RuntimeError(f"LiteLLM Error ({self.model_name}): {err_msg}")