import logging
import copy
from typing import Dict, Any
from app.config import settings
from app.api import schemas
logger = logging.getLogger(__name__)
class PreferenceService:
def __init__(self, services):
self.services = services
def mask_key(self, k: str) -> str:
if not k: return None
if len(k) <= 8: return "****"
return k[:4] + "*" * (len(k)-8) + k[-4:]
def merge_user_config(self, user, db) -> Dict[str, Any]:
prefs_dict = user.preferences or {}
llm_prefs = prefs_dict.get("llm", {})
tts_prefs = prefs_dict.get("tts", {})
stt_prefs = prefs_dict.get("stt", {})
system_prefs = self.services.user_service.get_system_settings(db)
system_statuses = system_prefs.get("statuses", {})
user_statuses = prefs_dict.get("statuses", {})
def is_provider_healthy(section: str, provider_id: str, p_data: dict = None) -> bool:
status_key = f"{section}_{provider_id}"
is_success = user_statuses.get(status_key) == "success" or system_statuses.get(status_key) == "success"
has_key = p_data and p_data.get("api_key") and p_data.get("api_key") not in ("None", "none", "")
return is_success or bool(has_key)
# Build effective providers map
# ... simplifying the code from user.py
user_providers = llm_prefs.get("providers", {})
if not user_providers:
system_llm = system_prefs.get("llm", {}).get("providers", {})
user_providers = system_llm if system_llm else {
"deepseek": {"api_key": settings.DEEPSEEK_API_KEY, "model": settings.DEEPSEEK_MODEL_NAME},
"gemini": {"api_key": settings.GEMINI_API_KEY, "model": settings.GEMINI_MODEL_NAME},
}
llm_providers_effective = {
p: {"api_key": self.mask_key(p_p.get("api_key")), "model": p_p.get("model")}
for p, p_p in user_providers.items() if p_p and is_provider_healthy("llm", p, p_p)
}
user_tts_providers = tts_prefs.get("providers", {})
if not user_tts_providers:
system_tts = system_prefs.get("tts", {}).get("providers", {})
user_tts_providers = system_tts if system_tts else {
settings.TTS_PROVIDER: {
"api_key": settings.TTS_API_KEY,
"model": settings.TTS_MODEL_NAME,
"voice": settings.TTS_VOICE_NAME
}
}
tts_providers_effective = {
p: {
"api_key": self.mask_key(p_p.get("api_key")),
"model": p_p.get("model"),
"voice": p_p.get("voice")
}
for p, p_p in user_tts_providers.items() if p_p and is_provider_healthy("tts", p, p_p)
}
user_stt_providers = stt_prefs.get("stt", {}).get("providers", {}) or stt_prefs.get("providers", {})
if not user_stt_providers:
system_stt = system_prefs.get("stt", {}).get("providers", {})
user_stt_providers = system_stt if system_stt else {
settings.STT_PROVIDER: {"api_key": settings.STT_API_KEY, "model": settings.STT_MODEL_NAME}
}
stt_providers_effective = {
p: {"api_key": self.mask_key(p_p.get("api_key")), "model": p_p.get("model")}
for p, p_p in user_stt_providers.items() if p_p and is_provider_healthy("stt", p, p_p)
}
effective = {
"llm": {
"active_provider": llm_prefs.get("active_provider") or (next(iter(llm_providers_effective), None)) or "deepseek",
"providers": llm_providers_effective
},
"tts": {
"active_provider": tts_prefs.get("active_provider") or (next(iter(tts_providers_effective), None)) or settings.TTS_PROVIDER,
"providers": tts_providers_effective
},
"stt": {
"active_provider": stt_prefs.get("active_provider") or (next(iter(stt_providers_effective), None)) or settings.STT_PROVIDER,
"providers": stt_providers_effective
}
}
group = user.group or self.services.user_service.get_or_create_default_group(db)
if group and user.role != "admin":
policy = group.policy or {}
def apply_policy(section_key, policy_key, p_dict):
allowed = policy.get(policy_key, [])
if not allowed:
effective[section_key]["providers"] = {}
if p_dict and "providers" in p_dict: p_dict["providers"] = {}
effective[section_key]["active_provider"] = ""
return p_dict
providers = effective[section_key]["providers"]
filtered_eff = {k: v for k, v in providers.items() if k in allowed}
effective[section_key]["providers"] = filtered_eff
if p_dict and "providers" in p_dict:
p_dict["providers"] = {k: v for k, v in p_dict["providers"].items() if k in allowed}
if effective[section_key].get("active_provider") not in allowed:
effective[section_key]["active_provider"] = next(iter(filtered_eff), None) or ""
return p_dict
llm_prefs = apply_policy("llm", "llm", llm_prefs)
tts_prefs = apply_policy("tts", "tts", tts_prefs)
stt_prefs = apply_policy("stt", "stt", stt_prefs)
def mask_section_prefs(section_dict):
if not section_dict: return {}
masked_dict = copy.deepcopy(section_dict)
providers = masked_dict.get("providers", {})
for p_name, p_data in providers.items():
if p_data.get("api_key"):
p_data["api_key"] = self.mask_key(p_data["api_key"])
return masked_dict
return schemas.ConfigResponse(
preferences=schemas.UserPreferences(
llm=mask_section_prefs(llm_prefs),
tts=mask_section_prefs(tts_prefs),
stt=mask_section_prefs(stt_prefs),
statuses=user.preferences.get("statuses", {}) if user.preferences else {}
),
effective=effective
)
def update_user_config(self, user, prefs: schemas.UserPreferences, db) -> schemas.UserPreferences:
# When saving, if the api_key contains ****, we must retain the old one from the DB
old_prefs = user.preferences or {}
def preserve_masked_keys(section_name, new_section):
if not new_section or "providers" not in new_section:
return
old_section = old_prefs.get(section_name, {}).get("providers", {})
for p_name, p_data in new_section["providers"].items():
if p_data.get("api_key") and "***" in p_data["api_key"]:
if p_name in old_section:
p_data["api_key"] = old_section[p_name].get("api_key")
def resolve_clone_from(section_name, new_section):
if not new_section or "providers" not in new_section:
return
old_section = old_prefs.get(section_name, {}).get("providers", {})
system_prefs = self.services.user_service.get_system_settings(db)
system_section = system_prefs.get(section_name, {}).get("providers", {})
for p_name, p_data in new_section["providers"].items():
clone_source = p_data.pop("_clone_from", None)
if not clone_source:
continue
real_key = (
old_section.get(clone_source, {}).get("api_key")
or system_section.get(clone_source, {}).get("api_key")
)
if real_key and "***" not in str(real_key):
p_data["api_key"] = real_key
logger.info(f"Resolved _clone_from: {p_name} inherited api_key from {clone_source} [{section_name}]")
else:
logger.warning(f"Could not resolve _clone_from for {p_name}: source '{clone_source}' key not found or masked.")
if prefs.llm: preserve_masked_keys("llm", prefs.llm)
if prefs.tts: preserve_masked_keys("tts", prefs.tts)
if prefs.stt: preserve_masked_keys("stt", prefs.stt)
if prefs.llm: resolve_clone_from("llm", prefs.llm)
if prefs.tts: resolve_clone_from("tts", prefs.tts)
if prefs.stt: resolve_clone_from("stt", prefs.stt)
current_prefs = dict(user.preferences or {})
current_prefs.update({
"llm": prefs.llm,
"tts": prefs.tts,
"stt": prefs.stt,
"statuses": prefs.statuses or {}
})
user.preferences = current_prefs
if user.role == "admin":
from sqlalchemy.orm.attributes import flag_modified
flag_modified(user, "preferences")
from app.config import settings as global_settings
if prefs.llm and "providers" in prefs.llm:
global_settings.LLM_PROVIDERS = dict(prefs.llm.get("providers", {}))
if prefs.tts and prefs.tts.get("active_provider"):
p_name = prefs.tts["active_provider"]
p_data = prefs.tts.get("providers", {}).get(p_name, {})
if p_data:
global_settings.TTS_PROVIDER = p_name
global_settings.TTS_MODEL_NAME = p_data.get("model") or global_settings.TTS_MODEL_NAME
global_settings.TTS_VOICE_NAME = p_data.get("voice") or global_settings.TTS_VOICE_NAME
global_settings.TTS_API_KEY = p_data.get("api_key") or global_settings.TTS_API_KEY
if prefs.stt and prefs.stt.get("active_provider"):
p_name = prefs.stt["active_provider"]
p_data = prefs.stt.get("providers", {}).get(p_name, {})
if p_data:
global_settings.STT_PROVIDER = p_name
global_settings.STT_MODEL_NAME = p_data.get("model") or global_settings.STT_MODEL_NAME
global_settings.STT_API_KEY = p_data.get("api_key") or global_settings.STT_API_KEY
try:
global_settings.save_to_yaml()
except Exception as ey:
logger.error(f"Failed to sync settings to YAML: {ey}")
logger.info(f"Saving updated global preferences via admin {user.id}")
else:
user.preferences["llm"]["active_provider"] = prefs.llm.get("active_provider")
user.preferences["tts"]["active_provider"] = prefs.tts.get("active_provider")
user.preferences["stt"]["active_provider"] = prefs.stt.get("active_provider")
user.preferences["statuses"] = prefs.statuses or {}
from sqlalchemy.orm.attributes import flag_modified
flag_modified(user, "preferences")
logger.info(f"Saving personal preferences for user {user.id}")
db.add(user)
db.commit()
db.refresh(user)
return schemas.UserPreferences(
llm=user.preferences.get("llm", {}),
tts=user.preferences.get("tts", {}),
stt=user.preferences.get("stt", {}),
statuses=user.preferences.get("statuses", {})
)