Newer
Older
cortex-hub / ai-hub / app / core / services / preference.py
import logging
import copy
from typing import List, Dict, Any

from app.config import settings
from app.api import schemas

logger = logging.getLogger(__name__)

class PreferenceService:
    def __init__(self, services):
        self.services = services

    def mask_key(self, k: str) -> str:
        if not k: return None
        if len(k) <= 8: return "****"
        return k[:4] + "*" * (len(k)-8) + k[-4:]

    def merge_user_config(self, user, db) -> schemas.ConfigResponse:
        prefs_dict = user.preferences or {}

        def normalize_section(section_name, default_active):
            section = prefs_dict.get(section_name, {})
            # If already new style, just return a copy
            if isinstance(section, dict) and "providers" in section:
                return copy.deepcopy(section)
            
            # Legacy transformation
            providers = {}
            active = section.get("active_provider") or section.get("provider") or default_active
            
            # Known providers to check for legacy transformation
            legacy_keys = ["openai", "gemini", "deepseek", "gcloud_tts", "azure", "google", "elevenlabs"]
            for p in legacy_keys:
                if p in section:
                    providers[p] = section[p]
            
            # If still no providers found but it's not empty, it might be a flat dict of other providers
            if not providers and section and isinstance(section, dict):
                for k, v in section.items():
                    if k not in ["active_provider", "provider", "providers"] and isinstance(v, dict):
                        providers[k] = v

            return {
                "active_provider": str(active) if active else default_active,
                "providers": providers
            }

        llm_prefs = normalize_section("llm", "deepseek")
        tts_prefs = normalize_section("tts", settings.TTS_PROVIDER)
        stt_prefs = normalize_section("stt", settings.STT_PROVIDER)
        
        system_prefs = self.services.user_service.get_system_settings(db)
        system_statuses = system_prefs.get("statuses", {})
        user_statuses = prefs_dict.get("statuses", {})
        
        def is_provider_healthy(section: str, provider_id: str, p_data: dict = None) -> bool:
            status_key = f"{section}_{provider_id}"
            is_success = user_statuses.get(status_key) == "success" or system_statuses.get(status_key) == "success"
            has_key = p_data and p_data.get("api_key") and p_data.get("api_key") not in ("None", "none", "")
            return is_success or bool(has_key)

        # Build effective combined config for processing
        def get_effective_providers(section_name, user_section_providers, sys_defaults):
            # M6/M3: Deep merge system defaults with user overrides.
            # This ensures that if a user configures a model name in the UI, 
            # they don't lose the API key from config.yaml if it's not provided in the UI.
            
            # Start with system defaults (config.yaml / env)
            effective = copy.deepcopy(sys_defaults)
            
            # Layer on overrides (Admin or User)
            if user_section_providers:
                for p_id, p_data in user_section_providers.items():
                    if p_id not in effective:
                        # New provider defined in UI
                        effective[p_id] = copy.deepcopy(p_data)
                    else:
                        # Override existing provider fields
                        for field, val in p_data.items():
                            # Only override if the value is meaningful (not empty/null)
                            if val is not None and val != "" and str(val).lower() != "none":
                                effective[p_id][field] = val
            
            # Filter by health and mask keys for the response
            res = {}
            for p, p_data in effective.items():
                if p_data and is_provider_healthy(section_name, p, p_data):
                    masked_data = copy.deepcopy(p_data)
                    masked_data["api_key"] = self.mask_key(p_data.get("api_key"))
                    res[p] = masked_data
            return res

        def get_merged_system_defaults(section_name, hardcoded_defaults):
            # M6/M3: Merge Admin overrides with hardcoded config.yaml defaults.
            # This prevents the admin from losing system-level keys when they 
            # customize other fields (like model names) in the UI.
            sys_prefs_section = system_prefs.get(section_name, {})
            sys_providers = sys_prefs_section.get("providers", {})
            
            if not sys_providers:
                return hardcoded_defaults
            
            # Start with hardcoded defaults
            merged = copy.deepcopy(hardcoded_defaults)
            for p_id, p_data in sys_providers.items():
                if p_id not in merged:
                    merged[p_id] = p_data
                else:
                    for field, val in p_data.items():
                        if val is not None and val != "" and str(val).lower() != "none":
                            merged[p_id][field] = val
            return merged

        system_llm = get_merged_system_defaults("llm", settings.LLM_PROVIDERS)
        llm_providers_effective = get_effective_providers("llm", llm_prefs["providers"], system_llm)

        system_tts = get_merged_system_defaults("tts", {
            settings.TTS_PROVIDER: {
                "api_key": settings.TTS_API_KEY,
                "model": settings.TTS_MODEL_NAME,
                "voice": settings.TTS_VOICE_NAME
            }
        })
        tts_providers_effective = get_effective_providers("tts", tts_prefs["providers"], system_tts)

        system_stt = get_merged_system_defaults("stt", {
            settings.STT_PROVIDER: {"api_key": settings.STT_API_KEY, "model": settings.STT_MODEL_NAME}
        })
        stt_providers_effective = get_effective_providers("stt", stt_prefs["providers"], system_stt)
        
        effective = {
            "llm": {
                "active_provider": llm_prefs.get("active_provider") or (next(iter(llm_providers_effective), "deepseek")),
                "providers": llm_providers_effective
            },
            "tts": {
                "active_provider": tts_prefs.get("active_provider") or (next(iter(tts_providers_effective), settings.TTS_PROVIDER)),
                "providers": tts_providers_effective
            },
            "stt": {
                "active_provider": stt_prefs.get("active_provider") or (next(iter(stt_providers_effective), settings.STT_PROVIDER)),
                "providers": stt_providers_effective
            }
        }

        group = user.group or self.services.user_service.get_or_create_default_group(db)
        if group:
            policy = group.policy or {}
            def apply_policy(section_key, policy_key):
                allowed = policy.get(policy_key, [])
                if not allowed:
                    effective[section_key]["providers"] = {}
                    effective[section_key]["active_provider"] = ""
                    return
                
                providers = effective[section_key]["providers"]
                filtered_eff = {k: v for k, v in providers.items() if k in allowed}
                effective[section_key]["providers"] = filtered_eff
                
                if effective[section_key].get("active_provider") not in allowed:
                    effective[section_key]["active_provider"] = next(iter(filtered_eff), None) or ""

            apply_policy("llm", "llm")
            apply_policy("tts", "tts")
            apply_policy("stt", "stt")

        def mask_section_prefs(section_dict):
            if not section_dict: return {}
            masked_dict = copy.deepcopy(section_dict)
            providers = masked_dict.get("providers", {})
            for p_name, p_data in providers.items():
                if p_data.get("api_key"):
                    p_data["api_key"] = self.mask_key(p_data["api_key"])
            return masked_dict

        merged_statuses = copy.deepcopy(system_statuses)
        merged_statuses.update(user_statuses)

        return schemas.ConfigResponse(
            preferences=schemas.UserPreferences(
                llm=mask_section_prefs(llm_prefs),
                tts=mask_section_prefs(tts_prefs),
                stt=mask_section_prefs(stt_prefs),
                statuses=merged_statuses
            ),
            effective=effective
        )


    def update_user_config(self, user, prefs: schemas.UserPreferences, db) -> schemas.UserPreferences:
        # When saving, if the api_key contains ****, we must retain the old one from the DB
        old_prefs = user.preferences or {}

        def get_old_providers(section_name):
            section = old_prefs.get(section_name, {})
            if isinstance(section, dict) and "providers" in section:
                return section["providers"]
            
            # Legacy extraction
            providers = {}
            legacy_keys = ["openai", "gemini", "deepseek", "gcloud_tts", "azure", "google", "elevenlabs"]
            for p in legacy_keys:
                if p in section:
                    providers[p] = section[p]
            
            if not providers and section and isinstance(section, dict):
                for k, v in section.items():
                    if k not in ["active_provider", "provider", "providers"] and isinstance(v, dict):
                        providers[k] = v
            return providers

        def preserve_masked_keys(section_name, new_section):
            if not new_section or "providers" not in new_section:
                return
            old_section_providers = get_old_providers(section_name)
            for p_name, p_data in new_section["providers"].items():
                if p_data.get("api_key") and "***" in str(p_data["api_key"]):
                    if p_name in old_section_providers:
                        p_data["api_key"] = old_section_providers[p_name].get("api_key")

        if prefs.llm: preserve_masked_keys("llm", prefs.llm)
        if prefs.tts: preserve_masked_keys("tts", prefs.tts)
        if prefs.stt: preserve_masked_keys("stt", prefs.stt)


        current_prefs = dict(user.preferences or {})
        current_prefs.update({
            "llm": prefs.llm,
            "tts": prefs.tts,
            "stt": prefs.stt,
            "statuses": prefs.statuses or {}
        })
        user.preferences = current_prefs
        
        if user.role == "admin":
            from sqlalchemy.orm.attributes import flag_modified
            flag_modified(user, "preferences")
            
            from app.config import settings as global_settings
            
            if prefs.llm and "providers" in prefs.llm:
                global_settings.LLM_PROVIDERS = dict(prefs.llm.get("providers", {}))
                
            if prefs.tts and prefs.tts.get("active_provider"):
                p_name = prefs.tts["active_provider"]
                p_data = prefs.tts.get("providers", {}).get(p_name, {})
                if p_data:
                    global_settings.TTS_PROVIDER = p_name
                    global_settings.TTS_MODEL_NAME = p_data.get("model") or global_settings.TTS_MODEL_NAME
                    global_settings.TTS_VOICE_NAME = p_data.get("voice") or global_settings.TTS_VOICE_NAME
                    global_settings.TTS_API_KEY = p_data.get("api_key") or global_settings.TTS_API_KEY
                    
            if prefs.stt and prefs.stt.get("active_provider"):
                p_name = prefs.stt["active_provider"]
                p_data = prefs.stt.get("providers", {}).get(p_name, {})
                if p_data:
                    global_settings.STT_PROVIDER = p_name
                    global_settings.STT_MODEL_NAME = p_data.get("model") or global_settings.STT_MODEL_NAME
                    global_settings.STT_API_KEY = p_data.get("api_key") or global_settings.STT_API_KEY

            logger.info(f"Saving updated global preferences via admin {user.id}")
        else:
            user.preferences["llm"]["active_provider"] = prefs.llm.get("active_provider")
            user.preferences["tts"]["active_provider"] = prefs.tts.get("active_provider")
            user.preferences["stt"]["active_provider"] = prefs.stt.get("active_provider")
            user.preferences["statuses"] = prefs.statuses or {}
            from sqlalchemy.orm.attributes import flag_modified
            flag_modified(user, "preferences")
            logger.info(f"Saving personal preferences for user {user.id}")

        db.add(user)
        db.commit()
        db.refresh(user)
        
        return schemas.UserPreferences(
            llm=user.preferences.get("llm", {}),
            tts=user.preferences.get("tts", {}),
            stt=user.preferences.get("stt", {}),
            statuses=user.preferences.get("statuses", {})
        )

    def export_config_yaml(self, user, reveal_secrets: bool) -> str:
        import yaml
        from app.core.grpc.utils.crypto import encrypt_value

        prefs_dict = copy.deepcopy(user.preferences) if user.preferences else {}
        sensitive_keys = ["api_key", "client_secret", "webhook_secret", "password", "key_content", "key_file"]

        def process_export(obj):
            if isinstance(obj, dict):
                res = {}
                for k, v in obj.items():
                    if v is None: continue
                    if k in sensitive_keys and v:
                        res[k] = v if reveal_secrets else encrypt_value(v)
                    else:
                        res[k] = process_export(v)
                return res
            elif isinstance(obj, list):
                return [process_export(x) for x in obj]
            return obj

        export_data = {
            "llm": prefs_dict.get("llm", {"providers": {}, "active_provider": "deepseek"}),
            "tts": prefs_dict.get("tts", {"providers": {}, "active_provider": settings.TTS_PROVIDER}),
            "stt": prefs_dict.get("stt", {"providers": {}, "active_provider": settings.STT_PROVIDER})
        }
        
        # Backfill from settings if empty
        if not export_data["llm"].get("providers"):
            export_data["llm"]["providers"] = settings.LLM_PROVIDERS

        return yaml.dump(process_export(export_data), sort_keys=False, default_flow_style=False)

    async def import_config_yaml(self, db, user, content: bytes) -> schemas.UserPreferences:
        import yaml
        from app.core.grpc.utils.crypto import decrypt_value
        from sqlalchemy.orm.attributes import flag_modified

        try: data = yaml.safe_load(content)
        except Exception as e: raise Exception(f"Invalid YAML: {e}")

        def process_import(obj):
            if isinstance(obj, dict): return {k: process_import(v) for k, v in obj.items()}
            elif isinstance(obj, str): return decrypt_value(obj)
            elif isinstance(obj, list): return [process_import(x) for x in obj]
            return obj

        data = process_import(data)
        user.preferences = {
            "llm": data.get("llm", {}),
            "tts": data.get("tts", {}),
            "stt": data.get("stt", {}),
            "statuses": {}
        }
        flag_modified(user, "preferences")
        db.commit()
        return schemas.UserPreferences(llm=user.preferences["llm"], tts=user.preferences["tts"], stt=user.preferences["stt"])

    async def verify_provider(self, db, user, req: schemas.VerifyProviderRequest, section: str) -> schemas.VerifyProviderResponse:
        from app.core.providers.factory import get_llm_provider, get_tts_provider, get_stt_provider
        
        # Admin or personal key check
        is_masked = not req.api_key or "***" in str(req.api_key)
        if is_masked and user.role != "admin":
             return schemas.VerifyProviderResponse(success=False, message="Forbidden: Admin only for masked keys")

        actual_key = req.api_key
        prefs = user.preferences.get(section, {}).get("providers", {}).get(req.provider_name, {}) if user.preferences else {}
        
        if is_masked:
             actual_key = prefs.get("api_key")
             if not actual_key:
                  s_prefs = self.services.user_service.get_system_settings(db)
                  actual_key = s_prefs.get(section, {}).get("providers", {}).get(req.provider_type or req.provider_name, {}).get("api_key")

        try:
            if section == "llm":
                llm = get_llm_provider(req.provider_name, model_name=req.model or "", api_key_override=actual_key)
                await llm.acompletion(prompt="Hello")
            elif section == "tts":
                p = get_tts_provider(req.provider_name, api_key=actual_key, model_name=req.model or "", voice_name=req.voice or "")
                await p.generate_speech("Test")
            else:
                get_stt_provider(req.provider_name, api_key=actual_key, model_name=req.model or "")
            return schemas.VerifyProviderResponse(success=True, message="Success!")
        except Exception as e:
            return schemas.VerifyProviderResponse(success=False, message=str(e))

    def resolve_llm_provider(self, db, user, provider_name: str, model_name: str = None) -> Any:
        """
        Unified resolution for LLM providers with full fallback chain:
        User Preference -> System Override (Admin UI) -> Config Defaults (YAML/Env)
        """
        from app.core.providers.factory import get_llm_provider
        
        base_key = provider_name.split("/")[0] if provider_name else ""
        if not provider_name:
            # Fallback to registered active provider
            u_svc = getattr(self.services, "user_service", None)
            sys_prefs = u_svc.get_system_settings(db) if u_svc else {}
            
            user_active = user.preferences.get("llm", {}).get("active_provider") if user and user.preferences else None
            base_key = user_active or sys_prefs.get("llm", {}).get("active_provider")
            provider_name = base_key

        user_providers = user.preferences.get("llm", {}).get("providers", {}) if user and user.preferences else {}
        llm_prefs = user_providers.get(base_key, {})
        
        # Prefix matching: 'gemini' -> 'gemini_gemini-3-flash-preview'
        if not llm_prefs and base_key:
             for pk, pv in user_providers.items():
                 if pk.startswith(f"{base_key}_") or pk == base_key:
                     llm_prefs = pv
                     provider_name = pk
                     base_key = pk
                     # Derive model from the key suffix if the stored model is generic/missing
                     # e.g. key='gemini_gemini-3-flash-preview' -> derived_model='gemini-3-flash-preview'
                     parts = pk.split("_", 1)
                     if len(parts) == 2:
                         derived_model = parts[1]
                         stored_model = pv.get("model", "")
                         # Only use the derived model if stored one looks generic (no dash = no version)
                         if not stored_model or "-" not in stored_model:
                             llm_prefs = dict(pv, model=derived_model)
                     break
                     
        logger.info(f"[Preference] Resolved match for '{base_key}': model={llm_prefs.get('model')}, has_key={'api_key' in llm_prefs}")
        
        # Resolve Resolved Model/Provider names
        provider_name_str = str(provider_name) if provider_name else ""
        has_slash = "/" in provider_name_str
        resolved_model = provider_name_str.split("/")[1] if has_slash else (model_name or llm_prefs.get("model", ""))
        resolved_provider_name = provider_name_str.split("/")[0] if has_slash else provider_name_str

        # 3. Last Resort: Pick the first available provider for the user/tenant if no preference exists
        if not resolved_provider_name or not resolved_model:
            available_providers = self.get_user_llm_providers(user.id, db)
            if available_providers:
                # Use the first available provider as the default
                default_p = available_providers[0]
                resolved_provider_name = default_p.name
                resolved_model = default_p.model
                logger.info(f"[Preference] No preference set for user {user.id}. Defaulting to first available: {resolved_provider_name}")
            else:
                # No providers configured for this user/tenant at all
                logger.error(f"[Preference] No LLM providers configured for user {user.id}")
                raise HTTPException(
                    status_code=400, 
                    detail="No LLM providers configured. Please configure an LLM provider (e.g. Gemini, OpenAI) in your user settings."
                )
        
        logger.info(f"[Preference] Final Resolution for {user.id}: {resolved_provider_name} / {resolved_model}")
        
        # The resolved_provider_name may be a DB-internal composite key like 'gemini_gemini-3-flash-preview'.
        # LiteLLM needs the base provider type (e.g. 'gemini'), not the full key.
        # The model string (e.g. 'gemini/gemini-3-flash-preview') already encodes the routing info.
        litellm_provider = resolved_provider_name.split("_")[0] if "_" in resolved_provider_name else resolved_provider_name
        
        try:
            return get_llm_provider(
                litellm_provider, 
                model_name=resolved_model, 
                api_key_override=llm_prefs.get("api_key"),
                **{k: v for k, v in llm_prefs.items() if k not in ["api_key", "model"]}
            ), resolved_provider_name
        except Exception as e:
            raise ValueError(f"Failed to initialize LLM provider '{litellm_provider}' with model '{resolved_model}': {e}")

    async def get_provider_models(self, provider_name: str, section: str = "llm") -> List[Dict[str, Any]]:
        """Fetches supported models for a specific provider and section using LiteLLM."""
        import litellm
        from fastapi.concurrency import run_in_threadpool
        
        def fetch_models():
            try:
                models = litellm.models_by_provider.get(provider_name, [])
                out = []
                for m in models:
                    try:
                        info = litellm.get_model_info(m)
                        if "error" not in info:
                            mode = info.get("mode")
                            is_valid = False
                            if section == "llm": is_valid = mode in ["chat", "text-completion", "custom", None]
                            elif section == "tts": is_valid = mode == "audio_speech"
                            elif section == "stt": is_valid = mode == "audio_transcription"
                            elif section == "image": is_valid = mode == "image_generation"
                            else: is_valid = True

                            if is_valid:
                                out.append({"model_name": m, "max_tokens": info.get("max_tokens"), "max_input_tokens": info.get("max_input_tokens")})
                    except: pass
                return out
            except: return []
                
        return await run_in_threadpool(fetch_models)

    def get_all_providers(self, db, user, section: str = "llm", configured_only: bool = False) -> List[str]:
        """Returns valid providers for a section, optionally filtering by those with configured credentials."""
        import litellm
        from app.core.providers.factory import get_registered_tts_providers, get_registered_stt_providers
        
        if configured_only:
            system_prefs = self.services.user_service.get_system_settings(db)
            user_prefs = user.preferences if user else {}
            
            configured = set(system_prefs.get(section, {}).get("providers", {}).keys())
            configured.update(user_prefs.get(section, {}).get("providers", {}).keys())
            
            if not configured:
                if section == "llm": configured.update(["deepseek", "gemini"])
                elif section == "tts": configured.add(settings.TTS_PROVIDER)
                elif section == "stt": configured.add(settings.STT_PROVIDER)
            return sorted(list(configured))

        if section == "llm": return ["general"] + [p.value for p in litellm.LlmProviders]
        elif section == "tts": return ["general"] + get_registered_tts_providers() + ["openai"]
        elif section == "stt": return ["general"] + get_registered_stt_providers() + ["openai"]
        return ["general"] + [p.value for p in litellm.LlmProviders]