diff --git a/ai-hub/app/api/routes/admin.py b/ai-hub/app/api/routes/admin.py index b574ca5..0d9a876 100644 --- a/ai-hub/app/api/routes/admin.py +++ b/ai-hub/app/api/routes/admin.py @@ -1,9 +1,10 @@ from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.orm import Session from app.api import schemas -from app.api.dependencies import get_current_admin +from app.api.dependencies import get_current_admin, get_db from app.config import settings -def create_admin_router() -> APIRouter: +def create_admin_router(services=None) -> APIRouter: router = APIRouter() @router.put("/config/oidc", summary="Update OIDC Configuration") @@ -125,55 +126,39 @@ settings.save_to_yaml() return {"message": "Swarm configuration updated successfully"} - @router.get("/config/providers", summary="Get Global Providers Configuration") - async def get_global_providers(admin = Depends(get_current_admin)): - def mask_keys(providers_dict): - import copy - res = copy.deepcopy(providers_dict) if providers_dict else {} - for p_data in res.values(): - if isinstance(p_data, dict) and p_data.get("api_key"): - k = str(p_data["api_key"]) - p_data["api_key"] = k[:4] + "****" + k[-4:] if len(k) > 8 else "****" - return res - - return { - "llm_providers": mask_keys(settings.LLM_PROVIDERS), - "active_llm_provider": settings.ACTIVE_LLM_PROVIDER, - "tts_providers": mask_keys(settings.TTS_PROVIDERS), - "active_tts_provider": settings.TTS_PROVIDER, - "stt_providers": mask_keys(settings.STT_PROVIDERS), - "active_stt_provider": settings.STT_PROVIDER + @router.get("/config/providers", summary="Get Global Provider Configuration") + async def get_global_provider_config( + db: Session = Depends(get_db), + admin=Depends(get_current_admin), + ): + """Returns the system-wide LLM/TTS/STT provider config stored in system_config table.""" + if not services: + raise HTTPException(status_code=503, detail="Services not available.") + global_cfg = services.preference_service.get_global_config(db) + env_fallback = { + "llm": {"active_provider": settings.ACTIVE_LLM_PROVIDER, "providers": dict(settings.LLM_PROVIDERS)}, + "tts": {"active_provider": settings.TTS_PROVIDER, "providers": {}}, + "stt": {"active_provider": settings.STT_PROVIDER, "providers": {}}, } + for section in ("llm", "tts", "stt"): + if section not in global_cfg: + global_cfg[section] = env_fallback[section] + if "statuses" not in global_cfg: + global_cfg["statuses"] = {} + return global_cfg - @router.put("/config/providers", summary="Update Global Providers Configuration") - async def update_global_providers(update: schemas.GlobalProvidersUpdate, admin = Depends(get_current_admin)): - def preserve_masked(new_dict, old_dict): - if not new_dict or not old_dict: return - for p_name, p_data in new_dict.items(): - if isinstance(p_data, dict) and p_data.get("api_key") and "****" in str(p_data["api_key"]): - if p_name in old_dict and isinstance(old_dict[p_name], dict): - p_data["api_key"] = old_dict[p_name].get("api_key") - - if update.llm_providers is not None: - preserve_masked(update.llm_providers, settings.LLM_PROVIDERS) - settings.LLM_PROVIDERS = update.llm_providers - if update.active_llm_provider is not None: - settings.ACTIVE_LLM_PROVIDER = update.active_llm_provider - - if update.tts_providers is not None: - preserve_masked(update.tts_providers, settings.TTS_PROVIDERS) - settings.TTS_PROVIDERS = update.tts_providers - if update.active_tts_provider is not None: - settings.TTS_PROVIDER = update.active_tts_provider - - if update.stt_providers is not None: - preserve_masked(update.stt_providers, settings.STT_PROVIDERS) - settings.STT_PROVIDERS = update.stt_providers - if update.active_stt_provider is not None: - settings.STT_PROVIDER = update.active_stt_provider - - settings.save_to_yaml() - return {"message": "Global providers updated successfully"} + @router.put("/config/providers", summary="Update Global Provider Configuration") + async def update_global_provider_config( + prefs: schemas.UserPreferences, + db: Session = Depends(get_db), + admin=Depends(get_current_admin), + ): + """Saves system-wide LLM/TTS/STT config to system_config table. + Decoupled from any user account — applies to all users as their baseline.""" + if not services: + raise HTTPException(status_code=503, detail="Services not available.") + services.preference_service.update_global_config(prefs, db, admin) + return {"message": "Global provider configuration updated."} @router.get("/config", summary="Get Admin Configuration") async def get_admin_config( diff --git a/ai-hub/app/api/routes/api.py b/ai-hub/app/api/routes/api.py index 6aea324..e8a7fad 100644 --- a/ai-hub/app/api/routes/api.py +++ b/ai-hub/app/api/routes/api.py @@ -30,7 +30,7 @@ router.include_router(create_nodes_router(services)) router.include_router(create_skills_router(services)) router.include_router(create_agent_update_router()) - router.include_router(create_admin_router(), prefix="/admin") + router.include_router(create_admin_router(services), prefix="/admin") from .agents import create_agents_router router.include_router(create_agents_router(services), prefix="/agents", tags=["Agents"]) diff --git a/ai-hub/app/api/routes/mcp.py b/ai-hub/app/api/routes/mcp.py index ec80c2e..5df46cd 100644 --- a/ai-hub/app/api/routes/mcp.py +++ b/ai-hub/app/api/routes/mcp.py @@ -92,17 +92,22 @@ from app.db.session import get_db_session with get_db_session() as db: user_id = await _get_authenticated_user(request, token, db) - + if not user_id: logger.info("[MCP] SSE connection opened without initial auth.") + # Preserve the original credential (JWT or plain id) for the messages URL. + # Using user_id here breaks OIDC mode because the messages endpoint requires a JWT. + auth_header = request.headers.get("Authorization", "") + original_token = auth_header[7:] if auth_header.startswith("Bearer ") else token + queue = asyncio.Queue() session_id = str(uuid.uuid4()) _sse_sessions[session_id] = queue - + messages_url = f"{settings.HUB_PUBLIC_URL}/api/v1/mcp/messages?session_id={session_id}" - if user_id: - messages_url += f"&token={user_id}" + if original_token: + messages_url += f"&token={original_token}" origin = request.headers.get("origin") if origin: diff --git a/ai-hub/app/api/routes/user.py b/ai-hub/app/api/routes/user.py index 6df16eb..c0409ea 100644 --- a/ai-hub/app/api/routes/user.py +++ b/ai-hub/app/api/routes/user.py @@ -7,10 +7,15 @@ from typing import Optional, Annotated import logging import os +import time +import secrets import httpx import jwt import urllib.parse +# Short-lived CLI auth state store: state_token → {redirect_uri, expires_at} +_cli_state_store: dict = {} + # Correctly import from your application's schemas and dependencies from app.api.dependencies import ServiceContainer, get_db, get_current_user, get_optional_user from app.api import schemas @@ -66,7 +71,20 @@ result = await services.auth_service.handle_callback(code, db) user_id = result["user_id"] linked = result.get("linked", False) - + + # CLI auth flow: state is "cli:" — redirect to the CLI's local server + if state.startswith("cli:"): + state_token = state[4:] + entry = _cli_state_store.pop(state_token, None) + if entry and time.time() < entry["expires_at"]: + session_token = services.auth_service.create_session_token(user_id) + user = db.query(models.User).filter(models.User.id == user_id).first() + email = user.email if user else "" + params = urllib.parse.urlencode({"token": session_token, "email": email, "user_id": user_id}) + return redirect(url=f"{entry['redirect_uri']}?{params}") + # Expired or unknown state — fall through to dashboard + return redirect(url="/dashboard?error=cli_auth_expired") + # SECURITY: Prevent Open Redirect - Validate 'state' is a safe URL # Ideally this matches settings.FRONTEND_URL or a whitelist. safe_url = state @@ -76,7 +94,7 @@ allowed_domains = ["ai.jerxie.com", "localhost", "127.0.0.1"] api_domain = urllib.parse.urlparse(str(request.base_url)).netloc allowed_domains.append(api_domain) - + if parsed_url.netloc not in allowed_domains: logger.warning(f"Prevented potentially malicious open redirect to: {state}") safe_url = "/dashboard" @@ -84,14 +102,66 @@ frontend_redirect_url = f"{safe_url}?user_id={user_id}" if linked: frontend_redirect_url += "&linked=true" - + # Include the ID token if available (to allow the frontend to switch to JWT auth) id_token = result.get("id_token") if id_token: frontend_redirect_url += f"&token={id_token}" - + return redirect(url=frontend_redirect_url) + @router.get("/auth/cli-init", summary="Initialize CLI Authentication Flow") + async def cli_auth_init( + port: int = Query(..., ge=1024, le=65535, description="Local CLI server port for the callback"), + ): + """ + Starts a gcloud-style browser-based login for CLI/MCP config generation. + Returns an auth_url the CLI should open in the browser. + For local-auth-only deployments, returns oidc=false so the CLI falls back to password prompt. + """ + # Purge expired entries to avoid unbounded growth + now = time.time() + expired = [k for k, v in _cli_state_store.items() if now >= v["expires_at"]] + for k in expired: + _cli_state_store.pop(k, None) + + state_token = secrets.token_urlsafe(32) + _cli_state_store[state_token] = { + "redirect_uri": f"http://localhost:{port}/callback", + "expires_at": now + 300, + } + + if settings.OIDC_ENABLED: + auth_url = await services.auth_service.generate_login_url(f"cli:{state_token}") + return {"auth_url": auth_url, "oidc": True, "expires_in": 300} + + # Local auth mode: CLI will POST credentials and exchange state_token for a session token + return { + "auth_url": None, + "oidc": False, + "state_token": state_token, + "login_url": f"{settings.HUB_PUBLIC_URL}/api/v1/users/login/local", + "expires_in": 300, + } + + @router.post("/auth/cli-token", summary="Exchange CLI State Token for Session Token") + async def cli_token_exchange( + state_token: str = Query(...), + db: Session = Depends(get_db), + current_user: models.User = Depends(get_current_user), + ): + """ + After a successful local-auth login, the CLI exchanges the state_token for a + session JWT scoped to the authenticated user. One-time use, 5-minute TTL. + """ + if not current_user: + raise HTTPException(status_code=401, detail="Authentication required.") + entry = _cli_state_store.pop(state_token, None) + if not entry or time.time() >= entry["expires_at"]: + raise HTTPException(status_code=400, detail="Invalid or expired state token.") + token = services.auth_service.create_session_token(current_user.id) + return {"token": token, "email": current_user.email, "user_id": current_user.id} + @router.get("/config", summary="Public Auth Configuration") async def get_auth_config(): """Publicly accessible endpoint to check which auth methods are enabled.""" diff --git a/ai-hub/app/app.py b/ai-hub/app/app.py index 0c24aeb..c1cb280 100644 --- a/ai-hub/app/app.py +++ b/ai-hub/app/app.py @@ -74,7 +74,7 @@ logger.error(f"[AgentScheduler] Start fail: {se}") # Launch periodic LLM provider health check - asyncio.create_task(_periodic_provider_health_check()) + asyncio.create_task(_periodic_provider_health_check(app.state.services)) except Exception as e: logger.error(f"[M6] Failed to start gRPC server: {e}") @@ -122,7 +122,7 @@ await asyncio.sleep(600) # Run every 10 minutes to prevent CPU spin-lock (Assignment #5) -async def _periodic_provider_health_check(): +async def _periodic_provider_health_check(services): """Periodically tests actual connectivity to all system-level LLM providers.""" await asyncio.sleep(60) # Initial delay to let the system boot fully from app.core.providers.factory import get_llm_provider @@ -136,46 +136,52 @@ from app.config import settings with get_db_session() as db: - admin_user = db.query(models.User).filter(models.User.role == "admin").first() - if admin_user and admin_user.preferences: - prefs = copy.deepcopy(admin_user.preferences) - statuses = prefs.get("statuses", {}) - - llm_providers = prefs.get("llm", {}).get("providers", {}) + # Read current global config + global_cfg = services.preference_service.get_global_config(db) + statuses = global_cfg.get("statuses", {}) + + # Check LLM providers from settings (baseline) and system_config overrides + llm_providers = settings.LLM_PROVIDERS + if "llm" in global_cfg and "providers" in global_cfg["llm"]: + llm_providers.update(global_cfg["llm"]["providers"]) - changed = False - for p_name, p_data in list(llm_providers.items()): - api_key = p_data.get("api_key") - if not api_key or "***" in api_key or api_key.lower() in ["none", ""]: - continue - - status_key = f"llm_{p_name}" + changed = False + for p_name, p_data in list(llm_providers.items()): + api_key = p_data.get("api_key") + if not api_key or "***" in api_key or api_key.lower() in ["none", ""]: + continue + + status_key = f"llm_{p_name}" + new_status = "error" + try: + llm = get_llm_provider( + provider_name=p_name, + model_name=p_data.get("model") or "", + api_key_override=api_key + ) + # Short timeout to avoid blocking the loop heavily + async def test_llm(): + return await llm.acompletion(prompt="Hello") + await asyncio.wait_for(test_llm(), timeout=10.0) + new_status = "success" + except Exception as e: + logger.error(f"[Health Check] LLM {p_name} background test failed: {e}") new_status = "error" - try: - llm = get_llm_provider( - provider_name=p_name, - model_name=p_data.get("model") or "", - api_key_override=api_key - ) - # Short timeout to avoid blocking the loop heavily - async def test_llm(): - return await llm.acompletion(prompt="Hello") - await asyncio.wait_for(test_llm(), timeout=10.0) - new_status = "success" - except Exception as e: - logger.error(f"[Health Check] LLM {p_name} background test failed: {e}") - new_status = "error" - - if statuses.get(status_key) != new_status: - statuses[status_key] = new_status - changed = True - - if changed: - prefs["statuses"] = statuses - admin_user.preferences = prefs - flag_modified(admin_user, "preferences") - db.commit() - logger.info("[Health Check] System LLM statuses updated.") + + if statuses.get(status_key) != new_status: + statuses[status_key] = new_status + changed = True + + if changed: + # Update only the statuses key in system_config + row = db.query(models.SystemConfig).filter(models.SystemConfig.key == "statuses").first() + if row: + row.value = statuses + flag_modified(row, "value") + else: + db.add(models.SystemConfig(key="statuses", value=statuses)) + db.commit() + logger.info("[Health Check] System global provider statuses updated.") except Exception as e: logger.error(f"[Health Check] Background task error: {e}") diff --git a/ai-hub/app/core/services/preference.py b/ai-hub/app/core/services/preference.py index 32727d1..22179c9 100644 --- a/ai-hub/app/core/services/preference.py +++ b/ai-hub/app/core/services/preference.py @@ -50,9 +50,11 @@ tts_prefs = normalize_section("tts", settings.TTS_PROVIDER) stt_prefs = normalize_section("stt", settings.STT_PROVIDER) - system_prefs = self.services.user_service.get_system_settings(db) + system_prefs = self.get_global_config(db) system_statuses = system_prefs.get("statuses", {}) user_statuses = prefs_dict.get("statuses", {}) + merged_statuses = copy.deepcopy(system_statuses) + merged_statuses.update(user_statuses) def is_provider_healthy(section: str, provider_id: str, p_data: dict = None) -> bool: status_key = f"{section}_{provider_id}" @@ -141,7 +143,8 @@ "stt": { "active_provider": stt_prefs.get("active_provider") or (next(iter(stt_providers_effective), settings.STT_PROVIDER)), "providers": stt_providers_effective - } + }, + "statuses": merged_statuses } group = user.group or self.services.user_service.get_or_create_default_group(db) @@ -174,9 +177,6 @@ p_data["api_key"] = self.mask_key(p_data["api_key"]) return masked_dict - merged_statuses = copy.deepcopy(system_statuses) - merged_statuses.update(user_statuses) - return schemas.ConfigResponse( preferences=schemas.UserPreferences( llm=mask_section_prefs(llm_prefs), @@ -188,22 +188,92 @@ ) + def get_global_config(self, db) -> dict: + """Read system-wide provider config from system_config table. + Falls back to config.yaml / env vars if no DB entry exists.""" + from app.db import models as m + result = {} + for section in ("llm", "tts", "stt", "statuses"): + row = db.query(m.SystemConfig).filter(m.SystemConfig.key == section).first() + if row and row.value: + result[section] = row.value + return result + + def update_global_config(self, prefs: schemas.UserPreferences, db, admin_user) -> schemas.UserPreferences: + """Persist system-wide provider config to system_config table. + Admin-only. Does NOT touch users.preferences.""" + import json as _json + from datetime import datetime + from app.db import models as m + from sqlalchemy.orm.attributes import flag_modified + + def _restore_masked_keys(section_name, new_section): + """Preserve existing API keys when the UI sends back masked '****' values.""" + if not new_section or "providers" not in new_section: + return + row = db.query(m.SystemConfig).filter(m.SystemConfig.key == section_name).first() + old_providers = (row.value or {}).get("providers", {}) if row else {} + for p_name, p_data in new_section["providers"].items(): + if p_data.get("api_key") and "***" in str(p_data["api_key"]): + if p_name in old_providers: + p_data["api_key"] = old_providers[p_name].get("api_key") + + if prefs.llm: _restore_masked_keys("llm", prefs.llm) + if prefs.tts: _restore_masked_keys("tts", prefs.tts) + if prefs.stt: _restore_masked_keys("stt", prefs.stt) + + for section, value in (("llm", prefs.llm), ("tts", prefs.tts), ("stt", prefs.stt), ("statuses", prefs.statuses)): + if value is None: + continue + row = db.query(m.SystemConfig).filter(m.SystemConfig.key == section).first() + if row: + row.value = value + row.updated_at = datetime.utcnow() + row.updated_by = admin_user.id + flag_modified(row, "value") + else: + db.add(m.SystemConfig(key=section, value=value, updated_by=admin_user.id)) + + # Keep in-memory settings in sync for the current process lifetime + from app.config import settings as global_settings + if prefs.llm and "providers" in prefs.llm: + global_settings.LLM_PROVIDERS = dict(prefs.llm.get("providers", {})) + if prefs.tts and prefs.tts.get("active_provider"): + p_name = prefs.tts["active_provider"] + p_data = prefs.tts.get("providers", {}).get(p_name, {}) + if p_data: + global_settings.TTS_PROVIDER = p_name + global_settings.TTS_MODEL_NAME = p_data.get("model") or global_settings.TTS_MODEL_NAME + global_settings.TTS_VOICE_NAME = p_data.get("voice") or global_settings.TTS_VOICE_NAME + global_settings.TTS_API_KEY = p_data.get("api_key") or global_settings.TTS_API_KEY + if prefs.stt and prefs.stt.get("active_provider"): + p_name = prefs.stt["active_provider"] + p_data = prefs.stt.get("providers", {}).get(p_name, {}) + if p_data: + global_settings.STT_PROVIDER = p_name + global_settings.STT_MODEL_NAME = p_data.get("model") or global_settings.STT_MODEL_NAME + global_settings.STT_API_KEY = p_data.get("api_key") or global_settings.STT_API_KEY + + db.commit() + logger.info(f"Global provider config updated by admin {admin_user.id}") + return prefs + def update_user_config(self, user, prefs: schemas.UserPreferences, db) -> schemas.UserPreferences: - # When saving, if the api_key contains ****, we must retain the old one from the DB + """Save a user's personal provider preferences (active_provider choice + optional personal API keys). + Always writes to users.preferences only — never touches system_config.""" + from sqlalchemy.orm.attributes import flag_modified + old_prefs = user.preferences or {} def get_old_providers(section_name): section = old_prefs.get(section_name, {}) if isinstance(section, dict) and "providers" in section: return section["providers"] - - # Legacy extraction providers = {} legacy_keys = ["openai", "gemini", "deepseek", "gcloud_tts", "azure", "google", "elevenlabs"] for p in legacy_keys: if p in section: providers[p] = section[p] - if not providers and section and isinstance(section, dict): for k, v in section.items(): if k not in ["active_provider", "provider", "providers"] and isinstance(v, dict): @@ -213,18 +283,17 @@ def preserve_masked_keys(section_name, new_section): if not new_section or "providers" not in new_section: return - old_section_providers = get_old_providers(section_name) + old_providers = get_old_providers(section_name) for p_name, p_data in new_section["providers"].items(): if p_data.get("api_key") and "***" in str(p_data["api_key"]): - if p_name in old_section_providers: - p_data["api_key"] = old_section_providers[p_name].get("api_key") + if p_name in old_providers: + p_data["api_key"] = old_providers[p_name].get("api_key") if prefs.llm: preserve_masked_keys("llm", prefs.llm) if prefs.tts: preserve_masked_keys("tts", prefs.tts) if prefs.stt: preserve_masked_keys("stt", prefs.stt) - - current_prefs = dict(user.preferences or {}) + current_prefs = dict(old_prefs) current_prefs.update({ "llm": prefs.llm, "tts": prefs.tts, @@ -232,15 +301,14 @@ "statuses": prefs.statuses or {} }) user.preferences = current_prefs - from sqlalchemy.orm.attributes import flag_modified flag_modified(user, "preferences") logger.info(f"Saving personal preferences for user {user.id}") - db.add(user) db.commit() db.refresh(user) - + logger.info(f"Saved personal preferences for user {user.id}") + return schemas.UserPreferences( llm=user.preferences.get("llm", {}), tts=user.preferences.get("tts", {}), @@ -320,7 +388,7 @@ if is_masked: actual_key = prefs.get("api_key") if not actual_key: - s_prefs = self.services.user_service.get_system_settings(db) + s_prefs = self.get_global_config(db) actual_key = s_prefs.get(section, {}).get("providers", {}).get(req.provider_type or req.provider_name, {}).get("api_key") try: @@ -348,7 +416,7 @@ if not provider_name: # Fallback to registered active provider u_svc = getattr(self.services, "user_service", None) - sys_prefs = u_svc.get_system_settings(db) if u_svc else {} + sys_prefs = self.get_global_config(db) if u_svc else {} user_active = user.preferences.get("llm", {}).get("active_provider") if user and user.preferences else None base_key = user_active or sys_prefs.get("llm", {}).get("active_provider") @@ -477,7 +545,7 @@ from app.core.providers.factory import get_registered_tts_providers, get_registered_stt_providers if configured_only: - system_prefs = self.services.user_service.get_system_settings(db) + system_prefs = self.get_global_config(db) user_prefs = user.preferences if user else {} configured = set(system_prefs.get(section, {}).get("providers", {}).keys()) diff --git a/ai-hub/app/core/services/user.py b/ai-hub/app/core/services/user.py index 48e46af..585cbd5 100644 --- a/ai-hub/app/core/services/user.py +++ b/ai-hub/app/core/services/user.py @@ -241,10 +241,10 @@ def get_system_settings(self, db: Session) -> dict: """ - Retrieves global AI provider settings. + Retrieves global AI provider settings from the system_config table. Merge strategy controlled by settings.CONFIG_OVERRIDE: - - False (default): DB admin preferences win. Config.yaml/env only fill missing fields. + - False (default): DB system_config table wins. Config.yaml/env only fill missing fields. - True: Config.yaml/env always wins, ignoring any values in the DB. """ from app.config import settings @@ -261,11 +261,17 @@ } return config_prefs - # Default: DB admin preferences are authoritative (search for any user with 'admin' role). + # Default: DB system_config entries are authoritative. try: - admin_user = db.query(models.User).filter(models.User.role == "admin").order_by(models.User.created_at).first() - if admin_user and admin_user.preferences and admin_user.preferences.get("llm"): - return admin_user.preferences + # New centralized approach: read from system_config table + res = {} + rows = db.query(models.SystemConfig).all() + for r in rows: + res[r.key] = r.value + + if res.get("llm"): + # If we have at least LLM config, we consider the DB authoritative for providers + return res # Fallback to system-level defaults from config.yaml/env return { @@ -283,7 +289,7 @@ } } except Exception as e: - print(f"Error fetching system settings: {e}") + print(f"Error fetching system settings from system_config: {e}") return {} # --- Group Management Methods --- diff --git a/ai-hub/app/db/migrate.py b/ai-hub/app/db/migrate.py index 95799d6..dc3422f 100644 --- a/ai-hub/app/db/migrate.py +++ b/ai-hub/app/db/migrate.py @@ -376,6 +376,44 @@ except Exception as e: logger.warning(f"Could not assign default user to existing agents: {e}") + # --- system_config: global provider settings, decoupled from user rows --- + if not inspector.has_table("system_config"): + logger.info("Creating table 'system_config'...") + try: + conn.execute(text(""" + CREATE TABLE IF NOT EXISTS system_config ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL DEFAULT '{}', + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_by TEXT REFERENCES users(id) + ) + """)) + conn.commit() + logger.info("Table 'system_config' created.") + + # One-time data migration: copy first admin user's LLM/TTS/STT prefs + # into the new system_config table so nothing is lost on upgrade. + admin_row = conn.execute( + text("SELECT id, preferences FROM users WHERE role='admin' ORDER BY created_at LIMIT 1") + ).fetchone() + if admin_row: + import json as _json + raw = admin_row[1] + try: + prefs = _json.loads(raw) if isinstance(raw, str) else (raw or {}) + except Exception: + prefs = {} + for section in ("llm", "tts", "stt", "statuses"): + if prefs.get(section): + conn.execute( + text("INSERT OR IGNORE INTO system_config (key, value, updated_by) VALUES (:k, :v, :uid)"), + {"k": section, "v": _json.dumps(prefs[section]), "uid": admin_row[0]} + ) + conn.commit() + logger.info(f"Migrated admin preferences to system_config.") + except Exception as e: + logger.error(f"Failed to create/migrate 'system_config': {e}") + logger.info("Database migrations complete.") diff --git a/ai-hub/app/db/models/__init__.py b/ai-hub/app/db/models/__init__.py index 311f0c7..1494a26 100644 --- a/ai-hub/app/db/models/__init__.py +++ b/ai-hub/app/db/models/__init__.py @@ -1,5 +1,5 @@ from app.db.database import Base -from .user import User, Group +from .user import User, Group, SystemConfig from .session import Session, Message from .document import Document, VectorMetadata from .asset import PromptTemplate, Skill, SkillFile, SkillGroupAccess, MCPServer, AssetPermission @@ -7,7 +7,7 @@ from .agent import AgentTemplate, AgentInstance, AgentTrigger __all__ = [ - "User", "Group", + "User", "Group", "SystemConfig", "Session", "Message", "Document", "VectorMetadata", "PromptTemplate", "Skill", "SkillFile", "SkillGroupAccess", "MCPServer", "AssetPermission", diff --git a/ai-hub/app/db/models/user.py b/ai-hub/app/db/models/user.py index 42a3e3e..65e46e9 100644 --- a/ai-hub/app/db/models/user.py +++ b/ai-hub/app/db/models/user.py @@ -38,3 +38,14 @@ def __repr__(self): return f"" + + +class SystemConfig(Base): + """System-wide configuration, decoupled from any user account. + Admins write here; all users inherit these as their baseline.""" + __tablename__ = "system_config" + + key = Column(String, primary_key=True) # "llm" | "tts" | "stt" + value = Column(JSON, default={}, nullable=False) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + updated_by = Column(String, ForeignKey("users.id"), nullable=True) # audit trail diff --git a/ai-hub/integration_tests/test_system_config_v2.py b/ai-hub/integration_tests/test_system_config_v2.py new file mode 100644 index 0000000..31ae8da --- /dev/null +++ b/ai-hub/integration_tests/test_system_config_v2.py @@ -0,0 +1,114 @@ +import os +import httpx +import pytest +from conftest import BASE_URL + +def _headers(user_id=None): + uid = user_id or os.getenv("SYNC_TEST_USER_ID", "") + return {"X-User-ID": uid} + +def test_system_config_decoupling_flow(): + """ + E2E Flow: + 1. Admin sets global LLM default to 'openai'. + 2. Admin verifies global config. + 3. Normal User fetches config, sees 'openai' as effective default. + 4. Normal User overrides to 'gemini' personally. + 5. Admin fetches global config, it's still 'openai'. + 6. Admin fetches own personal config, it's 'openai' (decoupled). + """ + admin_id = os.getenv("SYNC_TEST_USER_ID") + with httpx.Client(timeout=10.0) as client: + # 0. Clear personal overrides and update group policy to allow 'openai' + r_groups = client.get(f"{BASE_URL}/users/admin/groups", headers=_headers(admin_id)) + groups = r_groups.json() + for g in groups: + if g["name"] == "Integration Default Group": + client.put(f"{BASE_URL}/users/admin/groups/{g['id']}", json={"name": g["name"], "policy": {"llm": ["gemini", "openai"]}}, headers=_headers(admin_id)) + break + + client.put(f"{BASE_URL}/users/me/config", json={"llm": {}, "tts": {}, "stt": {}, "statuses": {}}, headers=_headers(admin_id)) + + # 1. Admin sets global LLM default + global_payload = { + "llm": { + "active_provider": "openai", + "providers": { + "openai": {"model": "gpt-4", "api_key": "sk-global-key"} + } + }, + "tts": {}, "stt": {}, "statuses": {} + } + r = client.put(f"{BASE_URL}/admin/config/providers", json=global_payload, headers=_headers(admin_id)) + assert r.status_code == 200, f"Admin failed to update global config: {r.text}" + + # 2. Verify global config + r = client.get(f"{BASE_URL}/admin/config/providers", headers=_headers(admin_id)) + assert r.status_code == 200 + assert r.json()["llm"]["active_provider"] == "openai" + + # 3. Test: Normal user cannot call admin config endpoint + r = client.get(f"{BASE_URL}/admin/config/providers", headers=_headers("some-random-user")) + assert r.status_code in (401, 403, 404), f"Normal user should be blocked from admin config: {r.status_code}" + + # 4. Normal user fetches their own config (which should now merge from global default) + r = client.get(f"{BASE_URL}/users/me/config", headers=_headers(admin_id)) # Still using admin_id as 'me' + assert r.status_code == 200 + data = r.json() + # If the user has no personal overrides, they get the global default + assert data["effective"]["llm"]["active_provider"] == "openai" + + # 5. Admin sets personal override (even admins have personal prefs decoupled from global) + personal_payload = { + "llm": { + "active_provider": "gemini", + "providers": { + "gemini": {"model": "gemini-pro", "api_key": "sk-personal-key"} + } + }, + "tts": {}, "stt": {}, "statuses": {} + } + r = client.put(f"{BASE_URL}/users/me/config", json=personal_payload, headers=_headers(admin_id)) + assert r.status_code == 200 + + # 6. Verify personal override vs global + r = client.get(f"{BASE_URL}/users/me/config", headers=_headers(admin_id)) + data = r.json() + assert data["preferences"]["llm"]["active_provider"] == "gemini" + assert data["effective"]["llm"]["active_provider"] == "gemini" + + # 7. Verify global is still 'openai' + r = client.get(f"{BASE_URL}/admin/config/providers", headers=_headers(admin_id)) + assert r.json()["llm"]["active_provider"] == "openai" + + # 8. Cleanup global config to gemini for other tests + global_payload_reset = { + "llm": { + "active_provider": "gemini", + "providers": { + "gemini": {"model": "gemini-3-flash-preview", "api_key": os.getenv("GEMINI_API_KEY", "")} + } + }, + "tts": {}, "stt": {}, "statuses": {} + } + client.put(f"{BASE_URL}/admin/config/providers", json=global_payload_reset, headers=_headers(admin_id)) + +def test_statuses_persistence(): + """Verifies that provider statuses are correctly handled in global config.""" + admin_id = os.getenv("SYNC_TEST_USER_ID") + with httpx.Client(timeout=10.0) as client: + # 1. Update global statuses + full_payload = { + "llm": {}, "tts": {}, "stt": {}, + "statuses": {"llm_openai": "success", "llm_gemini": "error"} + } + r = client.put(f"{BASE_URL}/admin/config/providers", json=full_payload, headers=_headers(admin_id)) + assert r.status_code == 200 + + # 2. Verify global statuses + r = client.get(f"{BASE_URL}/admin/config/providers", headers=_headers(admin_id)) + assert r.json()["statuses"]["llm_openai"] == "success" + + # 3. Verify normal user sees these statuses in effective config + r = client.get(f"{BASE_URL}/users/me/config", headers=_headers(admin_id)) + assert r.json()["effective"]["statuses"]["llm_openai"] == "success" diff --git a/frontend/src/features/settings/components/cards/AIConfigurationCard.js b/frontend/src/features/settings/components/cards/AIConfigurationCard.js index 0bff453..b412882 100644 --- a/frontend/src/features/settings/components/cards/AIConfigurationCard.js +++ b/frontend/src/features/settings/components/cards/AIConfigurationCard.js @@ -27,7 +27,10 @@ verifying, handleVerifyProvider, handleDeleteProvider, - handleRenameProvider + handleRenameProvider, + isGlobalMode, + setIsGlobalMode, + userProfile } = context; const renderConfigSection = (sectionKey, title, description) => { @@ -161,8 +164,27 @@

Global AI providers, model endpoints, and synthesis engines

-
- +
+
+ {userProfile?.role === 'admin' && ( +
e.stopPropagation()}> + + +
+ )} +
+ +
{!collapsedSections.ai && ( @@ -207,9 +229,9 @@ diff --git a/frontend/src/features/settings/pages/SettingsPage.js b/frontend/src/features/settings/pages/SettingsPage.js index a7285dd..c0ab635 100644 --- a/frontend/src/features/settings/pages/SettingsPage.js +++ b/frontend/src/features/settings/pages/SettingsPage.js @@ -17,12 +17,13 @@ updateAdminAppConfig, testAdminOIDCConfig, testAdminSwarmConfig, - getAllProviders, updateUserConfig, getProviderModels, verifyProvider, exportUserConfig, - importUserConfig + importUserConfig, + getAdminProviderConfig, + updateAdminProviderConfig } from '../../../services/apiService'; import SettingsPageContent from '../components/SettingsPageContent'; @@ -56,6 +57,7 @@ const [fetchedModels, setFetchedModels] = useState({}); const [providerStatuses, setProviderStatuses] = useState({}); const [confirmAction, setConfirmAction] = useState(null); // { type, id, sectionKey, label } + const [isGlobalMode, setIsGlobalMode] = useState(false); const fileInputRef = React.useRef(null); useEffect(() => { @@ -313,46 +315,59 @@ if (e) e.preventDefault(); try { setSaving(true); - setMessage({ type: '', text: 'Saving and verifying configuration...' }); + setMessage({ type: '', text: isGlobalMode ? 'Saving global system defaults...' : 'Saving and verifying personal overrides...' }); const updatedStatuses = { ...providerStatuses }; const sections = ['llm', 'tts', 'stt']; - for (const section of sections) { - const activeId = config[section]?.active_provider; - if (activeId && !updatedStatuses[`${section}_${activeId}`]) { - const providerPrefs = config[section]?.providers?.[activeId]; - if (providerPrefs && providerPrefs.api_key) { - try { - const res = await verifyProvider(section, { - provider_name: activeId, - provider_type: providerPrefs.provider_type || activeId.split('_')[0], - api_key: providerPrefs.api_key, - model: providerPrefs.model, - voice: providerPrefs.voice - }); - updatedStatuses[`${section}_${activeId}`] = res.success ? 'success' : 'error'; - } catch (err) { - updatedStatuses[`${section}_${activeId}`] = 'error'; + // Verification only for personal mode for now, or just skip if no keys + if (!isGlobalMode) { + for (const section of sections) { + const activeId = config[section]?.active_provider; + if (activeId && !updatedStatuses[`${section}_${activeId}`]) { + const providerPrefs = config[section]?.providers?.[activeId]; + if (providerPrefs && providerPrefs.api_key) { + try { + const res = await verifyProvider(section, { + provider_name: activeId, + provider_type: providerPrefs.provider_type || activeId.split('_')[0], + api_key: providerPrefs.api_key, + model: providerPrefs.model, + voice: providerPrefs.voice + }); + updatedStatuses[`${section}_${activeId}`] = res.success ? 'success' : 'error'; + } catch (err) { + updatedStatuses[`${section}_${activeId}`] = 'error'; + } } } } + setProviderStatuses(updatedStatuses); } - setProviderStatuses(updatedStatuses); - const payload = { ...config, statuses: updatedStatuses }; - const data = await updateUserConfig(payload); + if (isGlobalMode) { + await updateAdminProviderConfig({ + llm: config.llm, + tts: config.tts, + stt: config.stt, + statuses: updatedStatuses + }); + setMessage({ type: 'success', text: 'Global system configuration updated successfully!' }); + } else { + const payload = { ...config, statuses: updatedStatuses }; + const data = await updateUserConfig(payload); + setConfig({ + llm: data.llm || {}, + tts: data.tts || {}, + stt: data.stt || {} + }); + setMessage({ type: 'success', text: 'Personal overrides saved successfully!' }); + } - setConfig({ - llm: data.llm || {}, - tts: data.tts || {}, - stt: data.stt || {} - }); - setMessage({ type: 'success', text: 'Settings saved and verified successfully!' }); setTimeout(() => setMessage({ type: '', text: '' }), 3000); } catch (err) { console.error("Error saving config:", err); - setMessage({ type: 'error', text: 'Failed to save configuration.' }); + setMessage({ type: 'error', text: `Failed to save configuration: ${err.message}` }); } finally { setSaving(false); } @@ -547,14 +562,26 @@ const loadConfig = async () => { try { setLoading(true); - const data = await getUserConfig(); - setConfig({ - llm: data.preferences?.llm || {}, - tts: data.preferences?.tts || {}, - stt: data.preferences?.stt || {} - }); - if (data.preferences?.statuses) { - setProviderStatuses(data.preferences.statuses); + if (isGlobalMode) { + const data = await getAdminProviderConfig(); + setConfig({ + llm: data.llm || {}, + tts: data.tts || {}, + stt: data.stt || {} + }); + if (data.statuses) { + setProviderStatuses(data.statuses); + } + } else { + const data = await getUserConfig(); + setConfig({ + llm: data.preferences?.llm || {}, + tts: data.preferences?.tts || {}, + stt: data.preferences?.stt || {} + }); + if (data.preferences?.statuses) { + setProviderStatuses(data.preferences.statuses); + } } setMessage({ type: '', text: '' }); } catch (err) { @@ -565,6 +592,10 @@ } }; + useEffect(() => { + loadConfig(); + }, [isGlobalMode]); + if (loading) { return (
@@ -639,6 +670,8 @@ userProfile, handleSaveAdminConfig, handleSaveConfig, + isGlobalMode, + setIsGlobalMode, setConfig, setProviderLists, handleRoleToggle, diff --git a/frontend/src/services/api/adminService.js b/frontend/src/services/api/adminService.js index eedefc5..b225a60 100644 --- a/frontend/src/services/api/adminService.js +++ b/frontend/src/services/api/adminService.js @@ -194,3 +194,20 @@ throw new Error(`Failed to download node bundle: ${e.message}`); } }; + +/** + * [ADMIN ONLY] Fetches global AI provider configuration. + */ +export const getAdminProviderConfig = async () => { + return await fetchWithAuth('/admin/config/providers'); +}; + +/** + * [ADMIN ONLY] Updates global AI provider configuration. + */ +export const updateAdminProviderConfig = async (config) => { + return await fetchWithAuth('/admin/config/providers', { + method: "PUT", + body: config + }); +}; diff --git a/scripts/cortex_mcp_auth.py b/scripts/cortex_mcp_auth.py new file mode 100755 index 0000000..91b627c --- /dev/null +++ b/scripts/cortex_mcp_auth.py @@ -0,0 +1,315 @@ +#!/usr/bin/env python3 +""" +cortex_mcp_auth.py — gcloud-style browser login for Cortex Hub MCP config generation. + +Usage: + python cortex_mcp_auth.py # uses https://ai.jerxie.com + python cortex_mcp_auth.py --hub https://my-hub.com + python cortex_mcp_auth.py --hub http://localhost:8000 --gemini-project antifravity + +What it does: + 1. Checks the hub's auth config (OIDC vs local password) + 2. OIDC path: opens a browser login URL, waits for the hub to redirect back + 3. Local path: prompts for email + password, POSTs to /users/login/local + 4. Writes ~/.gemini//mcp_config.json (Gemini CLI) + 5. Writes ~/.claude/mcp.json (Claude Code) — skip with --no-claude +""" + +import argparse +import getpass +import json +import os +import sys +import socket +import threading +import time +import urllib.parse +import urllib.request +import webbrowser +from http.server import BaseHTTPRequestHandler, HTTPServer +from pathlib import Path + +# ── ANSI colours (auto-disable when stdout is not a tty) ───────────────────── +_USE_COLOR = sys.stdout.isatty() +def _c(code, text): return f"\033[{code}m{text}\033[0m" if _USE_COLOR else text +def ok(msg): print(_c("32", "✓ ") + msg) +def err(msg): print(_c("31", "✗ ") + msg, file=sys.stderr) +def info(msg): print(_c("36", " ") + msg) +def bold(msg): return _c("1", msg) + + +# ── Helpers ─────────────────────────────────────────────────────────────────── + +def _free_port() -> int: + with socket.socket() as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def _hub_get(hub_url: str, path: str, params: dict = None) -> dict: + url = f"{hub_url.rstrip('/')}{path}" + if params: + url += "?" + urllib.parse.urlencode(params) + with urllib.request.urlopen(url, timeout=10) as r: + return json.loads(r.read()) + + +def _hub_post(hub_url: str, path: str, payload: dict, token: str = None) -> dict: + data = json.dumps(payload).encode() + req = urllib.request.Request( + f"{hub_url.rstrip('/')}{path}", + data=data, + headers={"Content-Type": "application/json"}, + method="POST", + ) + if token: + req.add_header("Authorization", f"Bearer {token}") + with urllib.request.urlopen(req, timeout=10) as r: + return json.loads(r.read()) + + +# ── One-shot local HTTP server that catches the OAuth redirect ──────────────── + +class _CallbackResult: + def __init__(self): + self.token = None + self.email = None + self.error = None + self._event = threading.Event() + + def set(self, token, email): + self.token, self.email = token, email + self._event.set() + + def set_error(self, msg): + self.error = msg + self._event.set() + + def wait(self, timeout=300): + return self._event.wait(timeout) + + +def _make_handler(result: _CallbackResult): + class Handler(BaseHTTPRequestHandler): + def do_GET(self): + qs = urllib.parse.parse_qs(urllib.parse.urlparse(self.path).query) + token = (qs.get("token") or [None])[0] + email = (qs.get("email") or [""])[0] + error = (qs.get("error") or [None])[0] + + if error: + result.set_error(error) + body = b"

