diff --git a/ai-hub/app/api/routes/mcp.py b/ai-hub/app/api/routes/mcp.py index 5df46cd..363d786 100644 --- a/ai-hub/app/api/routes/mcp.py +++ b/ai-hub/app/api/routes/mcp.py @@ -877,62 +877,59 @@ u = db.query(User).filter(User.id == token).first() if not u or u.role != "admin": raise ValueError("Forbidden: Admin only.") - def mask_keys(providers_dict): + config = self.services.preference_service.get_global_config(db) + + def mask_keys(section_name, providers_dict): import copy res = copy.deepcopy(providers_dict) if providers_dict else {} for p_data in res.values(): if isinstance(p_data, dict) and p_data.get("api_key"): - k = str(p_data["api_key"]) - p_data["api_key"] = k[:4] + "****" + k[-4:] if len(k) > 8 else "****" + p_data["api_key"] = self.services.preference_service.mask_key(p_data["api_key"]) return res return { - "llm_providers": mask_keys(settings.LLM_PROVIDERS), - "active_llm_provider": settings.ACTIVE_LLM_PROVIDER, - "tts_providers": mask_keys(settings.TTS_PROVIDERS), - "active_tts_provider": settings.TTS_PROVIDER, - "stt_providers": mask_keys(settings.STT_PROVIDERS), - "active_stt_provider": settings.STT_PROVIDER + "llm_providers": mask_keys("llm", config.get("llm", {}).get("providers") or settings.LLM_PROVIDERS), + "active_llm_provider": config.get("llm", {}).get("active_provider") or settings.ACTIVE_LLM_PROVIDER, + "tts_providers": mask_keys("tts", config.get("tts", {}).get("providers") or { + settings.TTS_PROVIDER: {"api_key": settings.TTS_API_KEY, "model": settings.TTS_MODEL_NAME, "voice": settings.TTS_VOICE_NAME} + }), + "active_tts_provider": config.get("tts", {}).get("active_provider") or settings.TTS_PROVIDER, + "stt_providers": mask_keys("stt", config.get("stt", {}).get("providers") or { + settings.STT_PROVIDER: {"api_key": settings.STT_API_KEY, "model": settings.STT_MODEL_NAME} + }), + "active_stt_provider": config.get("stt", {}).get("active_provider") or settings.STT_PROVIDER } return await self.loop.run_in_executor(None, _query) async def _update_global_config(self, args: dict, token: Optional[str]): if not token: raise ValueError("Authentication required.") - def _query(): + from app.api import schemas + def _execute(): from app.db.session import get_db_session from app.db.models import User with get_db_session() as db: u = db.query(User).filter(User.id == token).first() if not u or u.role != "admin": raise ValueError("Forbidden: Admin only.") - def preserve_masked(new_dict, old_dict): - if not new_dict or not old_dict: return - for p_name, p_data in new_dict.items(): - if isinstance(p_data, dict) and p_data.get("api_key") and "****" in str(p_data["api_key"]): - if p_name in old_dict and isinstance(old_dict[p_name], dict): - p_data["api_key"] = old_dict[p_name].get("api_key") - - if args.get("llm_providers") is not None: - preserve_masked(args["llm_providers"], settings.LLM_PROVIDERS) - settings.LLM_PROVIDERS = args["llm_providers"] - if args.get("active_llm_provider") is not None: - settings.ACTIVE_LLM_PROVIDER = args["active_llm_provider"] - - if args.get("tts_providers") is not None: - preserve_masked(args["tts_providers"], settings.TTS_PROVIDERS) - settings.TTS_PROVIDERS = args["tts_providers"] - if args.get("active_tts_provider") is not None: - settings.TTS_PROVIDER = args["active_tts_provider"] - - if args.get("stt_providers") is not None: - preserve_masked(args["stt_providers"], settings.STT_PROVIDERS) - settings.STT_PROVIDERS = args["stt_providers"] - if args.get("active_stt_provider") is not None: - settings.STT_PROVIDER = args["active_stt_provider"] - - settings.save_to_yaml() - return {"message": "Global providers updated successfully"} - return await self.loop.run_in_executor(None, _query) + # Transform MCP args to UserPreferences schema + prefs = schemas.UserPreferences( + llm={ + "active_provider": args.get("active_llm_provider"), + "providers": args.get("llm_providers") + } if args.get("llm_providers") or args.get("active_llm_provider") else None, + tts={ + "active_provider": args.get("active_tts_provider"), + "providers": args.get("tts_providers") + } if args.get("tts_providers") or args.get("active_tts_provider") else None, + stt={ + "active_provider": args.get("active_stt_provider"), + "providers": args.get("stt_providers") + } if args.get("stt_providers") or args.get("active_stt_provider") else None + ) + + self.services.preference_service.update_global_config(prefs, db, u) + return {"message": "Global providers updated successfully via system_config"} + return await self.loop.run_in_executor(None, _execute) async def _get_system_status(self, args: dict, token: Optional[str]): def _query():