import logging
import copy
from typing import List, Dict, Any
from app.config import settings
from app.api import schemas
logger = logging.getLogger(__name__)
class PreferenceService:
def __init__(self, services):
self.services = services
def mask_key(self, k: str) -> str:
if not k: return None
if len(k) <= 8: return "****"
return k[:4] + "*" * (len(k)-8) + k[-4:]
def merge_user_config(self, user, db) -> schemas.ConfigResponse:
prefs_dict = user.preferences or {}
def normalize_section(section_name, default_active):
section = prefs_dict.get(section_name, {})
# If already new style, just return a copy
if isinstance(section, dict) and "providers" in section:
return copy.deepcopy(section)
# Legacy transformation
providers = {}
active = section.get("active_provider") or section.get("provider") or default_active
# Known providers to check for legacy transformation
legacy_keys = ["openai", "gemini", "deepseek", "gcloud_tts", "azure", "google", "elevenlabs"]
for p in legacy_keys:
if p in section:
providers[p] = section[p]
# If still no providers found but it's not empty, it might be a flat dict of other providers
if not providers and section and isinstance(section, dict):
for k, v in section.items():
if k not in ["active_provider", "provider", "providers"] and isinstance(v, dict):
providers[k] = v
return {
"active_provider": str(active) if active else default_active,
"providers": providers
}
llm_prefs = normalize_section("llm", "deepseek")
tts_prefs = normalize_section("tts", settings.TTS_PROVIDER)
stt_prefs = normalize_section("stt", settings.STT_PROVIDER)
system_prefs = self.get_global_config(db)
system_statuses = system_prefs.get("statuses", {})
user_statuses = prefs_dict.get("statuses", {})
merged_statuses = copy.deepcopy(system_statuses)
merged_statuses.update(user_statuses)
# M6: Propagate personal statuses to their suffixed IDs (e.g. llm_gemini -> llm_gemini_personal)
# This ensures the UI shows green/red dots for personal accounts correctly.
for k, v in user_statuses.items():
merged_statuses[f"{k}_personal"] = v
def is_provider_healthy(section: str, provider_id: str, p_data: dict = None) -> bool:
# M6: Check health using either the full ID or the base ID (in case of personal suffix)
status_key = f"{section}_{provider_id}"
base_id = provider_id.replace("_personal", "")
base_status_key = f"{section}_{base_id}"
is_success = (
user_statuses.get(status_key) == "success" or
user_statuses.get(base_status_key) == "success" or
system_statuses.get(status_key) == "success" or
system_statuses.get(base_status_key) == "success"
)
has_key = p_data and p_data.get("api_key") and p_data.get("api_key") not in ("None", "none", "")
return is_success or bool(has_key)
# Build effective combined config for processing
def get_effective_providers(section_name, user_section_providers, sys_defaults, user_section_data):
# M6: If the user provides a configuration for a provider that already exists
# in the system defaults, we want to "Add" it as a personal account
# rather than "Overriding" the system default. This allows both to coexist.
effective = copy.deepcopy(sys_defaults)
user_active = user_section_data.get("active_provider")
if user_section_providers:
for p_id, p_data in user_section_providers.items():
target_id = p_id
if p_id in effective:
# Collision: Create a distinct personal resource
target_id = f"{p_id}_personal"
# Update the user's active selection to point to their personal version
if user_active == p_id:
user_section_data["active_provider"] = target_id
effective[target_id] = copy.deepcopy(p_data)
# Filter by health and mask keys for the response
res = {}
for p, p_data in effective.items():
if p_data and is_provider_healthy(section_name, p, p_data):
masked_data = copy.deepcopy(p_data)
masked_data["api_key"] = self.mask_key(p_data.get("api_key"))
res[p] = masked_data
return res
def get_merged_system_defaults(section_name, hardcoded_defaults):
# M6/M3: Merge Admin overrides with hardcoded config.yaml defaults.
# If the admin has configured ANY providers in the DB, we treat the DB
# as the Source of Truth for the LIST of available providers.
# We only use hardcoded_defaults to backfill missing fields (like API keys)
# for the providers that are present in the DB.
sys_prefs_section = system_prefs.get(section_name, {})
sys_providers = sys_prefs_section.get("providers", {})
if not sys_providers:
# If nothing in DB, fall back to everything in config.yaml/env
return hardcoded_defaults
# Start with the list of providers defined in the DB
merged = copy.deepcopy(sys_providers)
# Backfill secrets/defaults from hardcoded config for these specific providers
for p_id, p_data in merged.items():
if p_id in hardcoded_defaults:
for field, val in hardcoded_defaults[p_id].items():
# Only backfill if the DB value is missing or effectively empty
db_val = p_data.get(field)
if db_val is None or db_val == "" or str(db_val).lower() == "none":
p_data[field] = val
return merged
system_llm = get_merged_system_defaults("llm", settings.LLM_PROVIDERS)
llm_providers_effective = get_effective_providers("llm", llm_prefs["providers"], system_llm, llm_prefs)
system_tts = get_merged_system_defaults("tts", {
settings.TTS_PROVIDER: {
"api_key": settings.TTS_API_KEY,
"model": settings.TTS_MODEL_NAME,
"voice": settings.TTS_VOICE_NAME
}
})
tts_providers_effective = get_effective_providers("tts", tts_prefs["providers"], system_tts, tts_prefs)
system_stt = get_merged_system_defaults("stt", {
settings.STT_PROVIDER: {"api_key": settings.STT_API_KEY, "model": settings.STT_MODEL_NAME}
})
stt_providers_effective = get_effective_providers("stt", stt_prefs["providers"], system_stt, stt_prefs)
effective = {
"llm": {
"active_provider": llm_prefs.get("active_provider") or (next(iter(llm_providers_effective), "deepseek")),
"providers": llm_providers_effective
},
"tts": {
"active_provider": tts_prefs.get("active_provider") or (next(iter(tts_providers_effective), settings.TTS_PROVIDER)),
"providers": tts_providers_effective
},
"stt": {
"active_provider": stt_prefs.get("active_provider") or (next(iter(stt_providers_effective), settings.STT_PROVIDER)),
"providers": stt_providers_effective
},
"statuses": merged_statuses
}
group = user.group or self.services.user_service.get_or_create_default_group(db)
if group:
policy = group.policy or {}
def apply_policy(section_key, policy_key):
allowed = policy.get(policy_key, [])
if not allowed:
effective[section_key]["providers"] = {}
effective[section_key]["active_provider"] = ""
return
providers = effective[section_key]["providers"]
# M6: Allow explicitly whitelisted IDs OR their personal suffixed versions
def is_allowed(pid):
if pid in allowed: return True
if pid.endswith("_personal") and pid.replace("_personal", "") in allowed: return True
return False
filtered_eff = {k: v for k, v in providers.items() if is_allowed(k)}
effective[section_key]["providers"] = filtered_eff
curr_active = effective[section_key].get("active_provider")
if curr_active and not is_allowed(curr_active):
effective[section_key]["active_provider"] = next(iter(filtered_eff), None) or ""
apply_policy("llm", "llm")
apply_policy("tts", "tts")
apply_policy("stt", "stt")
def mask_section_prefs(section_dict):
if not section_dict: return {}
masked_dict = copy.deepcopy(section_dict)
providers = masked_dict.get("providers", {})
for p_name, p_data in providers.items():
if p_data.get("api_key"):
p_data["api_key"] = self.mask_key(p_data["api_key"])
return masked_dict
return schemas.ConfigResponse(
preferences=schemas.UserPreferences(
llm=mask_section_prefs(llm_prefs),
tts=mask_section_prefs(tts_prefs),
stt=mask_section_prefs(stt_prefs),
statuses=merged_statuses
),
effective=effective
)
def get_global_config(self, db) -> dict:
"""Read system-wide provider config from system_config table.
Falls back to config.yaml / env vars if no DB entry exists."""
from app.db import models as m
result = {}
for section in ("llm", "tts", "stt", "statuses"):
row = db.query(m.SystemConfig).filter(m.SystemConfig.key == section).first()
if row and row.value:
result[section] = row.value
return result
def update_global_config(self, prefs: schemas.UserPreferences, db, admin_user) -> schemas.UserPreferences:
"""Persist system-wide provider config to system_config table.
Admin-only. Does NOT touch users.preferences."""
import json as _json
from datetime import datetime
from app.db import models as m
from sqlalchemy.orm.attributes import flag_modified
def _restore_masked_keys(section_name, new_section):
"""Preserve existing API keys when the UI sends back masked '****' values."""
if not new_section or "providers" not in new_section:
return
row = db.query(m.SystemConfig).filter(m.SystemConfig.key == section_name).first()
old_providers = (row.value or {}).get("providers", {}) if row else {}
for p_name, p_data in new_section["providers"].items():
if p_data.get("api_key") and "***" in str(p_data["api_key"]):
if p_name in old_providers:
p_data["api_key"] = old_providers[p_name].get("api_key")
if prefs.llm: _restore_masked_keys("llm", prefs.llm)
if prefs.tts: _restore_masked_keys("tts", prefs.tts)
if prefs.stt: _restore_masked_keys("stt", prefs.stt)
for section, value in (("llm", prefs.llm), ("tts", prefs.tts), ("stt", prefs.stt), ("statuses", prefs.statuses)):
if value is None:
continue
row = db.query(m.SystemConfig).filter(m.SystemConfig.key == section).first()
if row:
row.value = value
row.updated_at = datetime.utcnow()
row.updated_by = admin_user.id
flag_modified(row, "value")
else:
db.add(m.SystemConfig(key=section, value=value, updated_by=admin_user.id))
# Keep in-memory settings in sync for the current process lifetime
from app.config import settings as global_settings
if prefs.llm and "providers" in prefs.llm:
global_settings.LLM_PROVIDERS = dict(prefs.llm.get("providers", {}))
if prefs.tts and prefs.tts.get("active_provider"):
p_name = prefs.tts["active_provider"]
p_data = prefs.tts.get("providers", {}).get(p_name, {})
if p_data:
global_settings.TTS_PROVIDER = p_name
global_settings.TTS_MODEL_NAME = p_data.get("model") or global_settings.TTS_MODEL_NAME
global_settings.TTS_VOICE_NAME = p_data.get("voice") or global_settings.TTS_VOICE_NAME
global_settings.TTS_API_KEY = p_data.get("api_key") or global_settings.TTS_API_KEY
if prefs.stt and prefs.stt.get("active_provider"):
p_name = prefs.stt["active_provider"]
p_data = prefs.stt.get("providers", {}).get(p_name, {})
if p_data:
global_settings.STT_PROVIDER = p_name
global_settings.STT_MODEL_NAME = p_data.get("model") or global_settings.STT_MODEL_NAME
global_settings.STT_API_KEY = p_data.get("api_key") or global_settings.STT_API_KEY
db.commit()
logger.info(f"Global provider config updated by admin {admin_user.id}")
return prefs
def update_user_config(self, user, prefs: schemas.UserPreferences, db) -> schemas.UserPreferences:
"""Save a user's personal provider preferences (active_provider choice + optional personal API keys).
Always writes to users.preferences only — never touches system_config."""
from sqlalchemy.orm.attributes import flag_modified
old_prefs = user.preferences or {}
def get_old_providers(section_name):
section = old_prefs.get(section_name, {})
if isinstance(section, dict) and "providers" in section:
return section["providers"]
providers = {}
legacy_keys = ["openai", "gemini", "deepseek", "gcloud_tts", "azure", "google", "elevenlabs"]
for p in legacy_keys:
if p in section:
providers[p] = section[p]
if not providers and section and isinstance(section, dict):
for k, v in section.items():
if k not in ["active_provider", "provider", "providers"] and isinstance(v, dict):
providers[k] = v
return providers
def preserve_masked_keys(section_name, new_section):
if not new_section or "providers" not in new_section:
return
old_providers = get_old_providers(section_name)
for p_name, p_data in new_section["providers"].items():
if p_data.get("api_key") and "***" in str(p_data["api_key"]):
if p_name in old_providers:
p_data["api_key"] = old_providers[p_name].get("api_key")
if prefs.llm: preserve_masked_keys("llm", prefs.llm)
if prefs.tts: preserve_masked_keys("tts", prefs.tts)
if prefs.stt: preserve_masked_keys("stt", prefs.stt)
current_prefs = dict(old_prefs)
current_prefs.update({
"llm": prefs.llm,
"tts": prefs.tts,
"stt": prefs.stt,
"statuses": prefs.statuses or {}
})
user.preferences = current_prefs
from sqlalchemy.orm.attributes import flag_modified
flag_modified(user, "preferences")
logger.info(f"Saving personal preferences for user {user.id}")
db.add(user)
db.commit()
db.refresh(user)
logger.info(f"Saved personal preferences for user {user.id}")
return schemas.UserPreferences(
llm=user.preferences.get("llm", {}),
tts=user.preferences.get("tts", {}),
stt=user.preferences.get("stt", {}),
statuses=user.preferences.get("statuses", {})
)
def export_config_yaml(self, user, reveal_secrets: bool) -> str:
import yaml
from app.core.grpc.utils.crypto import encrypt_value
prefs_dict = copy.deepcopy(user.preferences) if user.preferences else {}
sensitive_keys = ["api_key", "client_secret", "webhook_secret", "password", "key_content", "key_file"]
def process_export(obj):
if isinstance(obj, dict):
res = {}
for k, v in obj.items():
if v is None: continue
if k in sensitive_keys and v:
res[k] = v if reveal_secrets else encrypt_value(v)
else:
res[k] = process_export(v)
return res
elif isinstance(obj, list):
return [process_export(x) for x in obj]
return obj
export_data = {
"llm": prefs_dict.get("llm", {"providers": {}, "active_provider": "deepseek"}),
"tts": prefs_dict.get("tts", {"providers": {}, "active_provider": settings.TTS_PROVIDER}),
"stt": prefs_dict.get("stt", {"providers": {}, "active_provider": settings.STT_PROVIDER})
}
# Backfill from settings if empty
if not export_data["llm"].get("providers"):
export_data["llm"]["providers"] = settings.LLM_PROVIDERS
return yaml.dump(process_export(export_data), sort_keys=False, default_flow_style=False)
async def import_config_yaml(self, db, user, content: bytes) -> schemas.UserPreferences:
import yaml
from app.core.grpc.utils.crypto import decrypt_value
from sqlalchemy.orm.attributes import flag_modified
try: data = yaml.safe_load(content)
except Exception as e: raise Exception(f"Invalid YAML: {e}")
def process_import(obj):
if isinstance(obj, dict): return {k: process_import(v) for k, v in obj.items()}
elif isinstance(obj, str): return decrypt_value(obj)
elif isinstance(obj, list): return [process_import(x) for x in obj]
return obj
data = process_import(data)
user.preferences = {
"llm": data.get("llm", {}),
"tts": data.get("tts", {}),
"stt": data.get("stt", {}),
"statuses": {}
}
flag_modified(user, "preferences")
db.commit()
return schemas.UserPreferences(llm=user.preferences["llm"], tts=user.preferences["tts"], stt=user.preferences["stt"])
async def verify_provider(self, db, user, req: schemas.VerifyProviderRequest, section: str) -> schemas.VerifyProviderResponse:
from app.core.providers.factory import get_llm_provider, get_tts_provider, get_stt_provider
# Admin or personal key check
is_masked = not req.api_key or "***" in str(req.api_key)
if is_masked and user.role != "admin":
return schemas.VerifyProviderResponse(success=False, message="Forbidden: Admin only for masked keys")
actual_key = req.api_key
prefs = user.preferences.get(section, {}).get("providers", {}).get(req.provider_name, {}) if user.preferences else {}
if is_masked:
actual_key = prefs.get("api_key")
if not actual_key:
s_prefs = self.get_global_config(db)
actual_key = s_prefs.get(section, {}).get("providers", {}).get(req.provider_type or req.provider_name, {}).get("api_key")
try:
if section == "llm":
llm = get_llm_provider(req.provider_name, model_name=req.model or "", api_key_override=actual_key)
await llm.acompletion(prompt="Hello")
elif section == "tts":
p = get_tts_provider(req.provider_name, api_key=actual_key, model_name=req.model or "", voice_name=req.voice or "")
await p.generate_speech("Test")
else:
p = get_stt_provider(req.provider_name, api_key=actual_key, model_name=req.model or "")
await p.transcribe_audio(b"dummy audio")
return schemas.VerifyProviderResponse(success=True, message="Success!")
except Exception as e:
return schemas.VerifyProviderResponse(success=False, message=str(e))
def resolve_llm_provider(self, db, user, provider_name: str, model_name: str = None) -> Any:
"""
Unified resolution for LLM providers with full fallback chain:
User Preference -> System Override (Admin UI) -> Config Defaults (YAML/Env)
"""
from app.core.providers.factory import get_llm_provider
base_key = provider_name.split("/")[0] if provider_name else ""
if not provider_name:
# Fallback to registered active provider
u_svc = getattr(self.services, "user_service", None)
sys_prefs = self.get_global_config(db) if u_svc else {}
user_active = user.preferences.get("llm", {}).get("active_provider") if user and user.preferences else None
base_key = user_active or sys_prefs.get("llm", {}).get("active_provider")
provider_name = base_key
user_providers = user.preferences.get("llm", {}).get("providers", {}) if user and user.preferences else {}
llm_prefs = user_providers.get(base_key, {})
# Prefix matching: 'gemini' -> 'gemini_gemini-3-flash-preview'
if not llm_prefs and base_key:
for pk, pv in user_providers.items():
if pk.startswith(f"{base_key}_") or pk == base_key:
llm_prefs = pv
provider_name = pk
base_key = pk
# Derive model from the key suffix if the stored model is generic/missing
# e.g. key='gemini_gemini-3-flash-preview' -> derived_model='gemini-3-flash-preview'
parts = pk.split("_", 1)
if len(parts) == 2:
derived_model = parts[1]
stored_model = pv.get("model", "")
# Only use the derived model if stored one looks generic (no dash = no version)
if not stored_model or "-" not in stored_model:
llm_prefs = dict(pv, model=derived_model)
break
logger.info(f"[Preference] Resolved match for '{base_key}': model={llm_prefs.get('model')}, has_key={'api_key' in llm_prefs}")
# Resolve Resolved Model/Provider names
provider_name_str = str(provider_name) if provider_name else ""
has_slash = "/" in provider_name_str
resolved_model = provider_name_str.split("/")[1] if has_slash else (model_name or llm_prefs.get("model", ""))
resolved_provider_name = provider_name_str.split("/")[0] if has_slash else provider_name_str
# 3. Last Resort: Pick the first available provider for the user/tenant if no preference exists
if not resolved_provider_name or not resolved_model:
available_providers = self.get_user_llm_providers(user.id, db)
if available_providers:
# Use the first available provider as the default
default_p = available_providers[0]
resolved_provider_name = default_p.name
resolved_model = default_p.model
logger.info(f"[Preference] No preference set for user {user.id}. Defaulting to first available: {resolved_provider_name}")
else:
# No providers configured for this user/tenant at all
logger.error(f"[Preference] No LLM providers configured for user {user.id}")
raise HTTPException(
status_code=400,
detail="No LLM providers configured. Please configure an LLM provider (e.g. Gemini, OpenAI) in your user settings."
)
logger.info(f"[Preference] Final Resolution for {user.id}: {resolved_provider_name} / {resolved_model}")
# The resolved_provider_name may be a DB-internal composite key like 'gemini_gemini-3-flash-preview'.
# LiteLLM needs the base provider type (e.g. 'gemini'), not the full key.
# The model string (e.g. 'gemini/gemini-3-flash-preview') already encodes the routing info.
litellm_provider = resolved_provider_name.split("_")[0] if "_" in resolved_provider_name else resolved_provider_name
try:
return get_llm_provider(
litellm_provider,
model_name=resolved_model,
api_key_override=llm_prefs.get("api_key"),
**{k: v for k, v in llm_prefs.items() if k not in ["api_key", "model"]}
), resolved_provider_name
except Exception as e:
raise ValueError(f"Failed to initialize LLM provider '{litellm_provider}' with model '{resolved_model}': {e}")
def get_user_llm_providers(self, user_id: str, db) -> List[Any]:
"""
Helper method to retrieve all effectively configured LLM providers for a specific user.
Used as a last-resort fallback in resolve_llm_provider.
"""
from app.db.models import User
user = db.query(User).filter(User.id == user_id).first()
if not user:
return []
# Use existing merge logic to get effective config
config = self.merge_user_config(user, db)
llm_providers = config.effective.get("llm", {}).get("providers", {})
results = []
for p_name, p_data in llm_providers.items():
# Create a simple attribute-accessible object for the caller
class ProviderInfo:
def __init__(self, name: str, model: str):
self.name = name
self.model = model
results.append(ProviderInfo(p_name, p_data.get("model", "")))
return results
async def get_provider_models(self, provider_name: str, section: str = "llm") -> List[Dict[str, Any]]:
"""Fetches supported models for a specific provider and section using LiteLLM."""
import litellm
from fastapi.concurrency import run_in_threadpool
def fetch_models():
try:
models = litellm.models_by_provider.get(provider_name, [])
out = []
for m in models:
try:
info = litellm.get_model_info(m)
if "error" not in info:
mode = info.get("mode")
is_valid = False
if section == "llm": is_valid = mode in ["chat", "text-completion", "custom", None]
elif section == "tts": is_valid = mode == "audio_speech"
elif section == "stt": is_valid = mode == "audio_transcription"
elif section == "image": is_valid = mode == "image_generation"
else: is_valid = True
if is_valid:
out.append({"model_name": m, "max_tokens": info.get("max_tokens"), "max_input_tokens": info.get("max_input_tokens")})
except: pass
return out
except: return []
return await run_in_threadpool(fetch_models)
def get_all_providers(self, db, user, section: str = "llm", configured_only: bool = False) -> List[str]:
"""Returns valid providers for a section, optionally filtering by those with configured credentials."""
import litellm
from app.core.providers.factory import get_registered_tts_providers, get_registered_stt_providers
if configured_only:
system_prefs = self.get_global_config(db)
user_prefs = user.preferences if user else {}
configured = set(system_prefs.get(section, {}).get("providers", {}).keys())
configured.update(user_prefs.get(section, {}).get("providers", {}).keys())
if not configured:
if section == "llm": configured.update(["deepseek", "gemini"])
elif section == "tts": configured.add(settings.TTS_PROVIDER)
elif section == "stt": configured.add(settings.STT_PROVIDER)
return sorted(list(configured))
if section == "llm": return ["general"] + [p.value for p in litellm.LlmProviders]
elif section == "tts": return ["general"] + get_registered_tts_providers() + ["openai"]
elif section == "stt": return ["general"] + get_registered_stt_providers() + ["openai"]
return ["general"] + [p.value for p in litellm.LlmProviders]