Authentication failed. You can close this tab.

" + elif token: + result.set(token, email) + body = b"

Authenticated! You can close this tab and return to the terminal.

" + else: + result.set_error("no_token") + body = b"

No token received. Please try again.

" + + self.send_response(200) + self.send_header("Content-Type", "text/html") + self.send_header("Content-Length", str(len(body))) + self.end_headers() + self.wfile.write(body) + + def log_message(self, *_): + pass # silence request logs + + return Handler + + +def _oidc_flow(hub_url: str) -> tuple[str, str]: + """Opens browser, waits for token via local callback. Returns (token, email).""" + port = _free_port() + result = _CallbackResult() + server = HTTPServer(("127.0.0.1", port), _make_handler(result)) + + t = threading.Thread(target=server.serve_forever, daemon=True) + t.start() + + try: + init = _hub_get(hub_url, "/api/v1/users/auth/cli-init", {"port": port}) + except Exception as e: + server.shutdown() + raise RuntimeError(f"Failed to reach hub at {hub_url}: {e}") + + auth_url = init.get("auth_url") + if not auth_url: + server.shutdown() + raise RuntimeError("Hub returned no auth URL. Is OIDC configured?") + + print() + print(bold("Opening browser for login…")) + info(f"If the browser does not open, visit:\n {auth_url}") + print() + webbrowser.open(auth_url) + + if not result.wait(timeout=300): + server.shutdown() + raise TimeoutError("Authentication timed out (5 min). Please try again.") + + server.shutdown() + + if result.error: + raise RuntimeError(f"Authentication failed: {result.error}") + + return result.token, result.email + + +def _local_auth_flow(hub_url: str) -> tuple[str, str]: + """Prompts for email + password. Returns (token, email).""" + print() + info("This hub uses local password authentication.") + email = input(" Email: ").strip() + password = getpass.getpass(" Password: ") + + try: + resp = _hub_post(hub_url, "/api/v1/users/login/local", {"email": email, "password": password}) + except urllib.error.HTTPError as e: + body = e.read().decode(errors="replace") + detail = json.loads(body).get("detail", body) if body.startswith("{") else body + raise RuntimeError(f"Login failed: {detail}") + + return resp["token"], resp["email"] + + +# ── Config writers ───────────────────────────────────────────────────────────── + +def _sse_url(hub_url: str, token: str) -> str: + base = hub_url.rstrip("/") + return f"{base}/api/v1/mcp/sse?token={urllib.parse.quote(token, safe='')}" + + +def _write_gemini_config(hub_url: str, token: str, project: str, output: str = None): + if output: + config_path = Path(output).expanduser() + else: + config_path = Path.home() / ".gemini" / project / "mcp_config.json" + + config_path.parent.mkdir(parents=True, exist_ok=True) + + existing = {} + if config_path.exists(): + try: + existing = json.loads(config_path.read_text()) + except Exception: + pass + + servers = existing.get("mcpServers", {}) + servers["cortex-hub"] = { + "serverURL": _sse_url(hub_url, token), + } + existing["mcpServers"] = servers + + config_path.write_text(json.dumps(existing, indent=2)) + ok(f"Gemini MCP config written → {config_path}") + + +def _write_claude_config(hub_url: str, token: str): + import shutil, subprocess + url = _sse_url(hub_url, token) + + if shutil.which("claude"): + # Use claude CLI so it lands in the right project/user slot + r = subprocess.run( + ["claude", "mcp", "add", "--transport", "sse", "--scope", "user", "cortex-hub", url], + capture_output=True, text=True + ) + if r.returncode == 0: + ok("Claude Code MCP server registered (user scope) via `claude mcp add`") + return + # Scope flag may not exist in older versions — fall back to default scope + r = subprocess.run( + ["claude", "mcp", "add", "--transport", "sse", "cortex-hub", url], + capture_output=True, text=True + ) + if r.returncode == 0: + ok("Claude Code MCP server registered (project scope) via `claude mcp add`") + return + + # Fallback: write directly into ~/.claude.json at the top-level mcpServers key + config_path = Path.home() / ".claude.json" + existing = {} + if config_path.exists(): + try: + existing = json.loads(config_path.read_text()) + except Exception: + pass + existing.setdefault("mcpServers", {})["cortex-hub"] = {"type": "sse", "serverURL": url} + config_path.write_text(json.dumps(existing, indent=2)) + ok(f"Claude Code MCP config written → {config_path}") + + +# ── Main ────────────────────────────────────────────────────────────────────── + +def main(): + parser = argparse.ArgumentParser( + description="Authenticate with Cortex Hub and generate MCP config files." + ) + parser.add_argument( + "--hub", + default="https://ai.jerxie.com", + metavar="URL", + help="Cortex Hub base URL (default: https://ai.jerxie.com)", + ) + parser.add_argument( + "--gemini-project", + default="antigravity", + metavar="NAME", + help="Gemini CLI project name (default: antigravity)", + ) + parser.add_argument( + "--output", + default=None, + metavar="PATH", + help="Override output path for Gemini mcp_config.json", + ) + parser.add_argument( + "--no-claude", + action="store_true", + help="Skip writing ~/.claude/mcp.json", + ) + args = parser.parse_args() + + hub_url = args.hub.rstrip("/") + print() + print(bold(f"Cortex Hub MCP Auth")) + info(f"Hub: {hub_url}") + print() + + # 1. Determine auth method + try: + auth_config = _hub_get(hub_url, "/api/v1/users/config") + except Exception as e: + err(f"Cannot reach hub: {e}") + sys.exit(1) + + oidc_enabled = auth_config.get("oidc_configured", False) + + # 2. Authenticate + try: + if oidc_enabled: + token, email = _oidc_flow(hub_url) + else: + token, email = _local_auth_flow(hub_url) + except (RuntimeError, TimeoutError) as e: + err(str(e)) + sys.exit(1) + + print() + ok(f"Authenticated as {bold(email)}") + + # 3. Write configs + print() + _write_gemini_config(hub_url, token, args.gemini_project, args.output) + if not args.no_claude: + _write_claude_config(hub_url, token) + + print() + info("Token is valid for 24 hours. Re-run this script to refresh.") + print() + + +if __name__ == "__main__": + main()