diff --git a/ai-hub/app/api/routes/tts.py b/ai-hub/app/api/routes/tts.py index a6c9c4f..94334a8 100644 --- a/ai-hub/app/api/routes/tts.py +++ b/ai-hub/app/api/routes/tts.py @@ -104,20 +104,45 @@ @router.get( "/voices", - summary="List available Google Cloud TTS voices", + summary="List available TTS voices", response_description="A list of voice names" ) async def list_voices( + provider: str = Query(None, description="Optional provider name"), api_key: str = Query(None, description="Optional API key override"), db: Session = Depends(get_db), user_id: str = Depends(get_current_user_id) ): from app.config import settings import httpx - key_to_use = api_key or settings.TTS_API_KEY or settings.GEMINI_API_KEY + from app.core.providers.tts.gemini import GeminiTTSProvider + from app.core.providers.tts.gcloud_tts import GCloudTTSProvider + + # Resolve masked key if needed + key_to_use = api_key + if key_to_use and "***" in key_to_use and user_id: + user = services.user_service.get_user_by_id(db=db, user_id=user_id) + if user and user.preferences: + # Look for the key in any TTS provider since we don't necessarily know which one yet + for p_name, p_data in user.preferences.get("tts", {}).get("providers", {}).items(): + if p_data.get("api_key") and "***" not in p_data["api_key"]: + # If a provider was passed, only use its key + if not provider or provider == p_name: + key_to_use = p_data["api_key"] + break + + # Fallback to defaults + if not key_to_use or "***" in key_to_use: + key_to_use = settings.TTS_API_KEY or settings.GEMINI_API_KEY + + # If it's Gemini, or the key starts with AIza (common AI Studio key) + if provider == "google_gemini" or (not provider and key_to_use and key_to_use.startswith("AIza")): + return sorted(GeminiTTSProvider.AVAILABLE_VOICES) + + # Default or explicit GCloud if not key_to_use: return [] - + url = f"https://texttospeech.googleapis.com/v1/voices?key={key_to_use}" try: async with httpx.AsyncClient(timeout=10) as client: @@ -125,13 +150,21 @@ if res.status_code == 200: data = res.json() voices = data.get('voices', []) - # Filter for english or interesting ones names = [v['name'] for v in voices] return sorted(names) + + # If Google Cloud TTS fails, maybe it's actually an AI Studio key being used for Gemini? + # Fallback to Gemini voices if it seems likely + if key_to_use.startswith("AIza"): + return sorted(GeminiTTSProvider.AVAILABLE_VOICES) + return [] except Exception as e: import logging logging.getLogger(__name__).error(f"Failed to fetch voices: {e}") - return [] + # Final fallback to standard list if everything else fails but we have a key + if key_to_use and key_to_use.startswith("AIza"): + return sorted(GeminiTTSProvider.AVAILABLE_VOICES) + return sorted(GCloudTTSProvider.AVAILABLE_VOICES_EN) return router \ No newline at end of file diff --git a/ai-hub/app/api/routes/user.py b/ai-hub/app/api/routes/user.py index bb9d513..a4faaa7 100644 --- a/ai-hub/app/api/routes/user.py +++ b/ai-hub/app/api/routes/user.py @@ -18,11 +18,12 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -# Minimum OIDC configuration from environment variables -OIDC_CLIENT_ID = os.getenv("OIDC_CLIENT_ID", "") -OIDC_CLIENT_SECRET = os.getenv("OIDC_CLIENT_SECRET", "") -OIDC_SERVER_URL = os.getenv("OIDC_SERVER_URL", "") -OIDC_REDIRECT_URI = os.getenv("OIDC_REDIRECT_URI", "") +# Minimum OIDC configuration from settings +from app.config import settings +OIDC_CLIENT_ID = settings.OIDC_CLIENT_ID +OIDC_CLIENT_SECRET = settings.OIDC_CLIENT_SECRET +OIDC_SERVER_URL = settings.OIDC_SERVER_URL +OIDC_REDIRECT_URI = settings.OIDC_REDIRECT_URI # --- Derived OIDC Configuration --- OIDC_AUTHORIZATION_URL = f"{OIDC_SERVER_URL}/auth" @@ -144,7 +145,7 @@ except httpx.RequestError as e: logger.error(f"OIDC Token exchange request error: {e}") raise HTTPException(status_code=500, detail=f"Failed to communicate with OIDC provider: {e}") - except jwt.JWTDecodeError as e: + except jwt.DecodeError as e: logger.error(f"ID token decode error: {e}") raise HTTPException(status_code=400, detail="Failed to decode ID token from OIDC provider.") except Exception as e: @@ -177,6 +178,49 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {e}") + @router.get("/me/profile", response_model=schemas.UserProfile, summary="Get Current User Profile") + async def get_user_profile( + db: Session = Depends(get_db), + user_id: str = Depends(get_current_user_id) + ): + """Retrieves profile information for the current user.""" + if not user_id: + raise HTTPException(status_code=401, detail="Unauthorized") + user = services.user_service.get_user_by_id(db=db, user_id=user_id) + if not user: + raise HTTPException(status_code=404, detail="User not found") + + response = schemas.UserProfile.model_validate(user) + if user.group: + response.group_name = user.group.name + return response + + @router.put("/me/profile", response_model=schemas.UserProfile, summary="Update User Profile") + async def update_user_profile( + profile_data: schemas.UserProfileUpdate, + db: Session = Depends(get_db), + user_id: str = Depends(get_current_user_id) + ): + """Updates profile details for the current user.""" + if not user_id: + raise HTTPException(status_code=401, detail="Unauthorized") + user = services.user_service.get_user_by_id(db=db, user_id=user_id) + if not user: + raise HTTPException(status_code=404, detail="User not found") + + if profile_data.username: user.username = profile_data.username + if profile_data.full_name: user.full_name = profile_data.full_name + if profile_data.avatar_url: user.avatar_url = profile_data.avatar_url + + db.add(user) + db.commit() + db.refresh(user) + + response = schemas.UserProfile.model_validate(user) + if user.group: + response.group_name = user.group.name + return response + @router.get("/me/config", response_model=schemas.ConfigResponse, summary="Get Current User Preferences") async def get_user_config( db: Session = Depends(get_db), @@ -201,68 +245,74 @@ llm_prefs = prefs_dict.get("llm", {}) tts_prefs = prefs_dict.get("tts", {}) stt_prefs = prefs_dict.get("stt", {}) + + # Load system defaults from DB if needed + system_prefs = services.user_service.get_system_settings(db) user_providers = llm_prefs.get("providers", {}) - if not user_providers: - # Day zero: fall back to evaluating defaults from env/yaml - llm_providers_effective = { - "deepseek": { - "api_key": mask_key(settings.DEEPSEEK_API_KEY), - "model": settings.DEEPSEEK_MODEL_NAME - }, - "gemini": { - "api_key": mask_key(settings.GEMINI_API_KEY), - "model": settings.GEMINI_MODEL_NAME - }, - "openai": { - "api_key": mask_key(settings.OPENAI_API_KEY), - "model": "gpt-4" + # Try to get from system admin in DB first + system_llm = system_prefs.get("llm", {}).get("providers", {}) + if system_llm: + user_providers = system_llm + else: + # Fallback to hardcoded settings defaults (if any left in yaml) + user_providers = { + "deepseek": {"api_key": settings.DEEPSEEK_API_KEY, "model": settings.DEEPSEEK_MODEL_NAME}, + "gemini": {"api_key": settings.GEMINI_API_KEY, "model": settings.GEMINI_MODEL_NAME}, } - } - else: - # User has configured providers, only return what they have (plus masking) - llm_providers_effective = {} - for p, prefs in user_providers.items(): - llm_providers_effective[p] = { - "api_key": mask_key(prefs.get("api_key")), - "model": prefs.get("model") - } + + llm_providers_effective = {} + for p, p_p in user_providers.items(): + if p_p: + llm_providers_effective[p] = { + "api_key": mask_key(p_p.get("api_key")), + "model": p_p.get("model") + } user_tts_providers = tts_prefs.get("providers", {}) if not user_tts_providers: - tts_providers_effective = { - settings.TTS_PROVIDER: { - "api_key": mask_key(settings.TTS_API_KEY), - "model": settings.TTS_MODEL_NAME, - "voice": settings.TTS_VOICE_NAME + system_tts = system_prefs.get("tts", {}).get("providers", {}) + if system_tts: + user_tts_providers = system_tts + else: + user_tts_providers = { + settings.TTS_PROVIDER: { + "api_key": settings.TTS_API_KEY, + "model": settings.TTS_MODEL_NAME, + "voice": settings.TTS_VOICE_NAME + } } - } - else: - tts_providers_effective = {} - for p, prefs in user_tts_providers.items(): + + tts_providers_effective = {} + for p, p_p in user_tts_providers.items(): + if p_p: tts_providers_effective[p] = { - "api_key": mask_key(prefs.get("api_key")), - "model": prefs.get("model"), - "voice": prefs.get("voice") + "api_key": mask_key(p_p.get("api_key")), + "model": p_p.get("model"), + "voice": p_p.get("voice") } user_stt_providers = stt_prefs.get("providers", {}) if not user_stt_providers: - stt_providers_effective = { - settings.STT_PROVIDER: { - "api_key": mask_key(settings.STT_API_KEY), - "model": settings.STT_MODEL_NAME + system_stt = system_prefs.get("stt", {}).get("providers", {}) + if system_stt: + user_stt_providers = system_stt + else: + user_stt_providers = { + settings.STT_PROVIDER: { + "api_key": settings.STT_API_KEY, + "model": settings.STT_MODEL_NAME + } } - } - else: - stt_providers_effective = {} - for p, prefs in user_stt_providers.items(): + + stt_providers_effective = {} + for p, p_p in user_stt_providers.items(): + if p_p: stt_providers_effective[p] = { - "api_key": mask_key(prefs.get("api_key")), - "model": prefs.get("model") + "api_key": mask_key(p_p.get("api_key")), + "model": p_p.get("model") } - effective = { "llm": { "active_provider": llm_prefs.get("active_provider") or (next(iter(llm_providers_effective), None)) or "deepseek", @@ -277,8 +327,48 @@ "providers": stt_providers_effective } } + + # --- Group Policy Enforcement --- + # Only enforce for non-admin users. Admins should see all configured providers. + # If user has no group, they fall under the 'ungrouped' default group policy. + group = user.group or services.user_service.get_or_create_default_group(db) + if group and user.role != "admin": + policy = group.policy or {} + def apply_policy(section_key, policy_key, prefs_dict): + # A policy is a list of allowed provider IDs. Empty list means NO access to that section's providers. + allowed = policy.get(policy_key, []) + if not allowed: + # Explicit empty list or missing key results in NO providers + effective[section_key]["providers"] = {} + if prefs_dict and "providers" in prefs_dict: + prefs_dict["providers"] = {} + effective[section_key]["active_provider"] = "" + return prefs_dict + + # Filter the effective providers map + providers = effective[section_key]["providers"] + filtered_eff = {k: v for k, v in providers.items() if k in allowed} + effective[section_key]["providers"] = filtered_eff + + # Filter the user preferences too to avoid showing forbidden items + if prefs_dict and "providers" in prefs_dict: + prefs_dict["providers"] = { + k: v for k, v in prefs_dict["providers"].items() if k in allowed + } + + # Ensure active provider is still valid under policy + if effective[section_key].get("active_provider") not in allowed: + effective[section_key]["active_provider"] = next(iter(filtered_eff), None) or "" + + return prefs_dict + + llm_prefs = apply_policy("llm", "llm", llm_prefs) + tts_prefs = apply_policy("tts", "tts", tts_prefs) + stt_prefs = apply_policy("stt", "stt", stt_prefs) + # Ensure we mask the preferences dict we send back to the user def mask_section_prefs(section_dict): + if not section_dict: return {} import copy masked_dict = copy.deepcopy(section_dict) providers = masked_dict.get("providers", {}) @@ -291,7 +381,8 @@ preferences=schemas.UserPreferences( llm=mask_section_prefs(llm_prefs), tts=mask_section_prefs(tts_prefs), - stt=mask_section_prefs(stt_prefs) + stt=mask_section_prefs(stt_prefs), + statuses=user.preferences.get("statuses", {}) ), effective=effective ) @@ -321,54 +412,115 @@ if p_name in old_section: p_data["api_key"] = old_section[p_name].get("api_key") + def resolve_clone_from(section_name, new_section): + """ + If a new provider instance was created with _clone_from=, + copy the real API key from the source provider stored in the DB. + The _clone_from marker is then removed so it is not persisted. + """ + if not new_section or "providers" not in new_section: + return + # Look in DB-stored prefs first, then fall back to system settings + old_section = old_prefs.get(section_name, {}).get("providers", {}) + system_prefs = services.user_service.get_system_settings(db) + system_section = system_prefs.get(section_name, {}).get("providers", {}) + + for p_name, p_data in new_section["providers"].items(): + clone_source = p_data.pop("_clone_from", None) + if not clone_source: + continue + # Resolve real key: DB prefs > system settings + real_key = ( + old_section.get(clone_source, {}).get("api_key") + or system_section.get(clone_source, {}).get("api_key") + ) + if real_key and "***" not in str(real_key): + p_data["api_key"] = real_key + logger.info( + f"Resolved _clone_from: {p_name} inherited api_key from {clone_source} [{section_name}]" + ) + else: + logger.warning( + f"Could not resolve _clone_from for {p_name}: source '{clone_source}' key not found or masked." + ) + if prefs.llm: preserve_masked_keys("llm", prefs.llm) if prefs.tts: preserve_masked_keys("tts", prefs.tts) if prefs.stt: preserve_masked_keys("stt", prefs.stt) - user.preferences = {"llm": prefs.llm, "tts": prefs.tts, "stt": prefs.stt} - from sqlalchemy.orm.attributes import flag_modified - flag_modified(user, "preferences") - - # --- Day 2 Sync: Update Global Settings and YAML --- - # Update our global settings object to match the user's new preferences - # and then persist them to config.yaml if this is the "primary" user or admin. - from app.config import settings as global_settings - - # Sync LLM - if prefs.llm and prefs.llm.get("providers"): - global_settings.LLM_PROVIDERS.update(prefs.llm.get("providers", {})) - - # Sync TTS - if prefs.tts and prefs.tts.get("active_provider"): - p_name = prefs.tts["active_provider"] - p_data = prefs.tts.get("providers", {}).get(p_name, {}) - if p_data: - global_settings.TTS_PROVIDER = p_name - global_settings.TTS_MODEL_NAME = p_data.get("model") or global_settings.TTS_MODEL_NAME - global_settings.TTS_VOICE_NAME = p_data.get("voice") or global_settings.TTS_VOICE_NAME - global_settings.TTS_API_KEY = p_data.get("api_key") or global_settings.TTS_API_KEY - - # Sync STT - if prefs.stt and prefs.stt.get("active_provider"): - p_name = prefs.stt["active_provider"] - p_data = prefs.stt.get("providers", {}).get(p_name, {}) - if p_data: - global_settings.STT_PROVIDER = p_name - global_settings.STT_MODEL_NAME = p_data.get("model") or global_settings.STT_MODEL_NAME - global_settings.STT_API_KEY = p_data.get("api_key") or global_settings.STT_API_KEY + # resolve_clone_from must run AFTER preserve_masked_keys so the source + # provider's key is already unmasked in the new_section dict when needed. + if prefs.llm: resolve_clone_from("llm", prefs.llm) + if prefs.tts: resolve_clone_from("tts", prefs.tts) + if prefs.stt: resolve_clone_from("stt", prefs.stt) - # Write to config.yaml - try: - global_settings.save_to_yaml() - except Exception as ey: - logger.error(f"Failed to sync settings to YAML: {ey}") + user.preferences = { + "llm": prefs.llm, + "tts": prefs.tts, + "stt": prefs.stt, + "statuses": prefs.statuses or {} + } - logger.info(f"Saving updated preferences for user {user_id}: {list(user.preferences.keys())}") + # --- Enterprise RBAC Sync --- + # ONLY admins can sync to Global Settings and persist to config.yaml + if user.role == "admin": + from sqlalchemy.orm.attributes import flag_modified + flag_modified(user, "preferences") + + from app.config import settings as global_settings + + # Sync LLM + if prefs.llm and prefs.llm.get("providers"): + global_settings.LLM_PROVIDERS.update(prefs.llm.get("providers", {})) + + # Sync TTS + if prefs.tts and prefs.tts.get("active_provider"): + p_name = prefs.tts["active_provider"] + p_data = prefs.tts.get("providers", {}).get(p_name, {}) + if p_data: + global_settings.TTS_PROVIDER = p_name + global_settings.TTS_MODEL_NAME = p_data.get("model") or global_settings.TTS_MODEL_NAME + global_settings.TTS_VOICE_NAME = p_data.get("voice") or global_settings.TTS_VOICE_NAME + global_settings.TTS_API_KEY = p_data.get("api_key") or global_settings.TTS_API_KEY + + # Sync STT + if prefs.stt and prefs.stt.get("active_provider"): + p_name = prefs.stt["active_provider"] + p_data = prefs.stt.get("providers", {}).get(p_name, {}) + if p_data: + global_settings.STT_PROVIDER = p_name + global_settings.STT_MODEL_NAME = p_data.get("model") or global_settings.STT_MODEL_NAME + global_settings.STT_API_KEY = p_data.get("api_key") or global_settings.STT_API_KEY + + # Write to config.yaml + try: + global_settings.save_to_yaml() + except Exception as ey: + logger.error(f"Failed to sync settings to YAML: {ey}") + + logger.info(f"Saving updated global preferences via admin {user_id}") + else: + # Normal users: only allow modifying their personal active_provider selection + # but preserve OLD keys if they somehow try to send new ones (UI should prevent this anyway) + # Actually, let's just ignore their "providers" map entirely if we want strict admin control + user.preferences["llm"]["active_provider"] = prefs.llm.get("active_provider") + user.preferences["tts"]["active_provider"] = prefs.tts.get("active_provider") + user.preferences["stt"]["active_provider"] = prefs.stt.get("active_provider") + user.preferences["statuses"] = prefs.statuses or {} + from sqlalchemy.orm.attributes import flag_modified + flag_modified(user, "preferences") + logger.info(f"Saving personal preferences for user {user_id}") + db.add(user) db.commit() db.refresh(user) - return schemas.UserPreferences(llm=user.preferences.get("llm", {}), tts=user.preferences.get("tts", {}), stt=user.preferences.get("stt", {})) + return schemas.UserPreferences( + llm=user.preferences.get("llm", {}), + tts=user.preferences.get("tts", {}), + stt=user.preferences.get("stt", {}), + statuses=user.preferences.get("statuses", {}) + ) @router.get("/me/config/models", response_model=list[schemas.ModelInfoResponse], summary="Get Models for Provider") async def get_provider_models(provider_name: str, section: str = "llm"): @@ -412,14 +564,45 @@ return results @router.get("/me/config/providers", response_model=list[str], summary="Get All Valid Providers per Section") - async def get_all_providers(section: str = "llm"): + async def get_all_providers( + section: str = "llm", + configured_only: bool = Query(False, description="If true, only returns providers currently configured in preferences or system defaults"), + db: Session = Depends(get_db), + user_id: str = Depends(get_current_user_id) + ): import litellm from app.core.providers.factory import get_registered_tts_providers, get_registered_stt_providers + if configured_only: + # Fetch effective config to see what's actually configured + user = services.user_service.get_user_by_id(db=db, user_id=user_id) + # We don't want to recursivly call our own logic, but we can look at system settings + user prefs + system_prefs = services.user_service.get_system_settings(db) + user_prefs = user.preferences if user else {} + + configured = set() + # Add from system defaults + for p in system_prefs.get(section, {}).get("providers", {}).keys(): + configured.add(p) + # Add from user overrides + for p in user_prefs.get(section, {}).get("providers", {}).keys(): + configured.add(p) + + # If nothing configured, fallback to hardcoded defaults in settings (simulating get_user_config logic) + if not configured: + from app.config import settings + if section == "llm": + configured.update(["deepseek", "gemini"]) + elif section == "tts": + configured.add(settings.TTS_PROVIDER) + elif section == "stt": + configured.add(settings.STT_PROVIDER) + + return sorted(list(configured)) + if section == "llm": return ["general"] + [p.value for p in litellm.LlmProviders] elif section == "tts": - # For TTS, combine registry + possibly litellm providers that support audio_speech return ["general"] + get_registered_tts_providers() + ["openai"] elif section == "stt": return ["general"] + get_registered_stt_providers() + ["openai"] @@ -436,6 +619,9 @@ from app.core.providers.factory import get_llm_provider if not user_id: raise HTTPException(status_code=401, detail="Unauthorized") + user = services.user_service.get_user_by_id(db=db, user_id=user_id) + if not user or user.role != "admin": + raise HTTPException(status_code=403, detail="Forbidden: Admin only") actual_key = req.api_key try: llm_prefs = {} @@ -443,16 +629,30 @@ if user and user.preferences: llm_prefs = user.preferences.get("llm", {}).get("providers", {}).get(req.provider_name, {}) + # Handle masked keys by backfilling from stored prefs if needed + if actual_key and "***" in actual_key: + actual_key = llm_prefs.get("api_key") + if not actual_key: + # Fallback to system defaults if admin + system_prefs = services.user_service.get_system_settings(db) + actual_key = system_prefs.get("llm", {}).get("providers", {}).get(req.provider_type or req.provider_name, {}).get("api_key") + kwargs = {k: v for k, v in llm_prefs.items() if k not in ["api_key", "model"]} + if req.provider_type: + kwargs["provider_type"] = req.provider_type + llm = get_llm_provider( provider_name=req.provider_name, model_name=req.model or "", api_key_override=actual_key, **kwargs ) + # LiteLLM check: litellm models are callable res = llm("Hello") return schemas.VerifyProviderResponse(success=True, message="Connection successful!") except Exception as e: + import logging + logging.getLogger(__name__).error(f"LLM Verification failed for {req.provider_name} ({req.provider_type}): {e}") return schemas.VerifyProviderResponse(success=False, message=str(e)) @router.post("/me/config/verify_tts", response_model=schemas.VerifyProviderResponse) @@ -463,42 +663,43 @@ ): from app.core.providers.factory import get_tts_provider from app.config import settings + import logging + logger = logging.getLogger(__name__) + if not user_id: raise HTTPException(status_code=401, detail="Unauthorized") + user = services.user_service.get_user_by_id(db=db, user_id=user_id) + if not user or user.role != "admin": + raise HTTPException(status_code=403, detail="Forbidden: Admin only") + actual_key = req.api_key try: - tts_prefs = {} - user = services.user_service.get_user_by_id(db=db, user_id=user_id) - if user and user.preferences: - tts_prefs = user.preferences.get("tts", {}).get("providers", {}).get(req.provider_name, {}) + tts_prefs = user.preferences.get("tts", {}).get("providers", {}).get(req.provider_name, {}) if user.preferences else {} + + # Key resolution: Masked keys should be replaced with real ones from DB or system config + if not actual_key or "***" in str(actual_key): + actual_key = tts_prefs.get("api_key") + if not actual_key or "***" in str(actual_key): + # Try system settings + system_prefs = services.user_service.get_system_settings(db) + actual_key = system_prefs.get("tts", {}).get("providers", {}).get(req.provider_type or req.provider_name, {}).get("api_key") + # Final fallback to settings.py constants + if not actual_key: actual_key = settings.TTS_API_KEY or settings.GEMINI_API_KEY - # Resolve the real key: masked/absent → DB key → system TTS key → system Gemini key - def is_masked(k): - return not k or k in ("None", "none", "") or "*" in str(k) - - if is_masked(actual_key): - db_key = tts_prefs.get("api_key") - if not is_masked(db_key): - actual_key = db_key - logger.debug(f"verify_tts: using DB key for {req.provider_name}") - elif not is_masked(settings.TTS_API_KEY): - actual_key = settings.TTS_API_KEY - logger.debug(f"verify_tts: using system TTS_API_KEY for {req.provider_name}") - elif not is_masked(settings.GEMINI_API_KEY): - actual_key = settings.GEMINI_API_KEY - logger.debug(f"verify_tts: using system GEMINI_API_KEY for {req.provider_name}") - - logger.info(f"verify_tts: provider={req.provider_name}, model={req.model}, key_prefix={str(actual_key)[:8] if actual_key else 'NONE'}") + logger.info(f"verify_tts: instance={req.provider_name}, type={req.provider_type}, model={req.model}") kwargs = {k: v for k, v in tts_prefs.items() if k not in ["api_key", "model", "voice"]} - tts = get_tts_provider( + if req.provider_type: + kwargs["provider_type"] = req.provider_type + + provider = get_tts_provider( provider_name=req.provider_name, api_key=actual_key, model_name=req.model or "", voice_name=req.voice or "", **kwargs ) - await tts.generate_speech("Testing.") + await provider.generate_speech("Test") return schemas.VerifyProviderResponse(success=True, message="Connection successful!") except Exception as e: logger.error(f"TTS verification failed for {req.provider_name}: {e}") @@ -511,38 +712,41 @@ user_id: str = Depends(get_current_user_id) ): from app.core.providers.factory import get_stt_provider + from app.config import settings + import logging + logger = logging.getLogger(__name__) + if not user_id: raise HTTPException(status_code=401, detail="Unauthorized") + user = services.user_service.get_user_by_id(db=db, user_id=user_id) + if not user or user.role != "admin": + raise HTTPException(status_code=403, detail="Forbidden: Admin only") + actual_key = req.api_key - from app.config import settings - def is_masked(k): - return not k or k in ("None", "none", "") or "*" in str(k) - - if is_masked(actual_key): - user = services.user_service.get_user_by_id(db=db, user_id=user_id) - if user and user.preferences: - prefs = user.preferences.get("stt", {}).get("providers", {}).get(req.provider_name, {}) - db_key = prefs.get("api_key") - if not is_masked(db_key): - actual_key = db_key - if is_masked(actual_key) and not is_masked(settings.STT_API_KEY): - actual_key = settings.STT_API_KEY - if is_masked(actual_key) and not is_masked(settings.GEMINI_API_KEY): - actual_key = settings.GEMINI_API_KEY try: - stt = get_stt_provider(provider_name=req.provider_name, model_name=req.model or "", api_key=actual_key) - import io - import wave - wav_io = io.BytesIO() - with wave.open(wav_io, 'wb') as wav_file: - wav_file.setnchannels(1) - wav_file.setsampwidth(2) - wav_file.setframerate(16000) - wav_file.writeframes(b'\x00\x00' * 16000) # 1 second of silence - res = await stt.transcribe_audio(wav_io.getvalue()) - # Empty transcript is expected for silent audio — connection works fine - return schemas.VerifyProviderResponse(success=True, message="Connection successful!") + stt_prefs = user.preferences.get("stt", {}).get("providers", {}).get(req.provider_name, {}) if user.preferences else {} + + if not actual_key or "***" in str(actual_key): + actual_key = stt_prefs.get("api_key") + if not actual_key or "***" in str(actual_key): + system_prefs = services.user_service.get_system_settings(db) + actual_key = system_prefs.get("stt", {}).get("providers", {}).get(req.provider_type or req.provider_name, {}).get("api_key") + if not actual_key: actual_key = settings.STT_API_KEY or settings.GEMINI_API_KEY + + kwargs = {k: v for k, v in stt_prefs.items() if k not in ["api_key", "model"]} + if req.provider_type: + kwargs["provider_type"] = req.provider_type + + provider = get_stt_provider( + provider_name=req.provider_name, + api_key=actual_key, + model_name=req.model or "", + **kwargs + ) + # Minimal STT check: factory init is usually enough to catch invalid credentials for SDK-based providers + return schemas.VerifyProviderResponse(success=True, message="Provider initialized. Full transcription test requires audio payload.") except Exception as e: + logger.error(f"STT verification failed for {req.provider_name}: {e}") return schemas.VerifyProviderResponse(success=False, message=str(e)) @router.post("/logout", summary="Log Out the Current User") @@ -557,13 +761,13 @@ db: Session = Depends(get_db), user_id: str = Depends(get_current_user_id) ): - """Exports the effective user configuration as a YAML file.""" + """Exports the effective user configuration as a YAML file (Admin only).""" from fastapi.responses import PlainTextResponse if not user_id: raise HTTPException(status_code=401, detail="Unauthorized") user = services.user_service.get_user_by_id(db=db, user_id=user_id) - if not user: - raise HTTPException(status_code=404, detail="User not found") + if not user or user.role != "admin": + raise HTTPException(status_code=403, detail="Forbidden: Admin only") prefs_dict = user.preferences or {} from app.config import settings @@ -608,20 +812,10 @@ "api_key": fallback_api_key } + # Layer 2 (Day 2) Export: Only LLM, TTS, STT yaml_data = { - "application": { - "project_name": settings.PROJECT_NAME, - "version": settings.VERSION, - "log_level": settings.LOG_LEVEL - }, - "database": { - "mode": settings.DB_MODE, - }, - "llm_providers": llm_providers_export, - "embedding_provider": { - "provider": settings.EMBEDDING_PROVIDER, - "model_name": settings.EMBEDDING_MODEL_NAME, - "api_key": settings.EMBEDDING_API_KEY + "llm_providers": { + "providers": user_providers or settings.LLM_PROVIDERS }, "tts_provider": get_provider_export(tts_prefs, settings.TTS_PROVIDER, settings.TTS_MODEL_NAME, settings.TTS_API_KEY, settings.TTS_VOICE_NAME), "stt_provider": get_provider_export(stt_prefs, settings.STT_PROVIDER, settings.STT_MODEL_NAME, settings.STT_API_KEY) @@ -639,7 +833,7 @@ return PlainTextResponse( content=yaml_str, media_type="application/x-yaml", - headers={"Content-Disposition": "attachment; filename=\"config.yaml\""} + headers={"Content-Disposition": "attachment; filename=\"day2_config.yaml\""} ) @router.post("/me/config/import", response_model=schemas.UserPreferences, summary="Import Configurations from YAML") @@ -706,7 +900,12 @@ "model": stt_data.get("model_name") } - user.preferences = { "llm": new_llm, "tts": new_tts, "stt": new_stt } + user.preferences = { + "llm": new_llm, + "tts": new_tts, + "stt": new_stt, + "statuses": {} + } from sqlalchemy.orm.attributes import flag_modified flag_modified(user, "preferences") @@ -740,4 +939,150 @@ db.refresh(user) return schemas.UserPreferences(llm=user.preferences.get("llm", {}), tts=user.preferences.get("tts", {}), stt=user.preferences.get("stt", {})) + # --- NEW ADMIN ROUTES --- + + @router.get("/admin/users", response_model=list[schemas.UserProfile], summary="List All Users (Admin Only)") + async def admin_list_users( + db: Session = Depends(get_db), + user_id: str = Depends(get_current_user_id) + ): + """Returns a list of all registered users in the system.""" + if not user_id: + raise HTTPException(status_code=401, detail="Unauthorized") + user = services.user_service.get_user_by_id(db=db, user_id=user_id) + if not user or user.role != "admin": + raise HTTPException(status_code=403, detail="Forbidden: Admin only") + + users = services.user_service.get_all_users(db) + response = [] + for u in users: + p = schemas.UserProfile.model_validate(u) + if u.group: + p.group_name = u.group.name + response.append(p) + return response + + @router.put("/admin/users/{uid}/role", response_model=schemas.UserProfile, summary="Update User Role (Admin Only)") + async def admin_update_role( + uid: str, + role_req: schemas.UserRoleUpdate, + db: Session = Depends(get_db), + user_id: str = Depends(get_current_user_id) + ): + """Updates a user's role. Prevents demoting the last administrator.""" + if not user_id: + raise HTTPException(status_code=401, detail="Unauthorized") + user = services.user_service.get_user_by_id(db=db, user_id=user_id) + if not user or user.role != "admin": + raise HTTPException(status_code=403, detail="Forbidden: Admin only") + + success = services.user_service.update_user_role(db, uid, role_req.role) + if not success: + raise HTTPException(status_code=400, detail="Failed to update role. Maybe this is the last admin?") + + return services.user_service.get_user_by_id(db, uid) + + @router.put("/admin/users/{uid}/group", response_model=schemas.UserProfile, summary="Update User Group (Admin Only)") + async def admin_update_user_group( + uid: str, + group_req: schemas.UserGroupUpdate, + db: Session = Depends(get_db), + user_id: str = Depends(get_current_user_id) + ): + """Assigns a user to a group.""" + if not user_id: + raise HTTPException(status_code=401, detail="Unauthorized") + user = services.user_service.get_user_by_id(db=db, user_id=user_id) + if not user or user.role != "admin": + raise HTTPException(status_code=403, detail="Forbidden: Admin only") + + success = services.user_service.assign_user_to_group(db, uid, group_req.group_id) + if not success: + raise HTTPException(status_code=404, detail="User or group not found") + + return services.user_service.get_user_by_id(db, uid) + + @router.get("/admin/groups", response_model=list[schemas.GroupInfo], summary="List All Groups (Admin Only)") + async def admin_list_groups( + db: Session = Depends(get_db), + user_id: str = Depends(get_current_user_id) + ): + """Returns all existing groups.""" + if not user_id: + raise HTTPException(status_code=401, detail="Unauthorized") + user = services.user_service.get_user_by_id(db=db, user_id=user_id) + if not user or user.role != "admin": + raise HTTPException(status_code=403, detail="Forbidden: Admin only") + + # Explicitly convert to Pydantic models within the session scope + # to prevent SQLAlchemy lazy-loading issues in async context. + groups = services.user_service.get_all_groups(db) + return [schemas.GroupInfo.model_validate(g) for g in groups] + + @router.post("/admin/groups", response_model=schemas.GroupInfo, summary="Create Group (Admin Only)") + async def admin_create_group( + group_req: schemas.GroupCreate, + db: Session = Depends(get_db), + user_id: str = Depends(get_current_user_id) + ): + """Creates a new group.""" + if not user_id: + raise HTTPException(status_code=401, detail="Unauthorized") + user = services.user_service.get_user_by_id(db=db, user_id=user_id) + if not user or user.role != "admin": + raise HTTPException(status_code=403, detail="Forbidden: Admin only") + + group = services.user_service.create_group(db, group_req.name, group_req.description, group_req.policy) + if group is None: + raise HTTPException(status_code=409, detail=f"A group named '{group_req.name}' already exists. Please choose a unique name.") + return schemas.GroupInfo.model_validate(group) + + @router.put("/admin/groups/{gid}", response_model=schemas.GroupInfo, summary="Update Group (Admin Only)") + async def admin_update_group( + gid: str, + group_req: schemas.GroupUpdate, + db: Session = Depends(get_db), + user_id: str = Depends(get_current_user_id) + ): + """Updates a group's metadata or policy.""" + if not user_id: + raise HTTPException(status_code=401, detail="Unauthorized") + user = services.user_service.get_user_by_id(db=db, user_id=user_id) + if not user or user.role != "admin": + raise HTTPException(status_code=403, detail="Forbidden: Admin only") + + # The 'ungrouped' group cannot be renamed — only its policy can be updated + if gid == "ungrouped" and group_req.name and group_req.name.strip().lower() != "ungrouped": + raise HTTPException(status_code=403, detail="The default 'Ungrouped' group cannot be renamed.") + + group = services.user_service.update_group(db, gid, group_req.name, group_req.description, group_req.policy) + if group is None: + raise HTTPException(status_code=404, detail="Group not found") + if group is False: + raise HTTPException(status_code=409, detail=f"A group named '{group_req.name}' already exists. Please choose a unique name.") + return schemas.GroupInfo.model_validate(group) + + @router.delete("/admin/groups/{gid}", summary="Delete Group (Admin Only)") + async def admin_delete_group( + gid: str, + db: Session = Depends(get_db), + user_id: str = Depends(get_current_user_id) + ): + """Deletes a group. Users are moved back to 'ungrouped'.""" + if not user_id: + raise HTTPException(status_code=401, detail="Unauthorized") + user = services.user_service.get_user_by_id(db=db, user_id=user_id) + if not user or user.role != "admin": + raise HTTPException(status_code=403, detail="Forbidden: Admin only") + + # Cannot delete the system default group + if gid == "ungrouped": + raise HTTPException(status_code=403, detail="The default 'Ungrouped' group cannot be deleted.") + + success = services.user_service.delete_group(db, gid) + if not success: + raise HTTPException(status_code=400, detail="Failed to delete group.") + + return {"message": "Group deleted successfully"} + return router diff --git a/ai-hub/app/api/schemas.py b/ai-hub/app/api/schemas.py index 630b178..357ed08 100644 --- a/ai-hub/app/api/schemas.py +++ b/ai-hub/app/api/schemas.py @@ -17,12 +17,57 @@ is_logged_in: bool = Field(True, description="Indicates if the user is currently authenticated.") is_anonymous: bool = Field(False, description="Indicates if the user is an anonymous user.") +class UserProfile(BaseModel): + id: str + email: str + username: Optional[str] = None + full_name: Optional[str] = None + role: str = "user" + group_id: Optional[str] = None + group_name: Optional[str] = None + avatar_url: Optional[str] = None + created_at: datetime + last_login_at: Optional[datetime] = None + model_config = ConfigDict(from_attributes=True) + +class UserProfileUpdate(BaseModel): + username: Optional[str] = None + full_name: Optional[str] = None + avatar_url: Optional[str] = None + +class UserRoleUpdate(BaseModel): + role: str + +class UserGroupUpdate(BaseModel): + group_id: str + +# --- Group Schemas --- +class GroupBase(BaseModel): + name: str + description: Optional[str] = None + # Policy: {"llm": ["openai"], "tts": ["gcloud"], "stt": ["google"]} + policy: dict = Field(default_factory=dict) + +class GroupCreate(GroupBase): + pass + +class GroupUpdate(BaseModel): + name: Optional[str] = None + description: Optional[str] = None + policy: Optional[dict] = None + +class GroupInfo(GroupBase): + id: str + created_at: Optional[datetime] = None + model_config = ConfigDict(from_attributes=True) + # --- General Schemas --- class UserPreferences(BaseModel): """Schema for user-specific LLM, TTS, STT preferences.""" llm: dict = Field(default_factory=dict) tts: dict = Field(default_factory=dict) stt: dict = Field(default_factory=dict) + statuses: Optional[dict] = Field(default_factory=dict) class ConfigResponse(BaseModel): """Schema for returning user preferences alongside effective settings.""" @@ -129,6 +174,7 @@ class VerifyProviderRequest(BaseModel): provider_name: str + provider_type: Optional[str] = None api_key: Optional[str] = None model: Optional[str] = None voice: Optional[str] = None diff --git a/ai-hub/app/config.py b/ai-hub/app/config.py index c4fd19d..d54d7cf 100644 --- a/ai-hub/app/config.py +++ b/ai-hub/app/config.py @@ -14,6 +14,13 @@ project_name: str = "Cortex Hub" version: str = "1.0.0" log_level: str = "INFO" + super_admins: list[str] = Field(default_factory=list) + +class OIDCSettings(BaseModel): + client_id: str = "" + client_secret: str = "" + server_url: str = "" + redirect_uri: str = "" class DatabaseSettings(BaseModel): mode: str = "sqlite" @@ -51,6 +58,7 @@ embedding_provider: EmbeddingProviderSettings = Field(default_factory=EmbeddingProviderSettings) tts_provider: ProviderSettings = Field(default_factory=ProviderSettings) stt_provider: ProviderSettings = Field(default_factory=ProviderSettings) + oidc: OIDCSettings = Field(default_factory=OIDCSettings) # --- 2. Create the Final Settings Object --- @@ -84,6 +92,22 @@ self.LOG_LEVEL: str = os.getenv("LOG_LEVEL") or \ get_from_yaml(["application", "log_level"]) or \ config_from_pydantic.application.log_level + self.SUPER_ADMINS: list[str] = get_from_yaml(["application", "super_admins"]) or \ + config_from_pydantic.application.super_admins + + # --- OIDC Settings --- + self.OIDC_CLIENT_ID: str = os.getenv("OIDC_CLIENT_ID") or \ + get_from_yaml(["oidc", "client_id"]) or \ + config_from_pydantic.oidc.client_id + self.OIDC_CLIENT_SECRET: str = os.getenv("OIDC_CLIENT_SECRET") or \ + get_from_yaml(["oidc", "client_secret"]) or \ + config_from_pydantic.oidc.client_secret + self.OIDC_SERVER_URL: str = os.getenv("OIDC_SERVER_URL") or \ + get_from_yaml(["oidc", "server_url"]) or \ + config_from_pydantic.oidc.server_url + self.OIDC_REDIRECT_URI: str = os.getenv("OIDC_REDIRECT_URI") or \ + get_from_yaml(["oidc", "redirect_uri"]) or \ + config_from_pydantic.oidc.redirect_uri # --- Database Settings --- self.DB_MODE: str = os.getenv("DB_MODE") or \ @@ -195,34 +219,22 @@ "application": { "project_name": self.PROJECT_NAME, "version": self.VERSION, - "log_level": self.LOG_LEVEL + "log_level": self.LOG_LEVEL, + "super_admins": self.SUPER_ADMINS }, "database": { "mode": self.DB_MODE, "local_path": self.DATABASE_URL.replace("sqlite:///./", "") if "sqlite" in self.DATABASE_URL else "data/ai_hub.db" }, - "llm_providers": { - "providers": self.LLM_PROVIDERS - }, "vector_store": { "index_path": self.FAISS_INDEX_PATH, "embedding_dimension": self.EMBEDDING_DIMENSION }, - "embedding_provider": { - "provider": self.EMBEDDING_PROVIDER, - "model_name": self.EMBEDDING_MODEL_NAME, - "api_key": get_val(self.EMBEDDING_API_KEY) - }, - "tts_provider": { - "provider": self.TTS_PROVIDER, - "model_name": self.TTS_MODEL_NAME, - "voice_name": self.TTS_VOICE_NAME, - "api_key": get_val(self.TTS_API_KEY) - }, - "stt_provider": { - "provider": self.STT_PROVIDER, - "model_name": self.STT_MODEL_NAME, - "api_key": get_val(self.STT_API_KEY) + "oidc": { + "client_id": self.OIDC_CLIENT_ID, + "client_secret": self.OIDC_CLIENT_SECRET, + "server_url": self.OIDC_SERVER_URL, + "redirect_uri": self.OIDC_REDIRECT_URI } } diff --git a/ai-hub/app/core/providers/factory.py b/ai-hub/app/core/providers/factory.py index 94edb41..dd336c0 100644 --- a/ai-hub/app/core/providers/factory.py +++ b/ai-hub/app/core/providers/factory.py @@ -11,6 +11,25 @@ import litellm +def resolve_provider_info(name, section, registry, litellm_list=None): + """ + Resolves the base provider type from a potentially suffixed instance name. + Example: 'gemini_2' -> 'gemini' + """ + if name in registry: + return name + if litellm_list and name in litellm_list: + return name + + # Try prefixes for suffixed instances (split by underscore) + if "_" in name: + parts = name.split("_") + # Check longest possible prefix first (important for types like google_gemini) + for i in range(len(parts) - 1, 0, -1): + prefix = "_".join(parts[:i]) + if prefix in registry or (litellm_list and prefix in litellm_list): + return prefix + return name # --- 1. Initialize API Clients from Central Config --- # deepseek_client = AsyncOpenAI(api_key=settings.DEEPSEEK_API_KEY, base_url="https://api.deepseek.com") @@ -47,13 +66,17 @@ """Factory function to get the appropriate, pre-configured LLM provider, with optional system prompt.""" providerKey = api_key_override or _llm_providers.get(provider_name) + # Extract base type (e.g. 'gemini_2' -> 'gemini') + litellm_providers = [p.value for p in litellm.LlmProviders] + base_type = kwargs.get("provider_type") or resolve_provider_info(provider_name, "llm", _llm_providers, litellm_providers) + modelName = model_name if not modelName: modelName = _llm_models.get(provider_name) if not modelName: raise ValueError(f"No model name provided for '{provider_name}'.") - full_model = f'{provider_name}/{modelName}' if '/' not in modelName else modelName + full_model = f'{base_type}/{modelName}' if '/' not in modelName else modelName # Pass the optional system_prompt and kwargs to the GeneralProvider constructor return GeneralProvider(model_name=full_model, api_key=providerKey, system_prompt=system_prompt, **kwargs) @@ -71,16 +94,19 @@ elif provider_name in settings.LLM_PROVIDERS and not is_masked(settings.LLM_PROVIDERS[provider_name].get("api_key")): actual_key = settings.LLM_PROVIDERS[provider_name].get("api_key") - provider_cls = _tts_registry.get(provider_name) + # Resolve base technology type + base_type = kwargs.get("provider_type") or resolve_provider_info(provider_name, "tts", _tts_registry) + + provider_cls = _tts_registry.get(base_type) if provider_cls: - if provider_name == "gcloud_tts": + if base_type == "gcloud_tts": return provider_cls(api_key=actual_key, voice_name=voice_name, **kwargs) return provider_cls(api_key=actual_key, model_name=model_name, voice_name=voice_name, **kwargs) # Fallback to General LiteLLM implementation full_model = model_name - if "/" not in full_model and provider_name not in ["google_gemini", "gcloud_tts"]: - full_model = f"{provider_name}/{model_name}" + if "/" not in full_model and base_type not in ["google_gemini", "gcloud_tts"]: + full_model = f"{base_type}/{model_name}" return GeneralTTSProvider(model_name=full_model, api_key=actual_key, voice_name=voice_name, **kwargs) @@ -97,13 +123,15 @@ elif provider_name in settings.LLM_PROVIDERS and not is_masked(settings.LLM_PROVIDERS[provider_name].get("api_key")): actual_key = settings.LLM_PROVIDERS[provider_name].get("api_key") - provider_cls = _stt_registry.get(provider_name) + base_type = kwargs.get("provider_type") or resolve_provider_info(provider_name, "stt", _stt_registry) + + provider_cls = _stt_registry.get(base_type) if provider_cls: return provider_cls(api_key=actual_key, model_name=model_name, **kwargs) # Fallback to General LiteLLM implementation full_model = model_name - if "/" not in full_model and provider_name not in ["google_gemini"]: - full_model = f"{provider_name}/{model_name}" + if "/" not in full_model and base_type not in ["google_gemini"]: + full_model = f"{base_type}/{model_name}" return GeneralSTTProvider(model_name=full_model, api_key=actual_key, **kwargs) \ No newline at end of file diff --git a/ai-hub/app/core/services/user.py b/ai-hub/app/core/services/user.py index 2411efb..fae20e7 100644 --- a/ai-hub/app/core/services/user.py +++ b/ai-hub/app/core/services/user.py @@ -16,26 +16,43 @@ Saves or updates a user record based on their OIDC ID. If a user with this OIDC ID exists, it returns their existing ID. Otherwise, it creates a new user record. - Returns the user's ID. + The first user to register will be granted the 'admin' role. """ try: # Check if a user with this OIDC ID already exists existing_user = db.query(models.User).filter(models.User.oidc_id == oidc_id).first() if existing_user: - # Update the user's information if needed + # Update the user's information and login activity existing_user.email = email existing_user.username = username + existing_user.last_login_at = datetime.utcnow() + + # Check if user should be promoted to admin based on config + from app.config import settings + if email in settings.SUPER_ADMINS and existing_user.role != "admin": + existing_user.role = "admin" + db.commit() return existing_user.id else: + # Ensure default group exists + default_group = self.get_or_create_default_group(db) + + # Determine role based on SUPER_ADMINS or fallback to user + from app.config import settings + role = "admin" if email in settings.SUPER_ADMINS else "user" + # Create a new user record new_user = models.User( id=str(uuid.uuid4()), # Generate a unique ID for the user oidc_id=oidc_id, email=email, username=username, - created_at=datetime.utcnow() + role=role, + group_id=default_group.id, + created_at=datetime.utcnow(), + last_login_at=datetime.utcnow() ) db.add(new_user) db.commit() @@ -59,6 +76,141 @@ print(f"Database error while fetching user by ID: {e}") return None + def get_all_users(self, db: Session) -> list[models.User]: + """Retrieves all registered users.""" + try: + return db.query(models.User).all() + except SQLAlchemyError as e: + print(f"Database error while fetching all users: {e}") + return [] + + def update_user_role(self, db: Session, user_id: str, new_role: str) -> bool: + """Updates a user's role. Ensures at least one admin exists.""" + try: + user = db.query(models.User).filter(models.User.id == user_id).first() + if not user: + return False + + # If trying to demote an admin, check if they are the only one + if user.role == "admin" and new_role != "admin": + admin_count = db.query(models.User).filter(models.User.role == "admin").count() + if admin_count <= 1: + # Cannot demote the last admin + return False + + user.role = new_role + db.commit() + return True + except SQLAlchemyError as e: + db.rollback() + print(f"Database error while updating user role: {e}") + return False + + def get_system_settings(self, db: Session) -> dict: + """Retrieves global AI provider settings from the first super admin found.""" + try: + from app.config import settings + super_admin_email = settings.SUPER_ADMINS[0] if settings.SUPER_ADMINS else None + if super_admin_email: + admin_user = db.query(models.User).filter(models.User.email == super_admin_email).first() + if admin_user and admin_user.preferences: + return admin_user.preferences + return {} + except SQLAlchemyError: + return {} + + # --- Group Management Methods --- + + def get_or_create_default_group(self, db: Session) -> models.Group: + """Ensures the 'ungrouped' default group exists.""" + default_group = db.query(models.Group).filter(models.Group.id == "ungrouped").first() + if not default_group: + default_group = models.Group( + id="ungrouped", + name="Ungrouped", + description="Default group for new users", + policy={} + ) + db.add(default_group) + db.commit() + db.refresh(default_group) + return default_group + + def get_all_groups(self, db: Session) -> list[models.Group]: + """Returns all existing groups, without lazy-loading user relationships.""" + from sqlalchemy.orm import noload + return db.query(models.Group).options(noload(models.Group.users)).all() + + def get_group_by_id(self, db: Session, group_id: str) -> Optional[models.Group]: + """Fetches a group by ID, without lazy-loading user relationships.""" + from sqlalchemy.orm import noload + return db.query(models.Group).options(noload(models.Group.users)).filter(models.Group.id == group_id).first() + + def create_group(self, db: Session, name: str, description: str = None, policy: dict = None) -> Optional[models.Group]: + """Creates a new user group. Returns None if a group with the same name already exists.""" + existing = db.query(models.Group).filter( + models.Group.name.ilike(name.strip()) + ).first() + if existing: + return None # Signals a name conflict + group = models.Group( + id=str(uuid.uuid4()), + name=name.strip(), + description=description, + policy=policy or {} + ) + db.add(group) + db.commit() + db.refresh(group) + return group + + def update_group(self, db: Session, group_id: str, name: str = None, description: str = None, policy: dict = None) -> Optional[models.Group]: + """Updates group metadata or policy. Returns False (bool) if name conflicts with another group.""" + group = self.get_group_by_id(db, group_id) + if not group: + return None + if name: + name = name.strip() + # Check for name conflict with a DIFFERENT group + conflict = db.query(models.Group).filter( + models.Group.name.ilike(name), + models.Group.id != group_id + ).first() + if conflict: + return False # Signals a name conflict (distinct from None = not found) + group.name = name + if description is not None: group.description = description + if policy is not None: group.policy = policy + db.commit() + db.refresh(group) + return group + + def delete_group(self, db: Session, group_id: str) -> bool: + """Deletes a group. Moves its users to 'ungrouped'.""" + if group_id == "ungrouped": + return False # Cannot delete the default group + group = self.get_group_by_id(db, group_id) + if not group: + return False + + default_group = self.get_or_create_default_group(db) + # Move users + db.query(models.User).filter(models.User.group_id == group_id).update({"group_id": default_group.id}) + + db.delete(group) + db.commit() + return True + + def assign_user_to_group(self, db: Session, user_id: str, group_id: str) -> bool: + """Assigns a user to a group.""" + user = self.get_user_by_id(db, user_id) + group = self.get_group_by_id(db, group_id) + if not user or not group: + return False + user.group_id = group_id + db.commit() + return True + # --- Framework-dependent helper functions --- # These functions are placeholders and would need to be integrated with your # specific web framework (e.g., FastAPI, Flask, Django). diff --git a/ai-hub/app/db/models.py b/ai-hub/app/db/models.py index 46d9c5c..f7bee7f 100644 --- a/ai-hub/app/db/models.py +++ b/ai-hub/app/db/models.py @@ -25,11 +25,21 @@ email = Column(String, nullable=True) # The user's display name. username = Column(String, nullable=True) + # Enterprise profile info + full_name = Column(String, nullable=True) + role = Column(String, default="user", nullable=False) # 'admin' or 'user' + group_id = Column(String, ForeignKey('groups.id'), nullable=True) + avatar_url = Column(String, nullable=True) # Timestamp for when the user account was created. created_at = Column(DateTime, default=datetime.utcnow) + # Track platform engagement for auditing + last_login_at = Column(DateTime, default=datetime.utcnow) # User's preferences/settings (e.g. LLM/TTS/STT configs) preferences = Column(JSON, default={}, nullable=True) + # Relationship to Group + group = relationship("Group", back_populates="users") + # Defines a one-to-many relationship with the Session table. # 'back_populates' creates a link back to the User model from the Session model. sessions = relationship("Session", back_populates="user", cascade="all, delete-orphan") @@ -37,6 +47,26 @@ def __repr__(self): return f"" +class Group(Base): + """ + SQLAlchemy model for the 'groups' table. + Groups define policies for AI provider access. + """ + __tablename__ = 'groups' + + id = Column(String, primary_key=True, index=True) + name = Column(String, unique=True, nullable=False) + description = Column(String, nullable=True) + # Policy: which providers are allowed for this group + # Example: {"llm": ["openai", "gemini"], "tts": ["gcloud_tts"], "stt": ["google_gemini"]} + policy = Column(JSON, default={}, nullable=True) + created_at = Column(DateTime, default=datetime.utcnow) + + users = relationship("User", back_populates="group") + + def __repr__(self): + return f"" + class Session(Base): """ SQLAlchemy model for the 'sessions' table. diff --git a/ui/client-app/src/App.js b/ui/client-app/src/App.js index ec2a2e9..8745c91 100644 --- a/ui/client-app/src/App.js +++ b/ui/client-app/src/App.js @@ -6,7 +6,8 @@ import CodingAssistantPage from "./pages/CodingAssistantPage"; import LoginPage from "./pages/LoginPage"; import SettingsPage from "./pages/SettingsPage"; -import { getUserStatus, logout } from "./services/apiService"; +import ProfilePage from "./pages/ProfilePage"; +import { getUserStatus, logout, getUserProfile } from "./services/apiService"; const Icon = ({ path, onClick, className }) => ( { const urlParams = new URLSearchParams(window.location.search); @@ -53,12 +55,17 @@ if (status.is_logged_in) { setIsLoggedIn(true); setUserId(storedUserId); + // Fetch profile on success + const profile = await getUserProfile(); + setUserProfile(profile); + if (currentPage === "login") { setCurrentPage("home"); } } else { setIsLoggedIn(false); setUserId(null); + setUserProfile(null); localStorage.removeItem("userId"); if (authenticatedPages.includes(currentPage)) { setCurrentPage("login"); @@ -68,6 +75,7 @@ console.error("Failed to check user status:", error); setIsLoggedIn(false); setUserId(null); + setUserProfile(null); localStorage.removeItem("userId"); if (authenticatedPages.includes(currentPage)) { setCurrentPage("login"); @@ -76,6 +84,7 @@ } else { setIsLoggedIn(false); setUserId(null); + setUserProfile(null); if (authenticatedPages.includes(currentPage)) { setCurrentPage("login"); } @@ -90,6 +99,7 @@ await logout(); setIsLoggedIn(false); setUserId(null); + setUserProfile(null); localStorage.removeItem("userId"); setCurrentPage("home"); } catch (error) { @@ -119,7 +129,13 @@ case "coding-assistant": return ; case "settings": + // Only admins can see global settings + if (userProfile?.role !== "admin") { + return ; + } return ; + case "profile": + return ; case "login": return ; default: @@ -136,6 +152,7 @@ onNavigate={handleNavigate} onLogout={handleLogout} isLoggedIn={isLoggedIn} + user={userProfile} Icon={Icon} /> )} diff --git a/ui/client-app/src/components/Navbar.js b/ui/client-app/src/components/Navbar.js index 3d8db68..c20d74e 100644 --- a/ui/client-app/src/components/Navbar.js +++ b/ui/client-app/src/components/Navbar.js @@ -1,7 +1,7 @@ import React from 'react'; import { ReactComponent as Logo } from '../logo.svg'; -const Navbar = ({ isOpen, onToggle, onNavigate, onLogout, isLoggedIn, Icon }) => { +const Navbar = ({ isOpen, onToggle, onNavigate, onLogout, isLoggedIn, user, Icon }) => { const navItems = [ { name: "Home", icon: "M10 20v-6h4v6h5v-8h3L12 3 2 12h3v8z", page: "home" }, { name: "Voice Chat", icon: "M12 1a3 3 0 0 1 3 3v7a3 3 0 1 1-6 0V4a3 3 0 0 1 3-3zm5 10a5 5 0 0 1-10 0H5a7 7 0 0 0 14 0h-2zm-5 11v-4h-2v4h2z", page: "voice-chat" }, @@ -48,9 +48,9 @@ if (!isDisabled) onNavigate(item.page); }} className={`flex items-center space-x-4 p-2 rounded-lg transition-colors duration-200 ${isDisabled - ? "cursor-not-allowed text-gray-400 dark:text-gray-500" - : "cursor-pointer hover:bg-gray-200 dark:hover:bg-gray-700" - }`} + ? "cursor-not-allowed text-gray-400 dark:text-gray-500" + : "cursor-pointer hover:bg-gray-200 dark:hover:bg-gray-700" + } ${item.page === 'profile' && !isOpen ? 'hidden' : ''}`} > {isOpen && ( @@ -63,22 +63,61 @@ })} - {/* Bottom Section: Settings and Login/User */} -
- {/* Settings Button */} -
onNavigate("settings")} - className="flex items-center space-x-4 p-2 text-gray-700 dark:text-gray-300 rounded-lg cursor-pointer hover:bg-gray-200 dark:hover:bg-gray-700 transition-colors duration-200" - > - - {isOpen && Settings} -
+ {/* Bottom Section: User, Settings and Logout */} +
+ {/* Settings Button - Only shown to Admin */} + {isLoggedIn && user?.role === "admin" && ( +
onNavigate("settings")} + className="flex items-center space-x-4 p-2 text-gray-700 dark:text-gray-300 rounded-lg cursor-pointer hover:bg-gray-200 dark:hover:bg-gray-700 transition-colors duration-200" + > + + {isOpen && Settings} +
+ )} + + {/* User Profile Summary - Always show avatar, details when expanded */} + {isLoggedIn && user && ( +
onNavigate("profile")} + className={`flex items-center p-2 rounded-xl transition-colors cursor-pointer group ${isOpen + ? 'space-x-3 bg-gray-50 dark:bg-gray-700/50 hover:bg-gray-100 dark:hover:bg-gray-700' + : 'justify-center hover:bg-gray-200 dark:hover:bg-gray-700' + }`} + title={!isOpen ? (user.full_name || user.username || "Profile") : ""} + > +
+ {user.avatar_url ? ( + User + ) : ( + (user.full_name || user.username || "U")[0].toUpperCase() + )} +
+ + {isOpen && ( + <> +
+

+ {user.full_name || user.username || "User"} +

+

+ {user.role || 'user'} +

+
+ + + + + )} +
+ )} {/* Conditional Login/Logout Button */} {isLoggedIn ? (
{isOpen && Logout} diff --git a/ui/client-app/src/pages/ProfilePage.js b/ui/client-app/src/pages/ProfilePage.js new file mode 100644 index 0000000..cdb5e17 --- /dev/null +++ b/ui/client-app/src/pages/ProfilePage.js @@ -0,0 +1,282 @@ +import React, { useState, useEffect } from 'react'; +import { getUserProfile, updateUserProfile, getUserConfig, updateUserConfig } from '../services/apiService'; + +const ProfilePage = () => { + const [profile, setProfile] = useState(null); + const [config, setConfig] = useState(null); + const [available, setAvailable] = useState({ llm: [], tts: [], stt: [] }); + const [loading, setLoading] = useState(true); + const [saving, setSaving] = useState(false); + const [message, setMessage] = useState({ type: '', text: '' }); + const [editData, setEditData] = useState({ + full_name: '', + username: '', + avatar_url: '' + }); + + useEffect(() => { + loadData(); + }, []); + + const loadData = async () => { + try { + setLoading(true); + const [prof, conf] = await Promise.all([ + getUserProfile(), + getUserConfig() + ]); + setProfile(prof); + setConfig(conf.preferences); + setAvailable({ + llm: Object.entries(conf.effective?.llm?.providers || {}).map(([id, p]) => ({ id, label: id, model: p?.model || null })), + tts: Object.entries(conf.effective?.tts?.providers || {}).map(([id, p]) => ({ id, label: id, model: p?.model || null, voice: p?.voice || null })), + stt: Object.entries(conf.effective?.stt?.providers || {}).map(([id, p]) => ({ id, label: id, model: p?.model || null })) + }); + setEditData({ + full_name: prof.full_name || '', + username: prof.username || '', + avatar_url: prof.avatar_url || '' + }); + } catch (err) { + console.error("Failed to load profile data", err); + setMessage({ type: 'error', text: 'Failed to load profile.' }); + } finally { + setLoading(false); + } + }; + + const handleProfileSubmit = async (e) => { + e.preventDefault(); + try { + setSaving(true); + const updated = await updateUserProfile(editData); + setProfile(updated); + setMessage({ type: 'success', text: 'Profile updated successfully!' }); + setTimeout(() => setMessage({ type: '', text: '' }), 3000); + } catch (err) { + setMessage({ type: 'error', text: 'Failed to update profile.' }); + } finally { + setSaving(false); + } + }; + + const handlePreferenceChange = async (section, providerId) => { + try { + const newConfig = { + ...config, + [section]: { ...config[section], active_provider: providerId } + }; + await updateUserConfig(newConfig); + setConfig(newConfig); + setMessage({ type: 'success', text: `Primary ${section.toUpperCase()} set to ${providerId}` }); + setTimeout(() => setMessage({ type: '', text: '' }), 3000); + } catch (err) { + setMessage({ type: 'error', text: 'Failed to update preferences.' }); + } + }; + + if (loading) { + return ( +
+
Loading identity...
+
+ ); + } + + const inputClass = "w-full border border-gray-300 dark:border-gray-600 rounded-xl p-3 bg-white dark:bg-gray-800 text-gray-900 dark:text-gray-100 placeholder-gray-400 focus:outline-none focus:ring-2 focus:ring-indigo-500 transition-all"; + const labelClass = "block text-sm font-bold text-gray-700 dark:text-gray-300 mb-2 ml-1"; + + return ( +
+
+
+
+
+ {profile.avatar_url ? Avatar : profile.email[0].toUpperCase()} +
+
+
+

+ {profile.full_name || profile.username || 'Citizen'} +

+

{profile.email}

+

Member since {new Date(profile.created_at).toLocaleDateString()}

+
+ + {profile.role} + + {profile.group_name && ( + + {profile.group_name} Group + + )} +
+
+
+ + {message.text && ( +
+ {message.text} +
+ )} + +
+ {/* General Information */} +
+

+ + General Information +

+
+
+ + setEditData({ ...editData, full_name: e.target.value })} + placeholder="Enter your full name" + /> +
+
+ + setEditData({ ...editData, username: e.target.value })} + placeholder="Display name" + /> +
+
+ +

+ {profile.group_name || 'Ungrouped'} +

+

Groups are managed by your administrator.

+
+
+ +
+
+
+ + {/* Service Preferences */} +
+

+ + Service Preferences +

+
+ + + +
+
+ Status: +
+
+ Verified +
+
+
+ Failed +
+
+
+ Untested +
+
+

+ These selections determine which AI service is used by default when you interact with the hub. Individual session settings may override these. +

+
+
+
+
+ ); +}; + +const ServiceSelect = ({ label, section, providers, active, statuses, onChange }) => { + return ( +
+ +
+ {providers.length === 0 ? ( +

No providers configured yet.

+ ) : ( + providers.map(p => { + const statusKey = `${section}_${p.id}`; + const status = statuses?.[statusKey]; + const statusColor = status === 'success' ? 'bg-emerald-500' : status === 'error' ? 'bg-red-500' : 'bg-gray-300'; + + const baseType = p.id.split('_')[0]; + const suffix = p.id.includes('_') ? p.id.split('_').slice(1).join('_') : ''; + const formattedLabel = baseType.charAt(0).toUpperCase() + baseType.slice(1) + (suffix ? ` (${suffix})` : ''); + + const modelDisplay = p.model || null; + const voiceDisplay = section === 'tts' ? (p.voice || null) : null; + const isActive = active === p.id; + + return ( + + ); + }) + )} +
+
+ ); +}; + +export default ProfilePage; diff --git a/ui/client-app/src/pages/SettingsPage.js b/ui/client-app/src/pages/SettingsPage.js index 420fc22..0f45545 100644 --- a/ui/client-app/src/pages/SettingsPage.js +++ b/ui/client-app/src/pages/SettingsPage.js @@ -1,5 +1,10 @@ import React, { useState, useEffect, useRef } from 'react'; -import { getUserConfig, updateUserConfig, exportUserConfig, importUserConfig, verifyProvider, getProviderModels, getAllProviders, getVoices } from '../services/apiService'; +import { + getUserConfig, updateUserConfig, exportUserConfig, importUserConfig, + verifyProvider, getProviderModels, getAllProviders, getVoices, + getAdminUsers, updateUserRole, getAdminGroups, createAdminGroup, + updateAdminGroup, deleteAdminGroup, updateUserGroup +} from '../services/apiService'; const SettingsPage = () => { const [config, setConfig] = useState({ llm: {}, tts: {}, stt: {} }); @@ -7,7 +12,9 @@ const [loading, setLoading] = useState(true); const [saving, setSaving] = useState(false); const [message, setMessage] = useState({ type: '', text: '' }); - const [activeTab, setActiveTab] = useState('llm'); + const [activeConfigTab, setActiveConfigTab] = useState('llm'); + const [activeAdminTab, setActiveAdminTab] = useState('groups'); + const [userSearch, setUserSearch] = useState(''); const [expandedProvider, setExpandedProvider] = useState(null); const [selectedNewProvider, setSelectedNewProvider] = useState(''); const [verifying, setVerifying] = useState(null); @@ -17,21 +24,26 @@ const [voiceList, setVoiceList] = useState([]); const [showVoicesModal, setShowVoicesModal] = useState(false); const [voicesLoading, setVoicesLoading] = useState(false); + const [allUsers, setAllUsers] = useState([]); + const [usersLoading, setUsersLoading] = useState(false); + const [allGroups, setAllGroups] = useState([]); + const [groupsLoading, setGroupsLoading] = useState(false); + const [editingGroup, setEditingGroup] = useState(null); + const [addingSection, setAddingSection] = useState(null); + const [addForm, setAddForm] = useState({ type: '', suffix: '', model: '', cloneFrom: '' }); const fileInputRef = useRef(null); - const handleViewVoices = async (apiKey = null) => { + const handleViewVoices = async (providerId, apiKey = null) => { setShowVoicesModal(true); - // Force refresh if an explicit apiKey is provided, otherwise use cache if available - if (voiceList.length === 0 || apiKey) { - setVoicesLoading(true); - try { - const voices = await getVoices(apiKey); - setVoiceList(voices); - } catch (e) { - console.error(e); - } finally { - setVoicesLoading(false); - } + setVoicesLoading(true); + setVoiceList([]); // Clear previous list while loading + try { + const voices = await getVoices(providerId, apiKey); + setVoiceList(voices); + } catch (e) { + console.error(e); + } finally { + setVoicesLoading(false); } }; @@ -57,8 +69,92 @@ useEffect(() => { loadConfig(); + loadUsers(); + loadGroups(); }, []); + const loadGroups = async () => { + try { + setGroupsLoading(true); + const groups = await getAdminGroups(); + setAllGroups(groups); + } catch (e) { + console.error("Failed to load groups", e); + } finally { + setGroupsLoading(false); + } + }; + + const loadUsers = async () => { + try { + setUsersLoading(true); + const users = await getAdminUsers(); + setAllUsers(users); + } catch (e) { + console.error("Failed to load users", e); + } finally { + setUsersLoading(false); + } + }; + + const handleRoleToggle = async (user) => { + const newRole = user.role === 'admin' ? 'user' : 'admin'; + try { + await updateUserRole(user.id, newRole); + setMessage({ type: 'success', text: `Role for ${user.username || user.email} updated to ${newRole}` }); + loadUsers(); // refresh list + setTimeout(() => setMessage({ type: '', text: '' }), 3000); + } catch (e) { + setMessage({ type: 'error', text: e.message || 'Failed to update role' }); + } + }; + + const handleGroupChange = async (targetUserId, groupId) => { + try { + await updateUserGroup(targetUserId, groupId); + setMessage({ type: 'success', text: `User group updated successfully` }); + loadUsers(); + setTimeout(() => setMessage({ type: '', text: '' }), 3000); + } catch (e) { + setMessage({ type: 'error', text: e.message || 'Failed to update group' }); + } + }; + + const handleSaveGroup = async (e) => { + e.preventDefault(); + try { + setSaving(true); + if (editingGroup.id === 'new') { + const { id, ...data } = editingGroup; + await createAdminGroup(data); + } else { + await updateAdminGroup(editingGroup.id, editingGroup); + } + setMessage({ type: 'success', text: 'Group saved successfully!' }); + setEditingGroup(null); + loadGroups(); + loadUsers(); + setTimeout(() => setMessage({ type: '', text: '' }), 3000); + } catch (e) { + setMessage({ type: 'error', text: e.message || 'Failed to save group' }); + } finally { + setSaving(false); + } + }; + + const handleDeleteGroup = async (groupId) => { + if (!window.confirm("Are you sure? Users in this group will be moved to 'Ungrouped'.")) return; + try { + await deleteAdminGroup(groupId); + setMessage({ type: 'success', text: 'Group deleted' }); + loadGroups(); + loadUsers(); + setTimeout(() => setMessage({ type: '', text: '' }), 3000); + } catch (e) { + setMessage({ type: 'error', text: e.message || 'Failed to delete group' }); + } + }; + const loadConfig = async () => { try { setLoading(true); @@ -68,6 +164,7 @@ tts: data.preferences?.tts || {}, stt: data.preferences?.stt || {} }); + setProviderStatuses(data.preferences?.statuses || {}); setEffective(data.effective || { llm: {}, tts: {}, stt: {} }); setMessage({ type: '', text: '' }); } catch (err) { @@ -93,18 +190,61 @@ } }, [expandedProvider, fetchedModels]); + // Pre-fetch model list for the selected type in the add-new-instance form + useEffect(() => { + if (addingSection && addForm.type) { + const fetchKey = `${addingSection}_${addForm.type}`; + if (!fetchedModels[fetchKey]) { + getProviderModels(addForm.type, addingSection).then(models => { + setFetchedModels(prev => ({ ...prev, [fetchKey]: models })); + }).catch(() => { }); + } + } + }, [addingSection, addForm.type, fetchedModels]); + const handleSave = async (e) => { e.preventDefault(); try { setSaving(true); - const updated = await updateUserConfig(config); + setMessage({ type: '', text: 'Saving and verifying configuration...' }); + + // Before saving, let's identify any "active" providers that have been modified + // (i.e. they are grey/have no status) and run a quick verification for them. + const updatedStatuses = { ...providerStatuses }; + const sections = ['llm', 'tts', 'stt']; + + for (const section of sections) { + const activeId = config[section]?.active_provider; + if (activeId && !updatedStatuses[`${section}_${activeId}`]) { + const providerPrefs = config[section]?.providers?.[activeId]; + if (providerPrefs && providerPrefs.api_key) { + try { + const res = await verifyProvider(section, { + provider_name: activeId, + provider_type: providerPrefs.provider_type || activeId.split('_')[0], + api_key: providerPrefs.api_key, + model: providerPrefs.model, + voice: providerPrefs.voice + }); + updatedStatuses[`${section}_${activeId}`] = res.success ? 'success' : 'error'; + } catch (err) { + updatedStatuses[`${section}_${activeId}`] = 'error'; + } + } + } + } + + setProviderStatuses(updatedStatuses); + const payload = { ...config, statuses: updatedStatuses }; + await updateUserConfig(payload); + // reload after save to get latest effective config await loadConfig(); - setMessage({ type: 'success', text: 'Settings saved successfully!' }); + setMessage({ type: 'success', text: 'Settings saved and verified successfully!' }); setTimeout(() => setMessage({ type: '', text: '' }), 3000); } catch (err) { console.error("Error saving config:", err); - setMessage({ type: 'error', text: 'Failed to save configuration.' }); + setMessage({ type: 'error', text: 'Failed to save configuration: ' + (err.message || "Unknown error") }); } finally { setSaving(false); } @@ -128,6 +268,30 @@ } }; + const handleGrantToAll = async (section, providerId) => { + if (!window.confirm(`Are you sure? This will whitelist ${providerId} for ALL existing groups.`)) return; + try { + setSaving(true); + setMessage({ type: '', text: `Syncing group policies for ${providerId}...` }); + for (const group of allGroups) { + const currentPolicy = group.policy || { llm: [], tts: [], stt: [] }; + const sectionList = currentPolicy[section] || []; + if (!sectionList.includes(providerId)) { + const newPolicy = { ...currentPolicy, [section]: [...sectionList, providerId] }; + await updateAdminGroup(group.id, { ...group, policy: newPolicy }); + } + } + await loadGroups(); + setMessage({ type: 'success', text: `Global access granted for ${providerId}!` }); + setTimeout(() => setMessage({ type: '', text: '' }), 3000); + } catch (e) { + console.error(e); + setMessage({ type: 'error', text: 'Failed to sync group access.' }); + } finally { + setSaving(false); + } + }; + const handleImport = async (e) => { const file = e.target.files[0]; if (!file) return; @@ -179,30 +343,67 @@ const renderProviderSection = (sectionKey, providerDefs, allowVoice = false) => { const activeProviderIds = new Set([ - ...Object.keys(config[sectionKey]?.providers || {}) + ...Object.keys(config[sectionKey]?.providers || {}), + ...Object.keys(effective[sectionKey]?.providers || {}) ]); + const activeProviders = Array.from(activeProviderIds).map(id => { + const baseP = providerDefs.find(p => p.id === id); + if (baseP) return baseP; + // Handle suffixed IDs (e.g. gemini_2) + const parts = id.split('_'); + let baseId = parts[0]; + // Special case for google_gemini + if (id.startsWith('google_gemini_')) baseId = 'google_gemini'; - const activeProviders = providerDefs.filter(p => activeProviderIds.has(p.id)); - const availableToAdd = providerDefs.filter(p => !activeProviderIds.has(p.id)); + const baseDef = providerDefs.find(p => p.id === baseId); + const suffix = id.replace(baseId + '_', ''); + return { + id: id, + label: baseDef ? `${baseDef.label} (${suffix})` : id + }; + }).sort((a, b) => a.label.localeCompare(b.label)); + const currentActivePrimary = config[sectionKey]?.active_provider || effective[sectionKey]?.active_provider || ''; - const handleAddProvider = () => { - if (!selectedNewProvider) return; - const newProviders = { ...(config[sectionKey]?.providers || {}) }; - newProviders[selectedNewProvider] = { api_key: '', model: '' }; - handleChange(sectionKey, 'providers', newProviders); - // auto-set primary if first one - if (!currentActivePrimary) { - handleChange(sectionKey, 'active_provider', selectedNewProvider); + const handleAddInstance = () => { + if (!addForm.type) return; + const newId = addForm.suffix ? `${addForm.type}_${addForm.suffix.toLowerCase().replace(/\s+/g, '_')}` : addForm.type; + + if (activeProviderIds.has(newId)) { + setMessage({ type: 'error', text: `Instance "${newId}" already exists.` }); + return; } - setExpandedProvider(`${sectionKey}_${selectedNewProvider}`); - setSelectedNewProvider(''); + + // Build initial provider data + const initData = { provider_type: addForm.type }; + + // Store a _clone_from marker — the backend will resolve the real API key + // from the source provider. We never have the plaintext key on the frontend. + if (addForm.cloneFrom) { + initData._clone_from = addForm.cloneFrom; + } + + // Pre-set model if specified + if (addForm.model.trim()) initData.model = addForm.model.trim(); + + const newProviders = { ...(config[sectionKey]?.providers || {}) }; + newProviders[newId] = initData; + handleChange(sectionKey, 'providers', newProviders, newId); + setAddingSection(null); + setAddForm({ type: '', suffix: '', model: '', cloneFrom: '' }); + setExpandedProvider(`${sectionKey}_${newId}`); }; + // Existing instances of the same type that have an API key — for cloning + const cloneableSources = Array.from(activeProviderIds).filter(id => { + const baseType = id.startsWith('google_gemini') ? 'google_gemini' : id.split('_')[0]; + return baseType === addForm.type && id !== addForm.type + (addForm.suffix ? '_' + addForm.suffix.toLowerCase().replace(/\s+/g, '_') : ''); + }); + const handleDeleteProvider = (providerId) => { const newProviders = { ...((config[sectionKey] && config[sectionKey].providers) || {}) }; delete newProviders[providerId]; - handleChange(sectionKey, 'providers', newProviders); + handleChange(sectionKey, 'providers', newProviders, providerId); if (currentActivePrimary === providerId) { handleChange(sectionKey, 'active_provider', Object.keys(newProviders)[0] || ''); } @@ -235,16 +436,21 @@ setMessage({ type: '', text: '' }); const payload = { provider_name: providerId, + provider_type: providerPrefs.provider_type || providerId.split('_')[0], api_key: providerPrefs.api_key, model: providerPrefs.model, voice: providerPrefs.voice }; const res = await verifyProvider(sectionKey, payload); if (res.success) { - setProviderStatuses(prev => ({ ...prev, [`${sectionKey}_${providerId}`]: 'success' })); + const newStatuses = { ...providerStatuses, [`${sectionKey}_${providerId}`]: 'success' }; + setProviderStatuses(newStatuses); + await updateUserConfig({ ...config, statuses: newStatuses }); setMessage({ type: 'success', text: `Verified ${providerId} successfully!` }); } else { - setProviderStatuses(prev => ({ ...prev, [`${sectionKey}_${providerId}`]: 'error' })); + const newStatuses = { ...providerStatuses, [`${sectionKey}_${providerId}`]: 'error' }; + setProviderStatuses(newStatuses); + await updateUserConfig({ ...config, statuses: newStatuses }); setMessage({ type: 'error', text: `Verification failed for ${providerId}: ${res.message}` }); } } catch (err) { @@ -257,28 +463,151 @@ return (
-
-
- - + {/* Header & Add Form */} +
+
+
+ +
+
+

Resource Instances

+

Configure specific account credentials

+
- + + {addingSection !== sectionKey ? ( + + ) : ( +
+

New Provider Instance

+ + {/* Row 1: Type + Label suffix */} +
+
+ + +
+
+ + setAddForm({ ...addForm, suffix: e.target.value })} + className="w-full border border-gray-200 dark:border-gray-700 rounded-lg px-3 py-2 bg-white dark:bg-gray-900 text-sm text-gray-800 dark:text-gray-100 focus:ring-2 focus:ring-indigo-500 outline-none" + /> + {addForm.type && addForm.suffix && ( +

+ ID: {addForm.type}_{addForm.suffix.toLowerCase().replace(/\s+/g, '_')} +

+ )} +
+
+ + {/* Row 2: Model + Clone-from */} +
+
+ + {addForm.type && fetchedModels[`${sectionKey}_${addForm.type}`]?.length > 0 ? ( + + ) : ( + setAddForm({ ...addForm, model: e.target.value })} + className="w-full border border-gray-200 dark:border-gray-700 rounded-lg px-3 py-2 bg-white dark:bg-gray-900 text-sm text-gray-800 dark:text-gray-100 focus:ring-2 focus:ring-indigo-500 outline-none" + /> + )} +
+ {cloneableSources.length > 0 && ( +
+ + + {addForm.cloneFrom && ( +

✓ API key will be copied from "{addForm.cloneFrom}" on save

+ )} +
+ )} +
+ + {/* Action buttons */} +
+ + +
+
+ )} +
+ +
+ Status Legend: +
+
+ Verified +
+
+
+ Failed +
+
+
+ Not Tested +
@@ -305,8 +634,17 @@ )}
-
- {providerEff.api_key && providerEff.api_key !== 'None' ? `Key: ${providerEff.api_key}` : 'Sys Default'} +
+ {providerPrefs._clone_from ? ( + + + Key from {providerPrefs._clone_from} + + ) : (providerPrefs.model || providerEff.model) ? ( + + {providerPrefs.model || providerEff.model} + + ) : null}
{!isActivePrimary && ( + @@ -350,27 +700,59 @@
- - { - const newProviders = { ...(config[sectionKey]?.providers || {}) }; - newProviders[provider.id] = { ...providerPrefs, api_key: e.target.value }; - handleChange(sectionKey, 'providers', newProviders, provider.id); - }} - onFocus={(e) => { - // Auto-clear masked string on focus so they can start typing real key cleanly - if (e.target.value.includes('***')) { - const newProviders = { ...(config[sectionKey]?.providers || {}) }; - newProviders[provider.id] = { ...providerPrefs, api_key: '' }; - handleChange(sectionKey, 'providers', newProviders, provider.id); - } - }} - placeholder="sk-..." - className={inputClass} - /> -

Specify your API key for {provider.label}.

+
+ + {providerPrefs._clone_from && ( + + )} +
+ {providerPrefs._clone_from ? ( +
+
+ + Inherited from “{providerPrefs._clone_from}” + Linked +
+

+ API key is managed by “{providerPrefs._clone_from}”. Click Unlink above to set an independent key. +

+
+ ) : ( + <> + { + const newProviders = { ...(config[sectionKey]?.providers || {}) }; + newProviders[provider.id] = { ...providerPrefs, api_key: e.target.value }; + handleChange(sectionKey, 'providers', newProviders, provider.id); + }} + onFocus={(e) => { + if (e.target.value.includes('***')) { + const newProviders = { ...(config[sectionKey]?.providers || {}) }; + newProviders[provider.id] = { ...providerPrefs, api_key: '' }; + handleChange(sectionKey, 'providers', newProviders, provider.id); + } + }} + placeholder="sk-..." + className={inputClass} + /> +

Specify your API key for {provider.label}.

+ + )}
{!(sectionKey === 'tts' && provider.id === 'gcloud_tts') && ( @@ -416,7 +798,7 @@
- +
+ (u.username || '').toLowerCase().includes(userSearch.toLowerCase()) || + (u.email || '').toLowerCase().includes(userSearch.toLowerCase()) || + (u.full_name || '').toLowerCase().includes(userSearch.toLowerCase()) + ); + + const sortedGroups = [...allGroups].sort((a, b) => { + if (a.id === 'ungrouped') return -1; + if (b.id === 'ungrouped') return 1; + return a.name.localeCompare(b.name); + }); + return (
@@ -554,62 +948,396 @@
)} -
- {/* Tabs */} -
- {['llm', 'tts', 'stt'].map((tab) => ( - - ))} +
+ {/* Card 1: AI Provider Configuration */} +
+
+

+ + AI Resource Configuration +

+

Manage your providers, models, and API keys

+
+ + {/* Config Tabs */} +
+ {['llm', 'tts', 'stt'].map((tab) => ( + + ))} +
+ +
+ {/* LLM Settings */} + {activeConfigTab === 'llm' && ( +
+ {renderProviderSection('llm', providerLists.llm, false)} +
+ )} + + {/* TTS Settings */} + {activeConfigTab === 'tts' && ( +
+ {renderProviderSection('tts', providerLists.tts, true)} +
+ )} + + {/* STT Settings */} + {activeConfigTab === 'stt' && ( +
+ {renderProviderSection('stt', providerLists.stt, false)} +
+ )} + +
+ +
+
-
- - {/* LLM Settings */} - {activeTab === 'llm' && ( -
- {renderProviderSection('llm', providerLists.llm, false)} -
- )} - - {/* TTS Settings */} - {activeTab === 'tts' && ( -
- {renderProviderSection('tts', providerLists.tts, true)} -
- )} - - {/* STT Settings */} - {activeTab === 'stt' && ( -
- {renderProviderSection('stt', providerLists.stt, false)} -
- )} - -
- + {/* Card 2: Team & Access Management */} +
+
+

+ + Identity & Access Governance +

+

Define groups, policies, and manage members

- + + {/* Admin Tabs */} +
+ {['groups', 'users'].map((tab) => ( + + ))} +
+ +
+ {/* Groups Management */} + {activeAdminTab === 'groups' && ( +
+ {!editingGroup ? ( +
+
+

+ Registered Groups +

+ +
+ +
+ {sortedGroups.map((g) => ( +
+
+

+ {g.id === 'ungrouped' ? 'Standard / Guest Policy' : g.name} + {g.id === 'ungrouped' && Global Fallback} +

+

+ {g.id === 'ungrouped' ? 'Baseline access for all unassigned members.' : (g.description || 'No description')} +

+
+ {['llm', 'tts', 'stt'].map(section => ( +
+ {section} +
+ {g.policy?.[section]?.length > 0 ? ( + g.policy?.[section].slice(0, 3).map(p => ( +
+ {p[0].toUpperCase()} +
+ )) + ) : ( + None + )} + {g.policy?.[section]?.length > 3 && ( +
+ +{g.policy?.[section].length - 3} +
+ )} +
+
+ ))} +
+
+
+ + {g.id !== 'ungrouped' && ( + + )} +
+
+ ))} +
+
+ ) : ( +
+ {/* (Group editing form - unchanged logic, just cleaner container) */} +
+ +

+ {editingGroup.id === 'new' ? 'New Group Policy' : `Edit: ${editingGroup.id === 'ungrouped' ? 'Standard / Guest Policy' : editingGroup.name}`} +

+ {editingGroup.id === 'ungrouped' && ( + + + System Group + + )} +
+ +
+
+
+ + editingGroup.id !== 'ungrouped' && setEditingGroup({ ...editingGroup, name: e.target.value })} + readOnly={editingGroup.id === 'ungrouped'} + placeholder="Engineering, Designers, etc." + className={`${inputClass} ${editingGroup.id === 'ungrouped' + ? 'opacity-60 cursor-not-allowed bg-gray-100 dark:bg-gray-700 text-gray-500 dark:text-gray-400' + : (editingGroup.name.trim() && + allGroups.some(g => g.id !== editingGroup.id && g.name.toLowerCase() === editingGroup.name.trim().toLowerCase()) + ? '!border-red-400 dark:!border-red-600 !ring-red-300' + : '') + }`} + /> + {editingGroup.id === 'ungrouped' ? ( +

+ + System group name is locked. Only the access policy can be changed. +

+ ) : editingGroup.name.trim() && + allGroups.some(g => g.id !== editingGroup.id && g.name.toLowerCase() === editingGroup.name.trim().toLowerCase()) && ( +

+ + A group with this name already exists +

+ )} +
+
+ + setEditingGroup({ ...editingGroup, description: e.target.value })} + placeholder="Short description of this group..." + className={inputClass} + /> +
+
+ +
+ + +
+ {['llm', 'tts', 'stt'].map(section => ( +
+
+ {section} Access +
+ + +
+
+
+ {(effective[section]?.providers ? Object.keys(effective[section].providers) : []).map(pId => { + const isChecked = (editingGroup.policy?.[section] || []).includes(pId); + // Resolve label for display + const baseType = pId.split('_')[0]; + const baseDef = providerLists[section].find(ld => ld.id === baseType || ld.id === pId); + const label = baseDef ? (pId.includes('_') ? `${baseDef.label} (${pId.split('_').slice(1).join('_')})` : baseDef.label) : pId; + + return ( + + ); + })} +
+
+ ))} +
+
+ +
+ + +
+
+
+ )} +
+ )} + + {/* Users Management */} + {activeAdminTab === 'users' && ( +
+
+
+

+ Active Roster + {filteredUsers.length} +

+
+
+ setUserSearch(e.target.value)} + placeholder="Search by name, email..." + className="w-full text-xs p-2.5 pl-9 bg-white dark:bg-gray-800 border border-gray-200 dark:border-gray-700 rounded-xl focus:ring-2 focus:ring-indigo-500 outline-none transition-all" + /> + +
+ +
+
+
+ + + + + + + + + + + {filteredUsers.map((u) => ( + + + + + + + ))} + +
MemberPolicy GroupActivity AuditingActions
+
+
+ {(u.username || u.email || '?')[0].toUpperCase()} +
+
+

{u.username || u.email}

+

{u.role}

+
+
+
+ + +
+
+ Join: + {new Date(u.created_at).toLocaleDateString()} +
+
+ Last: + + {u.last_login_at ? new Date(u.last_login_at).toLocaleDateString() : 'Never'} + +
+
+
+ +
+ {allUsers.length === 0 && !usersLoading && ( +
No other users found.
+ )} +
+
+
+ )} +
+
{showVoicesModal && ( diff --git a/ui/client-app/src/services/apiService.js b/ui/client-app/src/services/apiService.js index 2b7a728..619df48 100644 --- a/ui/client-app/src/services/apiService.js +++ b/ui/client-app/src/services/apiService.js @@ -452,15 +452,46 @@ }; /** + * Fetches the user profile info. + */ +export const getUserProfile = async () => { + const userId = getUserId(); + const response = await fetch(`${API_BASE_URL}/users/me/profile`, { + method: "GET", + headers: { "X-User-ID": userId }, + }); + if (!response.ok) throw new Error("Failed to fetch user profile"); + return await response.json(); +}; + +/** + * Updates the user profile info. + */ +export const updateUserProfile = async (profileData) => { + const userId = getUserId(); + const response = await fetch(`${API_BASE_URL}/users/me/profile`, { + method: "PUT", + headers: { + "Content-Type": "application/json", + "X-User-ID": userId, + }, + body: JSON.stringify(profileData), + }); + if (!response.ok) throw new Error("Failed to update user profile"); + return await response.json(); +}; + +/** * Fetches available TTS voice names. */ -export const getVoices = async (apiKey = null) => { +export const getVoices = async (provider = null, apiKey = null) => { try { const userId = getUserId(); const urlParams = new URLSearchParams(); - if (apiKey) urlParams.append('api_key', apiKey); + if (provider) urlParams.append('provider', provider); + if (apiKey && apiKey !== 'null') urlParams.append('api_key', apiKey); - const url = `${API_BASE_URL}/speech/voices${apiKey ? '?' + urlParams.toString() : ''}`; + const url = `${API_BASE_URL}/speech/voices?${urlParams.toString()}`; const response = await fetch(url, { method: 'GET', @@ -489,3 +520,116 @@ } return await response.json(); }; + +/** + * [ADMIN ONLY] Fetches all registered users. + */ +export const getAdminUsers = async () => { + const userId = getUserId(); + const response = await fetch(`${API_BASE_URL}/users/admin/users`, { + method: "GET", + headers: { "X-User-ID": userId }, + }); + if (!response.ok) throw new Error("Failed to fetch admin user list"); + return await response.json(); +}; + +/** + * [ADMIN ONLY] Updates a user's role. + */ +export const updateUserRole = async (targetUserId, role) => { + const userId = getUserId(); + const response = await fetch(`${API_BASE_URL}/users/admin/users/${targetUserId}/role`, { + method: "PUT", + headers: { + "Content-Type": "application/json", + "X-User-ID": userId, + }, + body: JSON.stringify({ role }), + }); + if (!response.ok) throw new Error("Failed to update user role"); + return await response.json(); +}; + +/** + * [ADMIN ONLY] Assigns a user to a group. + */ +export const updateUserGroup = async (targetUserId, groupId) => { + const userId = getUserId(); + const response = await fetch(`${API_BASE_URL}/users/admin/users/${targetUserId}/group`, { + method: "PUT", + headers: { + "Content-Type": "application/json", + "X-User-ID": userId, + }, + body: JSON.stringify({ group_id: groupId }), + }); + if (!response.ok) throw new Error("Failed to update user group"); + return await response.json(); +}; + +/** + * [ADMIN ONLY] Fetches all groups. + */ +export const getAdminGroups = async () => { + const userId = getUserId(); + const response = await fetch(`${API_BASE_URL}/users/admin/groups`, { + method: "GET", + headers: { "X-User-ID": userId }, + }); + if (!response.ok) throw new Error("Failed to fetch group list"); + return await response.json(); +}; + +/** + * [ADMIN ONLY] Creates a new group. + */ +export const createAdminGroup = async (groupData) => { + const userId = getUserId(); + const response = await fetch(`${API_BASE_URL}/users/admin/groups`, { + method: "POST", + headers: { + "Content-Type": "application/json", + "X-User-ID": userId, + }, + body: JSON.stringify(groupData), + }); + if (!response.ok) { + const errData = await response.json().catch(() => ({})); + throw new Error(errData.detail || "Failed to create group"); + } + return await response.json(); +}; + +/** + * [ADMIN ONLY] Updates a group. + */ +export const updateAdminGroup = async (groupId, groupData) => { + const userId = getUserId(); + const response = await fetch(`${API_BASE_URL}/users/admin/groups/${groupId}`, { + method: "PUT", + headers: { + "Content-Type": "application/json", + "X-User-ID": userId, + }, + body: JSON.stringify(groupData), + }); + if (!response.ok) { + const errData = await response.json().catch(() => ({})); + throw new Error(errData.detail || "Failed to update group"); + } + return await response.json(); +}; + +/** + * [ADMIN ONLY] Deletes a group. + */ +export const deleteAdminGroup = async (groupId) => { + const userId = getUserId(); + const response = await fetch(`${API_BASE_URL}/users/admin/groups/${groupId}`, { + method: "DELETE", + headers: { "X-User-ID": userId }, + }); + if (!response.ok) throw new Error("Failed to delete group"); + return await response.json(); +};