diff --git a/ai-hub/app/api/routes/admin.py b/ai-hub/app/api/routes/admin.py index 1697015..b574ca5 100644 --- a/ai-hub/app/api/routes/admin.py +++ b/ai-hub/app/api/routes/admin.py @@ -125,6 +125,56 @@ settings.save_to_yaml() return {"message": "Swarm configuration updated successfully"} + @router.get("/config/providers", summary="Get Global Providers Configuration") + async def get_global_providers(admin = Depends(get_current_admin)): + def mask_keys(providers_dict): + import copy + res = copy.deepcopy(providers_dict) if providers_dict else {} + for p_data in res.values(): + if isinstance(p_data, dict) and p_data.get("api_key"): + k = str(p_data["api_key"]) + p_data["api_key"] = k[:4] + "****" + k[-4:] if len(k) > 8 else "****" + return res + + return { + "llm_providers": mask_keys(settings.LLM_PROVIDERS), + "active_llm_provider": settings.ACTIVE_LLM_PROVIDER, + "tts_providers": mask_keys(settings.TTS_PROVIDERS), + "active_tts_provider": settings.TTS_PROVIDER, + "stt_providers": mask_keys(settings.STT_PROVIDERS), + "active_stt_provider": settings.STT_PROVIDER + } + + @router.put("/config/providers", summary="Update Global Providers Configuration") + async def update_global_providers(update: schemas.GlobalProvidersUpdate, admin = Depends(get_current_admin)): + def preserve_masked(new_dict, old_dict): + if not new_dict or not old_dict: return + for p_name, p_data in new_dict.items(): + if isinstance(p_data, dict) and p_data.get("api_key") and "****" in str(p_data["api_key"]): + if p_name in old_dict and isinstance(old_dict[p_name], dict): + p_data["api_key"] = old_dict[p_name].get("api_key") + + if update.llm_providers is not None: + preserve_masked(update.llm_providers, settings.LLM_PROVIDERS) + settings.LLM_PROVIDERS = update.llm_providers + if update.active_llm_provider is not None: + settings.ACTIVE_LLM_PROVIDER = update.active_llm_provider + + if update.tts_providers is not None: + preserve_masked(update.tts_providers, settings.TTS_PROVIDERS) + settings.TTS_PROVIDERS = update.tts_providers + if update.active_tts_provider is not None: + settings.TTS_PROVIDER = update.active_tts_provider + + if update.stt_providers is not None: + preserve_masked(update.stt_providers, settings.STT_PROVIDERS) + settings.STT_PROVIDERS = update.stt_providers + if update.active_stt_provider is not None: + settings.STT_PROVIDER = update.active_stt_provider + + settings.save_to_yaml() + return {"message": "Global providers updated successfully"} + @router.get("/config", summary="Get Admin Configuration") async def get_admin_config( admin = Depends(get_current_admin) diff --git a/ai-hub/app/api/routes/mcp.py b/ai-hub/app/api/routes/mcp.py index 613912e..ec80c2e 100644 --- a/ai-hub/app/api/routes/mcp.py +++ b/ai-hub/app/api/routes/mcp.py @@ -311,7 +311,16 @@ _tool_def("detach_node_from_session", "Detach agent node.", {"session_id": {"type": "integer"}, "node_id": {"type": "string"}}, required=["session_id", "node_id"]), _tool_def("get_session_nodes", "Get all nodes in session.", {"session_id": {"type": "integer"}}, required=["session_id"]), _tool_def("cancel_session_task", "Cancel session task.", {"session_id": {"type": "integer"}}, required=["session_id"]), - _tool_def("get_system_status", "Retrieve full system state.", {}) + _tool_def("get_system_status", "Retrieve full system state.", {}), + _tool_def("get_global_config", "Get global LLM/TTS/STT settings (Admin only).", {}), + _tool_def("update_global_config", "Update global LLM/TTS/STT settings (Admin only).", { + "llm_providers": {"type": "object"}, + "active_llm_provider": {"type": "string"}, + "tts_providers": {"type": "object"}, + "active_tts_provider": {"type": "string"}, + "stt_providers": {"type": "object"}, + "active_stt_provider": {"type": "string"} + }) ] } @@ -369,7 +378,9 @@ "detach_node_from_session": self._detach_node_from_session, "get_session_nodes": self._get_session_nodes, "cancel_session_task": self._cancel_session_task, - "get_system_status": self._get_system_status + "get_system_status": self._get_system_status, + "get_global_config": self._get_global_config, + "update_global_config": self._update_global_config } async def dispatch(self, name: str, args: dict, token: Optional[str]) -> Any: @@ -852,6 +863,72 @@ return {"message": "Cancellation request sent."} return await self.loop.run_in_executor(None, _query) + async def _get_global_config(self, args: dict, token: Optional[str]): + if not token: raise ValueError("Authentication required.") + def _query(): + from app.db.session import get_db_session + from app.db.models import User + with get_db_session() as db: + u = db.query(User).filter(User.id == token).first() + if not u or u.role != "admin": raise ValueError("Forbidden: Admin only.") + + def mask_keys(providers_dict): + import copy + res = copy.deepcopy(providers_dict) if providers_dict else {} + for p_data in res.values(): + if isinstance(p_data, dict) and p_data.get("api_key"): + k = str(p_data["api_key"]) + p_data["api_key"] = k[:4] + "****" + k[-4:] if len(k) > 8 else "****" + return res + + return { + "llm_providers": mask_keys(settings.LLM_PROVIDERS), + "active_llm_provider": settings.ACTIVE_LLM_PROVIDER, + "tts_providers": mask_keys(settings.TTS_PROVIDERS), + "active_tts_provider": settings.TTS_PROVIDER, + "stt_providers": mask_keys(settings.STT_PROVIDERS), + "active_stt_provider": settings.STT_PROVIDER + } + return await self.loop.run_in_executor(None, _query) + + async def _update_global_config(self, args: dict, token: Optional[str]): + if not token: raise ValueError("Authentication required.") + def _query(): + from app.db.session import get_db_session + from app.db.models import User + with get_db_session() as db: + u = db.query(User).filter(User.id == token).first() + if not u or u.role != "admin": raise ValueError("Forbidden: Admin only.") + + def preserve_masked(new_dict, old_dict): + if not new_dict or not old_dict: return + for p_name, p_data in new_dict.items(): + if isinstance(p_data, dict) and p_data.get("api_key") and "****" in str(p_data["api_key"]): + if p_name in old_dict and isinstance(old_dict[p_name], dict): + p_data["api_key"] = old_dict[p_name].get("api_key") + + if args.get("llm_providers") is not None: + preserve_masked(args["llm_providers"], settings.LLM_PROVIDERS) + settings.LLM_PROVIDERS = args["llm_providers"] + if args.get("active_llm_provider") is not None: + settings.ACTIVE_LLM_PROVIDER = args["active_llm_provider"] + + if args.get("tts_providers") is not None: + preserve_masked(args["tts_providers"], settings.TTS_PROVIDERS) + settings.TTS_PROVIDERS = args["tts_providers"] + if args.get("active_tts_provider") is not None: + settings.TTS_PROVIDER = args["active_tts_provider"] + + if args.get("stt_providers") is not None: + preserve_masked(args["stt_providers"], settings.STT_PROVIDERS) + settings.STT_PROVIDERS = args["stt_providers"] + if args.get("active_stt_provider") is not None: + settings.STT_PROVIDER = args["active_stt_provider"] + + settings.save_to_yaml() + return {"message": "Global providers updated successfully"} + return await self.loop.run_in_executor(None, _query) + async def _get_system_status(self, args: dict, token: Optional[str]): def _query(): return {"status": "running", "oidc_enabled": settings.OIDC_ENABLED, "version": "1.0.0"} diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 7ddb820..2f1bfd4 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -347,6 +347,14 @@ success: bool message: str +class GlobalProvidersUpdate(BaseModel): + llm_providers: Optional[dict] = None + active_llm_provider: Optional[str] = None + tts_providers: Optional[dict] = None + active_tts_provider: Optional[str] = None + stt_providers: Optional[dict] = None + active_stt_provider: Optional[str] = None + class ModelInfoResponse(BaseModel): model_name: str max_tokens: Optional[int] = None diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index c72b660..25b7eaf 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -352,6 +352,12 @@ "swarm": { "external_endpoint": self.GRPC_EXTERNAL_ENDPOINT }, + "llm_providers": self.LLM_PROVIDERS, + "active_llm_provider": self.ACTIVE_LLM_PROVIDER, + "tts_providers": self.TTS_PROVIDERS, + "active_tts_provider": self.TTS_PROVIDER, + "stt_providers": self.STT_PROVIDERS, + "active_stt_provider": self.STT_PROVIDER, "journal": { "stream_head_chars": self.STREAM_HEAD_CHARS, "stream_tail_chars": self.STREAM_TAIL_CHARS, diff --git a/ai-hub/app/core/services/preference.py b/ai-hub/app/core/services/preference.py index f58c791..32727d1 100644 --- a/ai-hub/app/core/services/preference.py +++ b/ai-hub/app/core/services/preference.py @@ -233,41 +233,9 @@ }) user.preferences = current_prefs - if user.role == "admin": - from sqlalchemy.orm.attributes import flag_modified - flag_modified(user, "preferences") - - from app.config import settings as global_settings - - if prefs.llm and "providers" in prefs.llm: - global_settings.LLM_PROVIDERS = dict(prefs.llm.get("providers", {})) - - if prefs.tts and prefs.tts.get("active_provider"): - p_name = prefs.tts["active_provider"] - p_data = prefs.tts.get("providers", {}).get(p_name, {}) - if p_data: - global_settings.TTS_PROVIDER = p_name - global_settings.TTS_MODEL_NAME = p_data.get("model") or global_settings.TTS_MODEL_NAME - global_settings.TTS_VOICE_NAME = p_data.get("voice") or global_settings.TTS_VOICE_NAME - global_settings.TTS_API_KEY = p_data.get("api_key") or global_settings.TTS_API_KEY - - if prefs.stt and prefs.stt.get("active_provider"): - p_name = prefs.stt["active_provider"] - p_data = prefs.stt.get("providers", {}).get(p_name, {}) - if p_data: - global_settings.STT_PROVIDER = p_name - global_settings.STT_MODEL_NAME = p_data.get("model") or global_settings.STT_MODEL_NAME - global_settings.STT_API_KEY = p_data.get("api_key") or global_settings.STT_API_KEY - - logger.info(f"Saving updated global preferences via admin {user.id}") - else: - user.preferences["llm"]["active_provider"] = prefs.llm.get("active_provider") - user.preferences["tts"]["active_provider"] = prefs.tts.get("active_provider") - user.preferences["stt"]["active_provider"] = prefs.stt.get("active_provider") - user.preferences["statuses"] = prefs.statuses or {} - from sqlalchemy.orm.attributes import flag_modified - flag_modified(user, "preferences") - logger.info(f"Saving personal preferences for user {user.id}") + from sqlalchemy.orm.attributes import flag_modified + flag_modified(user, "preferences") + logger.info(f"Saving personal preferences for user {user.id}") db.add(user) db.commit()