import logging
import copy
from typing import Dict, Any

from app.config import settings
from app.api import schemas

logger = logging.getLogger(__name__)

class PreferenceService:
    def __init__(self, services):
        self.services = services

    def mask_key(self, k: str) -> str:
        if not k: return None
        if len(k) <= 8: return "****"
        return k[:4] + "*" * (len(k)-8) + k[-4:]

    def merge_user_config(self, user, db) -> schemas.ConfigResponse:
        prefs_dict = user.preferences or {}

        def normalize_section(section_name, default_active):
            section = prefs_dict.get(section_name, {})
            # If already new style, just return a copy
            if isinstance(section, dict) and "providers" in section:
                return copy.deepcopy(section)
            
            # Legacy transformation
            providers = {}
            active = section.get("active_provider") or section.get("provider") or default_active
            
            # Known providers to check for legacy transformation
            legacy_keys = ["openai", "gemini", "deepseek", "gcloud_tts", "azure", "google", "elevenlabs"]
            for p in legacy_keys:
                if p in section:
                    providers[p] = section[p]
            
            # If still no providers found but it's not empty, it might be a flat dict of other providers
            if not providers and section and isinstance(section, dict):
                for k, v in section.items():
                    if k not in ["active_provider", "provider", "providers"] and isinstance(v, dict):
                        providers[k] = v

            return {
                "active_provider": str(active) if active else default_active,
                "providers": providers
            }

        llm_prefs = normalize_section("llm", "deepseek")
        tts_prefs = normalize_section("tts", settings.TTS_PROVIDER)
        stt_prefs = normalize_section("stt", settings.STT_PROVIDER)
        
        system_prefs = self.services.user_service.get_system_settings(db)
        system_statuses = system_prefs.get("statuses", {})
        user_statuses = prefs_dict.get("statuses", {})
        
        def is_provider_healthy(section: str, provider_id: str, p_data: dict = None) -> bool:
            status_key = f"{section}_{provider_id}"
            is_success = user_statuses.get(status_key) == "success" or system_statuses.get(status_key) == "success"
            has_key = p_data and p_data.get("api_key") and p_data.get("api_key") not in ("None", "none", "")
            return is_success or bool(has_key)

        # Build effective combined config for processing
        def get_effective_providers(section_name, user_section_providers, sys_defaults):
            # Start with system defaults if user has none
            effective_providers = {}
            if not user_section_providers:
                effective_providers = copy.deepcopy(sys_defaults)
            else:
                effective_providers = copy.deepcopy(user_section_providers)
            
            # Filter by health and mask keys
            res = {}
            for p, p_data in effective_providers.items():
                if p_data and is_provider_healthy(section_name, p, p_data):
                    masked_data = copy.deepcopy(p_data)
                    masked_data["api_key"] = self.mask_key(p_data.get("api_key"))
                    res[p] = masked_data
            return res

        system_llm = system_prefs.get("llm", {}).get("providers", {
            "deepseek": {"api_key": settings.DEEPSEEK_API_KEY, "model": settings.DEEPSEEK_MODEL_NAME},
            "gemini": {"api_key": settings.GEMINI_API_KEY, "model": settings.GEMINI_MODEL_NAME},
        })
        llm_providers_effective = get_effective_providers("llm", llm_prefs["providers"], system_llm)

        system_tts = system_prefs.get("tts", {}).get("providers", {
            settings.TTS_PROVIDER: {
                "api_key": settings.TTS_API_KEY,
                "model": settings.TTS_MODEL_NAME,
                "voice": settings.TTS_VOICE_NAME
            }
        })
        tts_providers_effective = get_effective_providers("tts", tts_prefs["providers"], system_tts)

        system_stt = system_prefs.get("stt", {}).get("providers", {
            settings.STT_PROVIDER: {"api_key": settings.STT_API_KEY, "model": settings.STT_MODEL_NAME}
        })
        stt_providers_effective = get_effective_providers("stt", stt_prefs["providers"], system_stt)
        
        effective = {
            "llm": {
                "active_provider": llm_prefs.get("active_provider") or (next(iter(llm_providers_effective), "deepseek")),
                "providers": llm_providers_effective
            },
            "tts": {
                "active_provider": tts_prefs.get("active_provider") or (next(iter(tts_providers_effective), settings.TTS_PROVIDER)),
                "providers": tts_providers_effective
            },
            "stt": {
                "active_provider": stt_prefs.get("active_provider") or (next(iter(stt_providers_effective), settings.STT_PROVIDER)),
                "providers": stt_providers_effective
            }
        }

        group = user.group or self.services.user_service.get_or_create_default_group(db)
        if group:
            policy = group.policy or {}
            def apply_policy(section_key, policy_key):
                allowed = policy.get(policy_key, [])
                if not allowed:
                    effective[section_key]["providers"] = {}
                    effective[section_key]["active_provider"] = ""
                    return
                
                providers = effective[section_key]["providers"]
                filtered_eff = {k: v for k, v in providers.items() if k in allowed}
                effective[section_key]["providers"] = filtered_eff
                
                if effective[section_key].get("active_provider") not in allowed:
                    effective[section_key]["active_provider"] = next(iter(filtered_eff), None) or ""

            apply_policy("llm", "llm")
            apply_policy("tts", "tts")
            apply_policy("stt", "stt")

        def mask_section_prefs(section_dict):
            if not section_dict: return {}
            masked_dict = copy.deepcopy(section_dict)
            providers = masked_dict.get("providers", {})
            for p_name, p_data in providers.items():
                if p_data.get("api_key"):
                    p_data["api_key"] = self.mask_key(p_data["api_key"])
            return masked_dict

        return schemas.ConfigResponse(
            preferences=schemas.UserPreferences(
                llm=mask_section_prefs(llm_prefs),
                tts=mask_section_prefs(tts_prefs),
                stt=mask_section_prefs(stt_prefs),
                statuses=user.preferences.get("statuses", {}) if user.preferences else {}
            ),
            effective=effective
        )


    def update_user_config(self, user, prefs: schemas.UserPreferences, db) -> schemas.UserPreferences:
        # When saving, if the api_key contains ****, we must retain the old one from the DB
        old_prefs = user.preferences or {}

        def get_old_providers(section_name):
            section = old_prefs.get(section_name, {})
            if isinstance(section, dict) and "providers" in section:
                return section["providers"]
            
            # Legacy extraction
            providers = {}
            legacy_keys = ["openai", "gemini", "deepseek", "gcloud_tts", "azure", "google", "elevenlabs"]
            for p in legacy_keys:
                if p in section:
                    providers[p] = section[p]
            
            if not providers and section and isinstance(section, dict):
                for k, v in section.items():
                    if k not in ["active_provider", "provider", "providers"] and isinstance(v, dict):
                        providers[k] = v
            return providers

        def preserve_masked_keys(section_name, new_section):
            if not new_section or "providers" not in new_section:
                return
            old_section_providers = get_old_providers(section_name)
            for p_name, p_data in new_section["providers"].items():
                if p_data.get("api_key") and "***" in str(p_data["api_key"]):
                    if p_name in old_section_providers:
                        p_data["api_key"] = old_section_providers[p_name].get("api_key")

        if prefs.llm: preserve_masked_keys("llm", prefs.llm)
        if prefs.tts: preserve_masked_keys("tts", prefs.tts)
        if prefs.stt: preserve_masked_keys("stt", prefs.stt)


        current_prefs = dict(user.preferences or {})
        current_prefs.update({
            "llm": prefs.llm,
            "tts": prefs.tts,
            "stt": prefs.stt,
            "statuses": prefs.statuses or {}
        })
        user.preferences = current_prefs
        
        if user.role == "admin":
            from sqlalchemy.orm.attributes import flag_modified
            flag_modified(user, "preferences")
            
            from app.config import settings as global_settings
            
            if prefs.llm and "providers" in prefs.llm:
                global_settings.LLM_PROVIDERS = dict(prefs.llm.get("providers", {}))
                
            if prefs.tts and prefs.tts.get("active_provider"):
                p_name = prefs.tts["active_provider"]
                p_data = prefs.tts.get("providers", {}).get(p_name, {})
                if p_data:
                    global_settings.TTS_PROVIDER = p_name
                    global_settings.TTS_MODEL_NAME = p_data.get("model") or global_settings.TTS_MODEL_NAME
                    global_settings.TTS_VOICE_NAME = p_data.get("voice") or global_settings.TTS_VOICE_NAME
                    global_settings.TTS_API_KEY = p_data.get("api_key") or global_settings.TTS_API_KEY
                    
            if prefs.stt and prefs.stt.get("active_provider"):
                p_name = prefs.stt["active_provider"]
                p_data = prefs.stt.get("providers", {}).get(p_name, {})
                if p_data:
                    global_settings.STT_PROVIDER = p_name
                    global_settings.STT_MODEL_NAME = p_data.get("model") or global_settings.STT_MODEL_NAME
                    global_settings.STT_API_KEY = p_data.get("api_key") or global_settings.STT_API_KEY

            try:
                global_settings.save_to_yaml()
            except Exception as ey:
                logger.error(f"Failed to sync settings to YAML: {ey}")
            
            logger.info(f"Saving updated global preferences via admin {user.id}")
        else:
            user.preferences["llm"]["active_provider"] = prefs.llm.get("active_provider")
            user.preferences["tts"]["active_provider"] = prefs.tts.get("active_provider")
            user.preferences["stt"]["active_provider"] = prefs.stt.get("active_provider")
            user.preferences["statuses"] = prefs.statuses or {}
            from sqlalchemy.orm.attributes import flag_modified
            flag_modified(user, "preferences")
            logger.info(f"Saving personal preferences for user {user.id}")

        db.add(user)
        db.commit()
        db.refresh(user)
        
        return schemas.UserPreferences(
            llm=user.preferences.get("llm", {}),
            tts=user.preferences.get("tts", {}),
            stt=user.preferences.get("stt", {}),
            statuses=user.preferences.get("statuses", {})
        )

