from fastapi import APIRouter, HTTPException, Depends, Header, Query, Request, UploadFile, File
from fastapi.responses import RedirectResponse as redirect
from sqlalchemy.orm import Session
from app.db import models
from typing import Optional, Annotated
import logging
import os
import httpx
import jwt
import urllib.parse
# Correctly import from your application's schemas and dependencies
from app.api.dependencies import ServiceContainer, get_db
from app.api import schemas
from app.core.services.user import login_required
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Minimum OIDC configuration from environment variables
OIDC_CLIENT_ID = os.getenv("OIDC_CLIENT_ID", "")
OIDC_CLIENT_SECRET = os.getenv("OIDC_CLIENT_SECRET", "")
OIDC_SERVER_URL = os.getenv("OIDC_SERVER_URL", "")
OIDC_REDIRECT_URI = os.getenv("OIDC_REDIRECT_URI", "")
# --- Derived OIDC Configuration ---
OIDC_AUTHORIZATION_URL = f"{OIDC_SERVER_URL}/auth"
OIDC_TOKEN_URL = f"{OIDC_SERVER_URL}/token"
OIDC_USERINFO_URL = f"{OIDC_SERVER_URL}/userinfo"
# A dependency to simulate getting the current user ID from a request header
def get_current_user_id(x_user_id: Annotated[Optional[str], Header()] = None) -> Optional[str]:
"""
Retrieves the user ID from the X-User-ID header.
This simulates an authentication system and is used by the login_required decorator.
"""
return x_user_id
def create_users_router(services: ServiceContainer) -> APIRouter:
router = APIRouter(prefix="/users", tags=["Users"])
@router.get("/login", summary="Initiate OIDC Login Flow")
async def login_redirect(
request: Request,
# Allow the frontend to provide its callback URL
frontend_callback_uri: Optional[str] = Query(None, description="The frontend URI to redirect back to after OIDC provider.")
):
"""
Initiates the OIDC authentication flow. The `frontend_callback_uri`
specifies where the user should be redirected after successful
authentication with the OIDC provider.
"""
# Store the frontend_callback_uri in a session or a cache,
# linked to the state parameter for security.
# For simplicity, we will pass it as a query parameter in the callback.
# A more robust solution would use a state parameter.
# Use urllib.parse.urlencode to properly encode parameters
params = {
"response_type": "code",
"scope": "openid profile email",
"client_id": OIDC_CLIENT_ID,
"redirect_uri": OIDC_REDIRECT_URI,
"state": frontend_callback_uri or ""
}
auth_url = f"{OIDC_AUTHORIZATION_URL}?{urllib.parse.urlencode(params)}"
logger.info(f"Redirecting to OIDC authorization URL: {auth_url}")
return redirect(url=auth_url)
@router.get("/login/callback", summary="Handle OIDC Login Callback")
async def login_callback(
request: Request,
code: str = Query(..., description="Authorization code from OIDC provider"),
state: str = Query(..., description="The original frontend redirect URI"),
db: Session = Depends(get_db)
):
"""
Handles the callback from the OIDC provider, exchanges the code for
tokens, and then redirects the user back to the frontend with
the user data or a session token.
"""
logger.info(f"Received callback with authorization code: {code[:10]}... and state: {state}")
try:
logger.info(f"Exchanging code for tokens at: {OIDC_TOKEN_URL}")
# Step 1: Exchange the authorization code for an access token and an ID token
token_data = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": OIDC_REDIRECT_URI,
"client_id": OIDC_CLIENT_ID,
"client_secret": OIDC_CLIENT_SECRET,
}
async with httpx.AsyncClient() as client:
logger.debug(f"Sending POST to {OIDC_TOKEN_URL} with data keys: {list(token_data.keys())}")
token_response = await client.post(OIDC_TOKEN_URL, data=token_data, timeout=30.0)
token_response.raise_for_status()
response_json = token_response.json()
logger.info("Successfully received tokens from OIDC provider.")
id_token = response_json.get("id_token")
if not id_token:
logger.error("Error: ID token not found in the response.")
raise HTTPException(status_code=400, detail="Failed to get ID token from OIDC provider.")
# Step 2: Decode the ID token to get user information
logger.info("Decoding ID token...")
decoded_id_token = jwt.decode(id_token, options={"verify_signature": False})
oidc_id = decoded_id_token.get("sub")
email = decoded_id_token.get("email")
# Dex and others often use 'name' for the full name, or 'preferred_username'
username = decoded_id_token.get("name") or decoded_id_token.get("preferred_username") or email
logger.info(f"User decoded: email={email}, oidc_id={oidc_id}")
if not all([oidc_id, email]):
logger.error(f"Error: Essential user data missing. oidc_id={oidc_id}, email={email}")
raise HTTPException(status_code=400, detail="Essential user data missing from ID token (sub and email required).")
# Step 3: Save the user and get their unique ID
logger.info("Saving user to database...")
user_id = services.user_service.save_user(
db=db,
oidc_id=oidc_id,
email=email,
username=username
)
logger.info(f"User saved/updated successfully with internal ID: {user_id}")
# Step 4: Redirect back to the frontend
frontend_redirect_url = f"{state}?user_id={user_id}"
logger.info(f"Redirecting back to frontend: {frontend_redirect_url}")
return redirect(url=frontend_redirect_url)
except httpx.HTTPStatusError as e:
logger.error(f"OIDC Token exchange failed with status {e.response.status_code}: {e.response.text}")
raise HTTPException(status_code=500, detail=f"OIDC Token exchange failed: {e.response.text}")
except httpx.RequestError as e:
logger.error(f"OIDC Token exchange request error: {e}")
raise HTTPException(status_code=500, detail=f"Failed to communicate with OIDC provider: {e}")
except jwt.JWTDecodeError as e:
logger.error(f"ID token decode error: {e}")
raise HTTPException(status_code=400, detail="Failed to decode ID token from OIDC provider.")
except Exception as e:
logger.exception(f"An unexpected error occurred during OIDC callback: {e}")
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {e}")
@router.get("/me", response_model=schemas.UserStatus, summary="Get Current User Status")
async def get_current_status(
db: Session = Depends(get_db),
user_id: str = Depends(get_current_user_id)
):
"""
Checks the login status of the current user.
Requires a valid user_id to be present in the request header.
"""
try:
# In a real-world scenario, you would fetch user details from the DB using user_id
# For this example, we return a mock response based on the presence of user_id
user : Optional[models.User] = services.user_service.get_user_by_id(db=db, user_id=user_id) # Ensure user exists
email = user.email if user else None
is_anonymous = user is None
is_logged_in = user is not None
return schemas.UserStatus(
id=user_id,
email=email,
is_logged_in=is_logged_in,
is_anonymous=is_anonymous
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
@router.get("/me/config", response_model=schemas.ConfigResponse, summary="Get Current User Preferences")
async def get_user_config(
db: Session = Depends(get_db),
user_id: str = Depends(get_current_user_id)
):
"""Gets user specific preferences (LLM, TTS, STT config overrides)."""
if not user_id:
raise HTTPException(status_code=401, detail="Unauthorized")
user = services.user_service.get_user_by_id(db=db, user_id=user_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
prefs_dict = user.preferences or {}
# Calculate effective config
from app.config import settings
def mask_key(k):
if not k: return None
if len(k) <= 8: return "****"
return k[:4] + "*" * (len(k)-8) + k[-4:]
llm_prefs = prefs_dict.get("llm", {})
tts_prefs = prefs_dict.get("tts", {})
stt_prefs = prefs_dict.get("stt", {})
user_providers = llm_prefs.get("providers", {})
if not user_providers:
# Day zero: fall back to evaluating defaults from env/yaml
llm_providers_effective = {
"deepseek": {
"api_key": mask_key(settings.DEEPSEEK_API_KEY),
"model": settings.DEEPSEEK_MODEL_NAME
},
"gemini": {
"api_key": mask_key(settings.GEMINI_API_KEY),
"model": settings.GEMINI_MODEL_NAME
},
"openai": {
"api_key": mask_key(settings.OPENAI_API_KEY),
"model": "gpt-4"
}
}
else:
# User has configured providers, only return what they have (plus masking)
llm_providers_effective = {}
for p, prefs in user_providers.items():
llm_providers_effective[p] = {
"api_key": mask_key(prefs.get("api_key")),
"model": prefs.get("model")
}
user_tts_providers = tts_prefs.get("providers", {})
if not user_tts_providers:
tts_providers_effective = {
settings.TTS_PROVIDER: {
"api_key": mask_key(settings.TTS_API_KEY),
"model": settings.TTS_MODEL_NAME,
"voice": settings.TTS_VOICE_NAME
}
}
else:
tts_providers_effective = {}
for p, prefs in user_tts_providers.items():
tts_providers_effective[p] = {
"api_key": mask_key(prefs.get("api_key")),
"model": prefs.get("model"),
"voice": prefs.get("voice")
}
user_stt_providers = stt_prefs.get("providers", {})
if not user_stt_providers:
stt_providers_effective = {
settings.STT_PROVIDER: {
"api_key": mask_key(settings.STT_API_KEY),
"model": settings.STT_MODEL_NAME
}
}
else:
stt_providers_effective = {}
for p, prefs in user_stt_providers.items():
stt_providers_effective[p] = {
"api_key": mask_key(prefs.get("api_key")),
"model": prefs.get("model")
}
effective = {
"llm": {
"active_provider": llm_prefs.get("active_provider") or (next(iter(llm_providers_effective), None)) or "deepseek",
"providers": llm_providers_effective
},
"tts": {
"active_provider": tts_prefs.get("active_provider") or (next(iter(tts_providers_effective), None)) or settings.TTS_PROVIDER,
"providers": tts_providers_effective
},
"stt": {
"active_provider": stt_prefs.get("active_provider") or (next(iter(stt_providers_effective), None)) or settings.STT_PROVIDER,
"providers": stt_providers_effective
}
}
# Ensure we mask the preferences dict we send back to the user
def mask_section_prefs(section_dict):
import copy
masked_dict = copy.deepcopy(section_dict)
providers = masked_dict.get("providers", {})
for p_name, p_data in providers.items():
if p_data.get("api_key"):
p_data["api_key"] = mask_key(p_data["api_key"])
return masked_dict
return schemas.ConfigResponse(
preferences=schemas.UserPreferences(
llm=mask_section_prefs(llm_prefs),
tts=mask_section_prefs(tts_prefs),
stt=mask_section_prefs(stt_prefs)
),
effective=effective
)
@router.put("/me/config", response_model=schemas.UserPreferences, summary="Update Current User Preferences")
async def update_user_config(
prefs: schemas.UserPreferences,
db: Session = Depends(get_db),
user_id: str = Depends(get_current_user_id)
):
"""Updates user specific preferences."""
if not user_id:
raise HTTPException(status_code=401, detail="Unauthorized")
user = services.user_service.get_user_by_id(db=db, user_id=user_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
# When saving, if the api_key contains ****, we must retain the old one from the DB
old_prefs = user.preferences or {}
def preserve_masked_keys(section_name, new_section):
if not new_section or "providers" not in new_section:
return
old_section = old_prefs.get(section_name, {}).get("providers", {})
for p_name, p_data in new_section["providers"].items():
if p_data.get("api_key") and "***" in p_data["api_key"]:
if p_name in old_section:
p_data["api_key"] = old_section[p_name].get("api_key")
if prefs.llm: preserve_masked_keys("llm", prefs.llm)
if prefs.tts: preserve_masked_keys("tts", prefs.tts)
if prefs.stt: preserve_masked_keys("stt", prefs.stt)
user.preferences = {"llm": prefs.llm, "tts": prefs.tts, "stt": prefs.stt}
from sqlalchemy.orm.attributes import flag_modified
flag_modified(user, "preferences")
# --- Day 2 Sync: Update Global Settings and YAML ---
# Update our global settings object to match the user's new preferences
# and then persist them to config.yaml if this is the "primary" user or admin.
from app.config import settings as global_settings
# Sync LLM
if prefs.llm and prefs.llm.get("providers"):
global_settings.LLM_PROVIDERS.update(prefs.llm.get("providers", {}))
# Sync TTS
if prefs.tts and prefs.tts.get("active_provider"):
p_name = prefs.tts["active_provider"]
p_data = prefs.tts.get("providers", {}).get(p_name, {})
if p_data:
global_settings.TTS_PROVIDER = p_name
global_settings.TTS_MODEL_NAME = p_data.get("model") or global_settings.TTS_MODEL_NAME
global_settings.TTS_VOICE_NAME = p_data.get("voice") or global_settings.TTS_VOICE_NAME
global_settings.TTS_API_KEY = p_data.get("api_key") or global_settings.TTS_API_KEY
# Sync STT
if prefs.stt and prefs.stt.get("active_provider"):
p_name = prefs.stt["active_provider"]
p_data = prefs.stt.get("providers", {}).get(p_name, {})
if p_data:
global_settings.STT_PROVIDER = p_name
global_settings.STT_MODEL_NAME = p_data.get("model") or global_settings.STT_MODEL_NAME
global_settings.STT_API_KEY = p_data.get("api_key") or global_settings.STT_API_KEY
# Write to config.yaml
try:
global_settings.save_to_yaml()
except Exception as ey:
logger.error(f"Failed to sync settings to YAML: {ey}")
logger.info(f"Saving updated preferences for user {user_id}: {list(user.preferences.keys())}")
db.add(user)
db.commit()
db.refresh(user)
return schemas.UserPreferences(llm=user.preferences.get("llm", {}), tts=user.preferences.get("tts", {}), stt=user.preferences.get("stt", {}))
@router.get("/me/config/models", response_model=list[schemas.ModelInfoResponse], summary="Get Models for Provider")
async def get_provider_models(provider_name: str, section: str = "llm"):
import litellm
from fastapi.concurrency import run_in_threadpool
def fetch_models():
try:
models = litellm.models_by_provider.get(provider_name, [])
out = []
for m in models:
try:
info = litellm.get_model_info(m)
if "error" not in info:
mode = info.get("mode")
is_valid = False
if section == "llm":
is_valid = mode in ["chat", "text-completion", "custom", None]
elif section == "tts":
is_valid = mode == "audio_speech"
elif section == "stt":
is_valid = mode == "audio_transcription"
elif section == "image":
is_valid = mode == "image_generation"
else:
is_valid = True
if is_valid:
out.append({
"model_name": m,
"max_tokens": info.get("max_tokens"),
"max_input_tokens": info.get("max_input_tokens")
})
except Exception:
pass
return out
except Exception:
return []
results = await run_in_threadpool(fetch_models)
return results
@router.get("/me/config/providers", response_model=list[str], summary="Get All Valid Providers per Section")
async def get_all_providers(section: str = "llm"):
import litellm
from app.core.providers.factory import get_registered_tts_providers, get_registered_stt_providers
if section == "llm":
return ["general"] + [p.value for p in litellm.LlmProviders]
elif section == "tts":
# For TTS, combine registry + possibly litellm providers that support audio_speech
return ["general"] + get_registered_tts_providers() + ["openai"]
elif section == "stt":
return ["general"] + get_registered_stt_providers() + ["openai"]
elif section == "image":
return ["general"] + [p.value for p in litellm.LlmProviders]
return ["general"] + [p.value for p in litellm.LlmProviders]
@router.post("/me/config/verify_llm", response_model=schemas.VerifyProviderResponse)
async def verify_llm(
req: schemas.VerifyProviderRequest,
db: Session = Depends(get_db),
user_id: str = Depends(get_current_user_id)
):
from app.core.providers.factory import get_llm_provider
if not user_id:
raise HTTPException(status_code=401, detail="Unauthorized")
actual_key = req.api_key
try:
llm_prefs = {}
user = services.user_service.get_user_by_id(db=db, user_id=user_id)
if user and user.preferences:
llm_prefs = user.preferences.get("llm", {}).get("providers", {}).get(req.provider_name, {})
kwargs = {k: v for k, v in llm_prefs.items() if k not in ["api_key", "model"]}
llm = get_llm_provider(
provider_name=req.provider_name,
model_name=req.model or "",
api_key_override=actual_key,
**kwargs
)
res = llm("Hello")
return schemas.VerifyProviderResponse(success=True, message="Connection successful!")
except Exception as e:
return schemas.VerifyProviderResponse(success=False, message=str(e))
@router.post("/me/config/verify_tts", response_model=schemas.VerifyProviderResponse)
async def verify_tts(
req: schemas.VerifyProviderRequest,
db: Session = Depends(get_db),
user_id: str = Depends(get_current_user_id)
):
from app.core.providers.factory import get_tts_provider
from app.config import settings
if not user_id:
raise HTTPException(status_code=401, detail="Unauthorized")
actual_key = req.api_key
try:
tts_prefs = {}
user = services.user_service.get_user_by_id(db=db, user_id=user_id)
if user and user.preferences:
tts_prefs = user.preferences.get("tts", {}).get("providers", {}).get(req.provider_name, {})
# Resolve the real key: masked/absent → DB key → system TTS key → system Gemini key
def is_masked(k):
return not k or k in ("None", "none", "") or "*" in str(k)
if is_masked(actual_key):
db_key = tts_prefs.get("api_key")
if not is_masked(db_key):
actual_key = db_key
logger.debug(f"verify_tts: using DB key for {req.provider_name}")
elif not is_masked(settings.TTS_API_KEY):
actual_key = settings.TTS_API_KEY
logger.debug(f"verify_tts: using system TTS_API_KEY for {req.provider_name}")
elif not is_masked(settings.GEMINI_API_KEY):
actual_key = settings.GEMINI_API_KEY
logger.debug(f"verify_tts: using system GEMINI_API_KEY for {req.provider_name}")
logger.info(f"verify_tts: provider={req.provider_name}, model={req.model}, key_prefix={str(actual_key)[:8] if actual_key else 'NONE'}")
kwargs = {k: v for k, v in tts_prefs.items() if k not in ["api_key", "model", "voice"]}
tts = get_tts_provider(
provider_name=req.provider_name,
api_key=actual_key,
model_name=req.model or "",
voice_name=req.voice or "",
**kwargs
)
await tts.generate_speech("Testing.")
return schemas.VerifyProviderResponse(success=True, message="Connection successful!")
except Exception as e:
logger.error(f"TTS verification failed for {req.provider_name}: {e}")
return schemas.VerifyProviderResponse(success=False, message=str(e))
@router.post("/me/config/verify_stt", response_model=schemas.VerifyProviderResponse)
async def verify_stt(
req: schemas.VerifyProviderRequest,
db: Session = Depends(get_db),
user_id: str = Depends(get_current_user_id)
):
from app.core.providers.factory import get_stt_provider
if not user_id:
raise HTTPException(status_code=401, detail="Unauthorized")
actual_key = req.api_key
from app.config import settings
def is_masked(k):
return not k or k in ("None", "none", "") or "*" in str(k)
if is_masked(actual_key):
user = services.user_service.get_user_by_id(db=db, user_id=user_id)
if user and user.preferences:
prefs = user.preferences.get("stt", {}).get("providers", {}).get(req.provider_name, {})
db_key = prefs.get("api_key")
if not is_masked(db_key):
actual_key = db_key
if is_masked(actual_key) and not is_masked(settings.STT_API_KEY):
actual_key = settings.STT_API_KEY
if is_masked(actual_key) and not is_masked(settings.GEMINI_API_KEY):
actual_key = settings.GEMINI_API_KEY
try:
stt = get_stt_provider(provider_name=req.provider_name, model_name=req.model or "", api_key=actual_key)
import io
import wave
wav_io = io.BytesIO()
with wave.open(wav_io, 'wb') as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2)
wav_file.setframerate(16000)
wav_file.writeframes(b'\x00\x00' * 16000) # 1 second of silence
res = await stt.transcribe_audio(wav_io.getvalue())
# Empty transcript is expected for silent audio — connection works fine
return schemas.VerifyProviderResponse(success=True, message="Connection successful!")
except Exception as e:
return schemas.VerifyProviderResponse(success=False, message=str(e))
@router.post("/logout", summary="Log Out the Current User")
async def logout():
"""
Simulates a user logout. In a real application, this would clear the session token or cookie.
"""
return {"message": "Logged out successfully"}
@router.get("/me/config/export", summary="Export Configurations to YAML")
async def export_user_config_yaml(
db: Session = Depends(get_db),
user_id: str = Depends(get_current_user_id)
):
"""Exports the effective user configuration as a YAML file."""
from fastapi.responses import PlainTextResponse
if not user_id:
raise HTTPException(status_code=401, detail="Unauthorized")
user = services.user_service.get_user_by_id(db=db, user_id=user_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
prefs_dict = user.preferences or {}
from app.config import settings
import yaml
llm_prefs = prefs_dict.get("llm", {})
tts_prefs = prefs_dict.get("tts", {})
stt_prefs = prefs_dict.get("stt", {})
llm_providers_export = {}
user_providers = llm_prefs.get("providers", {})
if not user_providers:
# Fallback to system defaults if no user config exists
llm_providers_export = {
"deepseek_api_key": settings.DEEPSEEK_API_KEY,
"deepseek_model_name": settings.DEEPSEEK_MODEL_NAME,
"gemini_api_key": settings.GEMINI_API_KEY,
"gemini_model_name": settings.GEMINI_MODEL_NAME,
"openai_api_key": settings.OPENAI_API_KEY
}
else:
for p, p_data in user_providers.items():
llm_providers_export[f"{p}_api_key"] = p_data.get("api_key")
llm_providers_export[f"{p}_model_name"] = p_data.get("model")
def get_provider_export(section_prefs, fallback_provider, fallback_model, fallback_api_key, fallback_voice=None):
active_p = section_prefs.get("active_provider")
providers = section_prefs.get("providers", {})
if active_p and active_p in providers:
p_data = providers[active_p]
return {
"provider": active_p,
"model_name": p_data.get("model"),
"voice_name": p_data.get("voice"),
"api_key": p_data.get("api_key")
}
# Fallback to system settings
return {
"provider": fallback_provider,
"model_name": fallback_model,
"voice_name": fallback_voice,
"api_key": fallback_api_key
}
yaml_data = {
"application": {
"project_name": settings.PROJECT_NAME,
"version": settings.VERSION,
"log_level": settings.LOG_LEVEL
},
"database": {
"mode": settings.DB_MODE,
},
"llm_providers": llm_providers_export,
"embedding_provider": {
"provider": settings.EMBEDDING_PROVIDER,
"model_name": settings.EMBEDDING_MODEL_NAME,
"api_key": settings.EMBEDDING_API_KEY
},
"tts_provider": get_provider_export(tts_prefs, settings.TTS_PROVIDER, settings.TTS_MODEL_NAME, settings.TTS_API_KEY, settings.TTS_VOICE_NAME),
"stt_provider": get_provider_export(stt_prefs, settings.STT_PROVIDER, settings.STT_MODEL_NAME, settings.STT_API_KEY)
}
# Filter out None values recursively
def remove_none(obj):
if isinstance(obj, dict):
return {k: remove_none(v) for k, v in obj.items() if v is not None}
return obj
clean_yaml_data = remove_none(yaml_data)
yaml_str = yaml.dump(clean_yaml_data, sort_keys=False, default_flow_style=False)
return PlainTextResponse(
content=yaml_str,
media_type="application/x-yaml",
headers={"Content-Disposition": "attachment; filename=\"config.yaml\""}
)
@router.post("/me/config/import", response_model=schemas.UserPreferences, summary="Import Configurations from YAML")
async def import_user_config_yaml(
file: UploadFile = File(...),
db: Session = Depends(get_db),
user_id: str = Depends(get_current_user_id)
):
"""Imports user configuration from a YAML file."""
if not user_id:
raise HTTPException(status_code=401, detail="Unauthorized")
user = services.user_service.get_user_by_id(db=db, user_id=user_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
content = await file.read()
try:
import yaml
data = yaml.safe_load(content)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid YAML file: {e}")
# Reverse mapping: YAML -> UserPreferences structure
new_llm = { "providers": {}, "active_provider": None }
new_tts = { "providers": {}, "active_provider": None }
new_stt = { "providers": {}, "active_provider": None }
# --- LLM ---
llm_data = data.get("llm_providers", {})
if "providers" in llm_data: # Structured
new_llm["providers"] = llm_data["providers"]
else: # Flattened (as exported)
for k, v in llm_data.items():
if k.endswith("_api_key"):
p = k.replace("_api_key", "")
if p not in new_llm["providers"]: new_llm["providers"][p] = {}
new_llm["providers"][p]["api_key"] = v
elif k.endswith("_model_name"):
p = k.replace("_model_name", "")
if p not in new_llm["providers"]: new_llm["providers"][p] = {}
new_llm["providers"][p]["model"] = v
if new_llm["providers"]:
new_llm["active_provider"] = next(iter(new_llm["providers"]), None)
# --- TTS ---
tts_data = data.get("tts_provider", {})
if tts_data:
p = tts_data.get("provider") or "google_gemini"
new_tts["active_provider"] = p
new_tts["providers"][p] = {
"api_key": tts_data.get("api_key"),
"model": tts_data.get("model_name"),
"voice": tts_data.get("voice_name")
}
# --- STT ---
stt_data = data.get("stt_provider", {})
if stt_data:
p = stt_data.get("provider") or "google_gemini"
new_stt["active_provider"] = p
new_stt["providers"][p] = {
"api_key": stt_data.get("api_key"),
"model": stt_data.get("model_name")
}
user.preferences = { "llm": new_llm, "tts": new_tts, "stt": new_stt }
from sqlalchemy.orm.attributes import flag_modified
flag_modified(user, "preferences")
# --- Day 2 Sync ---
from app.config import settings as global_settings
if new_llm.get("providers"):
global_settings.LLM_PROVIDERS.update(new_llm["providers"])
if new_tts.get("active_provider"):
p = new_tts["active_provider"]
p_data = new_tts["providers"].get(p, {})
if p_data:
global_settings.TTS_PROVIDER = p
global_settings.TTS_MODEL_NAME = p_data.get("model") or global_settings.TTS_MODEL_NAME
global_settings.TTS_VOICE_NAME = p_data.get("voice") or global_settings.TTS_VOICE_NAME
global_settings.TTS_API_KEY = p_data.get("api_key") or global_settings.TTS_API_KEY
if new_stt.get("active_provider"):
p = new_stt["active_provider"]
p_data = new_stt["providers"].get(p, {})
if p_data:
global_settings.STT_PROVIDER = p
global_settings.STT_MODEL_NAME = p_data.get("model") or global_settings.STT_MODEL_NAME
global_settings.STT_API_KEY = p_data.get("api_key") or global_settings.STT_API_KEY
try:
global_settings.save_to_yaml()
except Exception as ey:
logger.error(f"Failed to sync settings to YAML on import: {ey}")
db.add(user)
db.commit()
db.refresh(user)
return schemas.UserPreferences(llm=user.preferences.get("llm", {}), tts=user.preferences.get("tts", {}), stt=user.preferences.get("stt", {}))
return router