import logging
import copy
from typing import Dict, Any
from app.config import settings
from app.api import schemas
logger = logging.getLogger(__name__)
class PreferenceService:
def __init__(self, services):
self.services = services
def mask_key(self, k: str) -> str:
if not k: return None
if len(k) <= 8: return "****"
return k[:4] + "*" * (len(k)-8) + k[-4:]
def merge_user_config(self, user, db) -> schemas.ConfigResponse:
prefs_dict = user.preferences or {}
def normalize_section(section_name, default_active):
section = prefs_dict.get(section_name, {})
# If already new style, just return a copy
if isinstance(section, dict) and "providers" in section:
return copy.deepcopy(section)
# Legacy transformation
providers = {}
active = section.get("active_provider") or section.get("provider") or default_active
# Known providers to check for legacy transformation
legacy_keys = ["openai", "gemini", "deepseek", "gcloud_tts", "azure", "google", "elevenlabs"]
for p in legacy_keys:
if p in section:
providers[p] = section[p]
# If still no providers found but it's not empty, it might be a flat dict of other providers
if not providers and section and isinstance(section, dict):
for k, v in section.items():
if k not in ["active_provider", "provider", "providers"] and isinstance(v, dict):
providers[k] = v
return {
"active_provider": str(active) if active else default_active,
"providers": providers
}
llm_prefs = normalize_section("llm", "deepseek")
tts_prefs = normalize_section("tts", settings.TTS_PROVIDER)
stt_prefs = normalize_section("stt", settings.STT_PROVIDER)
system_prefs = self.services.user_service.get_system_settings(db)
system_statuses = system_prefs.get("statuses", {})
user_statuses = prefs_dict.get("statuses", {})
def is_provider_healthy(section: str, provider_id: str, p_data: dict = None) -> bool:
status_key = f"{section}_{provider_id}"
is_success = user_statuses.get(status_key) == "success" or system_statuses.get(status_key) == "success"
has_key = p_data and p_data.get("api_key") and p_data.get("api_key") not in ("None", "none", "")
return is_success or bool(has_key)
# Build effective combined config for processing
def get_effective_providers(section_name, user_section_providers, sys_defaults):
# M6/M3: Deep merge system defaults with user overrides.
# This ensures that if a user configures a model name in the UI,
# they don't lose the API key from config.yaml if it's not provided in the UI.
# Start with system defaults (config.yaml / env)
effective = copy.deepcopy(sys_defaults)
# Layer on overrides (Admin or User)
if user_section_providers:
for p_id, p_data in user_section_providers.items():
if p_id not in effective:
# New provider defined in UI
effective[p_id] = copy.deepcopy(p_data)
else:
# Override existing provider fields
for field, val in p_data.items():
# Only override if the value is meaningful (not empty/null)
if val is not None and val != "" and str(val).lower() != "none":
effective[p_id][field] = val
# Filter by health and mask keys for the response
res = {}
for p, p_data in effective.items():
if p_data and is_provider_healthy(section_name, p, p_data):
masked_data = copy.deepcopy(p_data)
masked_data["api_key"] = self.mask_key(p_data.get("api_key"))
res[p] = masked_data
return res
def get_merged_system_defaults(section_name, hardcoded_defaults):
# M6/M3: Merge Admin overrides with hardcoded config.yaml defaults.
# This prevents the admin from losing system-level keys when they
# customize other fields (like model names) in the UI.
sys_prefs_section = system_prefs.get(section_name, {})
sys_providers = sys_prefs_section.get("providers", {})
if not sys_providers:
return hardcoded_defaults
# Start with hardcoded defaults
merged = copy.deepcopy(hardcoded_defaults)
for p_id, p_data in sys_providers.items():
if p_id not in merged:
merged[p_id] = p_data
else:
for field, val in p_data.items():
if val is not None and val != "" and str(val).lower() != "none":
merged[p_id][field] = val
return merged
system_llm = get_merged_system_defaults("llm", {
"deepseek": {"api_key": settings.DEEPSEEK_API_KEY, "model": settings.DEEPSEEK_MODEL_NAME},
"gemini": {"api_key": settings.GEMINI_API_KEY, "model": settings.GEMINI_MODEL_NAME},
"openai": {"api_key": settings.OPENAI_API_KEY}
})
llm_providers_effective = get_effective_providers("llm", llm_prefs["providers"], system_llm)
system_tts = get_merged_system_defaults("tts", {
settings.TTS_PROVIDER: {
"api_key": settings.TTS_API_KEY,
"model": settings.TTS_MODEL_NAME,
"voice": settings.TTS_VOICE_NAME
}
})
tts_providers_effective = get_effective_providers("tts", tts_prefs["providers"], system_tts)
system_stt = get_merged_system_defaults("stt", {
settings.STT_PROVIDER: {"api_key": settings.STT_API_KEY, "model": settings.STT_MODEL_NAME}
})
stt_providers_effective = get_effective_providers("stt", stt_prefs["providers"], system_stt)
effective = {
"llm": {
"active_provider": llm_prefs.get("active_provider") or (next(iter(llm_providers_effective), "deepseek")),
"providers": llm_providers_effective
},
"tts": {
"active_provider": tts_prefs.get("active_provider") or (next(iter(tts_providers_effective), settings.TTS_PROVIDER)),
"providers": tts_providers_effective
},
"stt": {
"active_provider": stt_prefs.get("active_provider") or (next(iter(stt_providers_effective), settings.STT_PROVIDER)),
"providers": stt_providers_effective
}
}
group = user.group or self.services.user_service.get_or_create_default_group(db)
if group:
policy = group.policy or {}
def apply_policy(section_key, policy_key):
allowed = policy.get(policy_key, [])
if not allowed:
effective[section_key]["providers"] = {}
effective[section_key]["active_provider"] = ""
return
providers = effective[section_key]["providers"]
filtered_eff = {k: v for k, v in providers.items() if k in allowed}
effective[section_key]["providers"] = filtered_eff
if effective[section_key].get("active_provider") not in allowed:
effective[section_key]["active_provider"] = next(iter(filtered_eff), None) or ""
apply_policy("llm", "llm")
apply_policy("tts", "tts")
apply_policy("stt", "stt")
def mask_section_prefs(section_dict):
if not section_dict: return {}
masked_dict = copy.deepcopy(section_dict)
providers = masked_dict.get("providers", {})
for p_name, p_data in providers.items():
if p_data.get("api_key"):
p_data["api_key"] = self.mask_key(p_data["api_key"])
return masked_dict
merged_statuses = copy.deepcopy(system_statuses)
merged_statuses.update(user_statuses)
return schemas.ConfigResponse(
preferences=schemas.UserPreferences(
llm=mask_section_prefs(llm_prefs),
tts=mask_section_prefs(tts_prefs),
stt=mask_section_prefs(stt_prefs),
statuses=merged_statuses
),
effective=effective
)
def update_user_config(self, user, prefs: schemas.UserPreferences, db) -> schemas.UserPreferences:
# When saving, if the api_key contains ****, we must retain the old one from the DB
old_prefs = user.preferences or {}
def get_old_providers(section_name):
section = old_prefs.get(section_name, {})
if isinstance(section, dict) and "providers" in section:
return section["providers"]
# Legacy extraction
providers = {}
legacy_keys = ["openai", "gemini", "deepseek", "gcloud_tts", "azure", "google", "elevenlabs"]
for p in legacy_keys:
if p in section:
providers[p] = section[p]
if not providers and section and isinstance(section, dict):
for k, v in section.items():
if k not in ["active_provider", "provider", "providers"] and isinstance(v, dict):
providers[k] = v
return providers
def preserve_masked_keys(section_name, new_section):
if not new_section or "providers" not in new_section:
return
old_section_providers = get_old_providers(section_name)
for p_name, p_data in new_section["providers"].items():
if p_data.get("api_key") and "***" in str(p_data["api_key"]):
if p_name in old_section_providers:
p_data["api_key"] = old_section_providers[p_name].get("api_key")
if prefs.llm: preserve_masked_keys("llm", prefs.llm)
if prefs.tts: preserve_masked_keys("tts", prefs.tts)
if prefs.stt: preserve_masked_keys("stt", prefs.stt)
current_prefs = dict(user.preferences or {})
current_prefs.update({
"llm": prefs.llm,
"tts": prefs.tts,
"stt": prefs.stt,
"statuses": prefs.statuses or {}
})
user.preferences = current_prefs
if user.role == "admin":
from sqlalchemy.orm.attributes import flag_modified
flag_modified(user, "preferences")
from app.config import settings as global_settings
if prefs.llm and "providers" in prefs.llm:
global_settings.LLM_PROVIDERS = dict(prefs.llm.get("providers", {}))
if prefs.tts and prefs.tts.get("active_provider"):
p_name = prefs.tts["active_provider"]
p_data = prefs.tts.get("providers", {}).get(p_name, {})
if p_data:
global_settings.TTS_PROVIDER = p_name
global_settings.TTS_MODEL_NAME = p_data.get("model") or global_settings.TTS_MODEL_NAME
global_settings.TTS_VOICE_NAME = p_data.get("voice") or global_settings.TTS_VOICE_NAME
global_settings.TTS_API_KEY = p_data.get("api_key") or global_settings.TTS_API_KEY
if prefs.stt and prefs.stt.get("active_provider"):
p_name = prefs.stt["active_provider"]
p_data = prefs.stt.get("providers", {}).get(p_name, {})
if p_data:
global_settings.STT_PROVIDER = p_name
global_settings.STT_MODEL_NAME = p_data.get("model") or global_settings.STT_MODEL_NAME
global_settings.STT_API_KEY = p_data.get("api_key") or global_settings.STT_API_KEY
logger.info(f"Saving updated global preferences via admin {user.id}")
else:
user.preferences["llm"]["active_provider"] = prefs.llm.get("active_provider")
user.preferences["tts"]["active_provider"] = prefs.tts.get("active_provider")
user.preferences["stt"]["active_provider"] = prefs.stt.get("active_provider")
user.preferences["statuses"] = prefs.statuses or {}
from sqlalchemy.orm.attributes import flag_modified
flag_modified(user, "preferences")
logger.info(f"Saving personal preferences for user {user.id}")
db.add(user)
db.commit()
db.refresh(user)
return schemas.UserPreferences(
llm=user.preferences.get("llm", {}),
tts=user.preferences.get("tts", {}),
stt=user.preferences.get("stt", {}),
statuses=user.preferences.get("statuses", {})
)
def export_config_yaml(self, user, reveal_secrets: bool) -> str:
import yaml
from app.core.grpc.utils.crypto import encrypt_value
prefs_dict = copy.deepcopy(user.preferences) if user.preferences else {}
sensitive_keys = ["api_key", "client_secret", "webhook_secret", "password", "key_content", "key_file"]
def process_export(obj):
if isinstance(obj, dict):
res = {}
for k, v in obj.items():
if v is None: continue
if k in sensitive_keys and v:
res[k] = v if reveal_secrets else encrypt_value(v)
else:
res[k] = process_export(v)
return res
elif isinstance(obj, list):
return [process_export(x) for x in obj]
return obj
export_data = {
"llm": prefs_dict.get("llm", {"providers": {}, "active_provider": "deepseek"}),
"tts": prefs_dict.get("tts", {"providers": {}, "active_provider": settings.TTS_PROVIDER}),
"stt": prefs_dict.get("stt", {"providers": {}, "active_provider": settings.STT_PROVIDER})
}
# Backfill from settings if empty
if not export_data["llm"].get("providers"):
export_data["llm"]["providers"] = {
"deepseek": {"api_key": settings.DEEPSEEK_API_KEY, "model": settings.DEEPSEEK_MODEL_NAME},
"gemini": {"api_key": settings.GEMINI_API_KEY, "model": settings.GEMINI_MODEL_NAME}
}
return yaml.dump(process_export(export_data), sort_keys=False, default_flow_style=False)
async def import_config_yaml(self, db, user, content: bytes) -> schemas.UserPreferences:
import yaml
from app.core.grpc.utils.crypto import decrypt_value
from sqlalchemy.orm.attributes import flag_modified
try: data = yaml.safe_load(content)
except Exception as e: raise Exception(f"Invalid YAML: {e}")
def process_import(obj):
if isinstance(obj, dict): return {k: process_import(v) for k, v in obj.items()}
elif isinstance(obj, str): return decrypt_value(obj)
elif isinstance(obj, list): return [process_import(x) for x in obj]
return obj
data = process_import(data)
user.preferences = {
"llm": data.get("llm", {}),
"tts": data.get("tts", {}),
"stt": data.get("stt", {}),
"statuses": {}
}
flag_modified(user, "preferences")
db.commit()
return schemas.UserPreferences(llm=user.preferences["llm"], tts=user.preferences["tts"], stt=user.preferences["stt"])
async def verify_provider(self, db, user, req: schemas.VerifyProviderRequest, section: str) -> schemas.VerifyProviderResponse:
from app.core.providers.factory import get_llm_provider, get_tts_provider, get_stt_provider
# Admin or personal key check
is_masked = not req.api_key or "***" in str(req.api_key)
if is_masked and user.role != "admin":
return schemas.VerifyProviderResponse(success=False, message="Forbidden: Admin only for masked keys")
actual_key = req.api_key
prefs = user.preferences.get(section, {}).get("providers", {}).get(req.provider_name, {}) if user.preferences else {}
if is_masked:
actual_key = prefs.get("api_key")
if not actual_key:
s_prefs = self.services.user_service.get_system_settings(db)
actual_key = s_prefs.get(section, {}).get("providers", {}).get(req.provider_type or req.provider_name, {}).get("api_key")
if not actual_key:
if section == "llm": actual_key = settings.DEEPSEEK_API_KEY
elif section == "tts": actual_key = settings.TTS_API_KEY
else: actual_key = settings.STT_API_KEY
try:
if section == "llm":
llm = get_llm_provider(req.provider_name, model_name=req.model or "", api_key_override=actual_key)
await llm.acompletion(prompt="Hello")
elif section == "tts":
p = get_tts_provider(req.provider_name, api_key=actual_key, model_name=req.model or "", voice_name=req.voice or "")
await p.generate_speech("Test")
else:
get_stt_provider(req.provider_name, api_key=actual_key, model_name=req.model or "")
return schemas.VerifyProviderResponse(success=True, message="Success!")
except Exception as e:
return schemas.VerifyProviderResponse(success=False, message=str(e))
def resolve_llm_provider(self, db, user, provider_name: str, model_name: str = None) -> Any:
"""
Unified resolution for LLM providers with full fallback chain:
User Preference -> System Override (Admin UI) -> Config Defaults (YAML/Env)
"""
from app.core.providers.factory import get_llm_provider
base_key = provider_name.split("/")[0] if provider_name else ""
if not base_key and user:
base_key = user.preferences.get("llm", {}).get("active_provider", "deepseek")
provider_name = base_key
llm_prefs = user.preferences.get("llm", {}).get("providers", {}).get(base_key, {}) if user else {}
# Resolve Key Fallbacks
if not llm_prefs or not llm_prefs.get("api_key") or "*" in str(llm_prefs.get("api_key")):
u_svc = getattr(self.services, "user_service", None)
if u_svc:
system_prefs = u_svc.get_system_settings(db)
system_prov = system_prefs.get("llm", {}).get("providers", {}).get(base_key, {})
# Fallback to system's active_provider if specified provider is missing key
if (not system_prov or not system_prov.get("api_key")) and system_prefs.get("llm", {}).get("active_provider"):
active_key = system_prefs["llm"]["active_provider"]
system_prov, provider_name = system_prefs["llm"]["providers"].get(active_key, {}), active_key
base_key = active_key
if system_prov:
merged = system_prov.copy()
merged.update({k: v for k, v in llm_prefs.items() if v})
llm_prefs = merged
# Resolve Model Override (handles 'provider/model' syntax)
resolved_model = provider_name.split("/")[1] if "/" in provider_name else (model_name or llm_prefs.get("model", ""))
resolved_provider_name = provider_name.split("/")[0] if "/" in provider_name else provider_name
try:
return get_llm_provider(
resolved_provider_name,
model_name=resolved_model,
api_key_override=llm_prefs.get("api_key"),
**{k: v for k, v in llm_prefs.items() if k not in ["api_key", "model"]}
), resolved_provider_name
except Exception as e:
logger.error(f"Failed to resolve LLM provider: {e}")
return None, resolved_provider_name