from fastapi import APIRouter, HTTPException, Depends, Header, Query, Request, UploadFile, File
from fastapi.responses import RedirectResponse as redirect
from sqlalchemy.orm import Session
from app.config import settings
from app.db import models
from typing import Optional, Annotated
import logging
import os
import httpx
import jwt
import urllib.parse
# Correctly import from your application's schemas and dependencies
from app.api.dependencies import ServiceContainer, get_db
from app.api import schemas
from app.core.services.user import login_required
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Derived OIDC Configuration Helpers ---
def get_oidc_urls():
server_url = settings.OIDC_SERVER_URL.rstrip("/")
return {
"auth": f"{server_url}/auth",
"token": f"{server_url}/token",
"userinfo": f"{server_url}/userinfo"
}
# A dependency to simulate getting the current user ID from a request header
def get_current_user_id(x_user_id: Annotated[Optional[str], Header()] = None) -> Optional[str]:
"""
Retrieves the user ID from the X-User-ID header.
This simulates an authentication system and is used by the login_required decorator.
"""
return x_user_id
def create_users_router(services: ServiceContainer) -> APIRouter:
router = APIRouter(prefix="/users", tags=["Users"])
@router.get("/login", summary="Initiate OIDC Login Flow")
async def login_redirect(
request: Request,
frontend_callback_uri: Optional[str] = Query(None, description="The frontend URI to redirect back to after OIDC provider.")
):
"""
Initiates the OIDC authentication flow.
"""
auth_url = services.auth_service.generate_login_url(frontend_callback_uri)
return redirect(url=auth_url)
@router.get("/login/callback", summary="Handle OIDC Login Callback")
async def login_callback(
request: Request,
code: str = Query(..., description="Authorization code from OIDC provider"),
state: str = Query(..., description="The original frontend redirect URI"),
db: Session = Depends(get_db)
):
"""
Handles the callback from the OIDC provider.
"""
result = await services.auth_service.handle_callback(code, db)
user_id = result["user_id"]
frontend_redirect_url = f"{state}?user_id={user_id}"
return redirect(url=frontend_redirect_url)
@router.get("/me", response_model=schemas.UserStatus, summary="Get Current User Status")
async def get_current_status(
db: Session = Depends(get_db),
user_id: str = Depends(get_current_user_id)
):
"""
Checks the login status of the current user.
Requires a valid user_id to be present in the request header.
"""
try:
# In a real-world scenario, you would fetch user details from the DB using user_id
# For this example, we return a mock response based on the presence of user_id
user : Optional[models.User] = services.user_service.get_user_by_id(db=db, user_id=user_id) # Ensure user exists
email = user.email if user else None
is_anonymous = user is None
is_logged_in = user is not None
return schemas.UserStatus(
id=user_id,
email=email,
is_logged_in=is_logged_in,
is_anonymous=is_anonymous
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
@router.get("/me/profile", response_model=schemas.UserProfile, summary="Get Current User Profile")
async def get_user_profile(
db: Session = Depends(get_db),
user_id: str = Depends(get_current_user_id)
):
"""Retrieves profile information for the current user."""
if not user_id:
raise HTTPException(status_code=401, detail="Unauthorized")
user = services.user_service.get_user_by_id(db=db, user_id=user_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
response = schemas.UserProfile.model_validate(user)
if user.group:
response.group_name = user.group.name
return response
@router.put("/me/profile", response_model=schemas.UserProfile, summary="Update User Profile")
async def update_user_profile(
profile_data: schemas.UserProfileUpdate,
db: Session = Depends(get_db),
user_id: str = Depends(get_current_user_id)
):
"""Updates profile details for the current user."""
if not user_id:
raise HTTPException(status_code=401, detail="Unauthorized")
user = services.user_service.get_user_by_id(db=db, user_id=user_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
if profile_data.username: user.username = profile_data.username
if profile_data.full_name: user.full_name = profile_data.full_name
if profile_data.avatar_url: user.avatar_url = profile_data.avatar_url
db.add(user)
db.commit()
db.refresh(user)
response = schemas.UserProfile.model_validate(user)
if user.group:
response.group_name = user.group.name
return response
@router.get("/me/config", response_model=schemas.ConfigResponse, summary="Get Current User Preferences")
async def get_user_config(
db: Session = Depends(get_db),
user_id: str = Depends(get_current_user_id)
):
"""Gets user specific preferences (LLM, TTS, STT config overrides)."""
if not user_id:
raise HTTPException(status_code=401, detail="Unauthorized")
user = services.user_service.get_user_by_id(db=db, user_id=user_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
return services.preference_service.merge_user_config(user, db)
@router.put("/me/config", response_model=schemas.UserPreferences, summary="Update Current User Preferences")
async def update_user_config(
prefs: schemas.UserPreferences,
db: Session = Depends(get_db),
user_id: str = Depends(get_current_user_id)
):
"""Updates user specific preferences."""
if not user_id:
raise HTTPException(status_code=401, detail="Unauthorized")
user = services.user_service.get_user_by_id(db=db, user_id=user_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
return services.preference_service.update_user_config(user, prefs, db)
@router.get("/me/config/models", response_model=list[schemas.ModelInfoResponse], summary="Get Models for Provider")
async def get_provider_models(provider_name: str, section: str = "llm"):
import litellm
from fastapi.concurrency import run_in_threadpool
def fetch_models():
try:
models = litellm.models_by_provider.get(provider_name, [])
out = []
for m in models:
try:
info = litellm.get_model_info(m)
if "error" not in info:
mode = info.get("mode")
is_valid = False
if section == "llm":
is_valid = mode in ["chat", "text-completion", "custom", None]
elif section == "tts":
is_valid = mode == "audio_speech"
elif section == "stt":
is_valid = mode == "audio_transcription"
elif section == "image":
is_valid = mode == "image_generation"
else:
is_valid = True
if is_valid:
out.append({
"model_name": m,
"max_tokens": info.get("max_tokens"),
"max_input_tokens": info.get("max_input_tokens")
})
except Exception:
pass
return out
except Exception:
return []
results = await run_in_threadpool(fetch_models)
return results
@router.get("/me/config/providers", response_model=list[str], summary="Get All Valid Providers per Section")
async def get_all_providers(
section: str = "llm",
configured_only: bool = Query(False, description="If true, only returns providers currently configured in preferences or system defaults"),
db: Session = Depends(get_db),
user_id: str = Depends(get_current_user_id)
):
import litellm
from app.core.providers.factory import get_registered_tts_providers, get_registered_stt_providers
if configured_only:
# Fetch effective config to see what's actually configured
user = services.user_service.get_user_by_id(db=db, user_id=user_id)
# We don't want to recursivly call our own logic, but we can look at system settings + user prefs
system_prefs = services.user_service.get_system_settings(db)
user_prefs = user.preferences if user else {}
configured = set()
# Add from system defaults
for p in system_prefs.get(section, {}).get("providers", {}).keys():
configured.add(p)
# Add from user overrides
for p in user_prefs.get(section, {}).get("providers", {}).keys():
configured.add(p)
# If nothing configured, fallback to hardcoded defaults in settings (simulating get_user_config logic)
if not configured:
from app.config import settings
if section == "llm":
configured.update(["deepseek", "gemini"])
elif section == "tts":
configured.add(settings.TTS_PROVIDER)
elif section == "stt":
configured.add(settings.STT_PROVIDER)
return sorted(list(configured))
if section == "llm":
return ["general"] + [p.value for p in litellm.LlmProviders]
elif section == "tts":
return ["general"] + get_registered_tts_providers() + ["openai"]
elif section == "stt":
return ["general"] + get_registered_stt_providers() + ["openai"]
elif section == "image":
return ["general"] + [p.value for p in litellm.LlmProviders]
return ["general"] + [p.value for p in litellm.LlmProviders]
@router.post("/me/config/verify_llm", response_model=schemas.VerifyProviderResponse)
async def verify_llm(
req: schemas.VerifyProviderRequest,
db: Session = Depends(get_db),
user_id: str = Depends(get_current_user_id)
):
from app.core.providers.factory import get_llm_provider
if not user_id:
raise HTTPException(status_code=401, detail="Unauthorized")
user = services.user_service.get_user_by_id(db=db, user_id=user_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
# We allow verification if user is admin OR if they are providing their own key (not using a masked key without permission)
is_using_masked = not req.api_key or "***" in str(req.api_key)
if is_using_masked and user.role != "admin":
raise HTTPException(status_code=403, detail="Forbidden: Admin only for masked keys")
actual_key = req.api_key
try:
llm_prefs = {}
user = services.user_service.get_user_by_id(db=db, user_id=user_id)
if user and user.preferences:
llm_prefs = user.preferences.get("llm", {}).get("providers", {}).get(req.provider_name, {})
# Handle masked keys by backfilling from stored prefs if needed
if actual_key and "***" in actual_key:
actual_key = llm_prefs.get("api_key")
if not actual_key:
# Fallback to system defaults if admin
system_prefs = services.user_service.get_system_settings(db)
actual_key = system_prefs.get("llm", {}).get("providers", {}).get(req.provider_type or req.provider_name, {}).get("api_key")
kwargs = {k: v for k, v in llm_prefs.items() if k not in ["api_key", "model"]}
if req.provider_type:
kwargs["provider_type"] = req.provider_type
llm = get_llm_provider(
provider_name=req.provider_name,
model_name=req.model or "",
api_key_override=actual_key,
**kwargs
)
# LiteLLM check: litellm models are callable
res = llm("Hello")
return schemas.VerifyProviderResponse(success=True, message="Connection successful!")
except Exception as e:
import logging
logging.getLogger(__name__).error(f"LLM Verification failed for {req.provider_name} ({req.provider_type}): {e}")
return schemas.VerifyProviderResponse(success=False, message=str(e))
@router.post("/me/config/verify_tts", response_model=schemas.VerifyProviderResponse)
async def verify_tts(
req: schemas.VerifyProviderRequest,
db: Session = Depends(get_db),
user_id: str = Depends(get_current_user_id)
):
from app.core.providers.factory import get_tts_provider
from app.config import settings
import logging
logger = logging.getLogger(__name__)
if not user_id:
raise HTTPException(status_code=401, detail="Unauthorized")
user = services.user_service.get_user_by_id(db=db, user_id=user_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
is_using_masked = not req.api_key or "***" in str(req.api_key)
if is_using_masked and user.role != "admin":
raise HTTPException(status_code=403, detail="Forbidden: Admin only for masked keys")
actual_key = req.api_key
try:
tts_prefs = user.preferences.get("tts", {}).get("providers", {}).get(req.provider_name, {}) if user.preferences else {}
# Key resolution: Masked keys should be replaced with real ones from DB or system config
if not actual_key or "***" in str(actual_key):
actual_key = tts_prefs.get("api_key")
if not actual_key or "***" in str(actual_key):
# Try system settings
system_prefs = services.user_service.get_system_settings(db)
actual_key = system_prefs.get("tts", {}).get("providers", {}).get(req.provider_type or req.provider_name, {}).get("api_key")
# Final fallback to settings.py constants
if not actual_key: actual_key = settings.TTS_API_KEY or settings.GEMINI_API_KEY
logger.info(f"verify_tts: instance={req.provider_name}, type={req.provider_type}, model={req.model}")
kwargs = {k: v for k, v in tts_prefs.items() if k not in ["api_key", "model", "voice"]}
if req.provider_type:
kwargs["provider_type"] = req.provider_type
provider = get_tts_provider(
provider_name=req.provider_name,
api_key=actual_key,
model_name=req.model or "",
voice_name=req.voice or "",
**kwargs
)
await provider.generate_speech("Hello there. We are testing this thing.")
return schemas.VerifyProviderResponse(success=True, message="Connection successful!")
except Exception as e:
logger.error(f"TTS verification failed for {req.provider_name}: {e}")
return schemas.VerifyProviderResponse(success=False, message=str(e))
@router.post("/me/config/verify_stt", response_model=schemas.VerifyProviderResponse)
async def verify_stt(
req: schemas.VerifyProviderRequest,
db: Session = Depends(get_db),
user_id: str = Depends(get_current_user_id)
):
from app.core.providers.factory import get_stt_provider
from app.config import settings
import logging
logger = logging.getLogger(__name__)
if not user_id:
raise HTTPException(status_code=401, detail="Unauthorized")
user = services.user_service.get_user_by_id(db=db, user_id=user_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
is_using_masked = not req.api_key or "***" in str(req.api_key)
if is_using_masked and user.role != "admin":
raise HTTPException(status_code=403, detail="Forbidden: Admin only for masked keys")
actual_key = req.api_key
try:
stt_prefs = user.preferences.get("stt", {}).get("providers", {}).get(req.provider_name, {}) if user.preferences else {}
if not actual_key or "***" in str(actual_key):
actual_key = stt_prefs.get("api_key")
if not actual_key or "***" in str(actual_key):
system_prefs = services.user_service.get_system_settings(db)
actual_key = system_prefs.get("stt", {}).get("providers", {}).get(req.provider_type or req.provider_name, {}).get("api_key")
if not actual_key: actual_key = settings.STT_API_KEY or settings.GEMINI_API_KEY
kwargs = {k: v for k, v in stt_prefs.items() if k not in ["api_key", "model"]}
if req.provider_type:
kwargs["provider_type"] = req.provider_type
provider = get_stt_provider(
provider_name=req.provider_name,
api_key=actual_key,
model_name=req.model or "",
**kwargs
)
# Minimal STT check: factory init is usually enough to catch invalid credentials for SDK-based providers
return schemas.VerifyProviderResponse(success=True, message="Provider initialized. Full transcription test requires audio payload.")
except Exception as e:
logger.error(f"STT verification failed for {req.provider_name}: {e}")
return schemas.VerifyProviderResponse(success=False, message=str(e))
@router.post("/logout", summary="Log Out the Current User")
async def logout():
"""
Simulates a user logout. In a real application, this would clear the session token or cookie.
"""
return {"message": "Logged out successfully"}
@router.get("/me/config/export", summary="Export Configurations to YAML")
async def export_user_config_yaml(
db: Session = Depends(get_db),
user_id: str = Depends(get_current_user_id)
):
"""Exports the effective user configuration as a YAML file (Admin only)."""
from fastapi.responses import PlainTextResponse
if not user_id:
raise HTTPException(status_code=401, detail="Unauthorized")
user = services.user_service.get_user_by_id(db=db, user_id=user_id)
if not user or user.role != "admin":
raise HTTPException(status_code=403, detail="Forbidden: Admin only")
prefs_dict = user.preferences or {}
from app.config import settings
import yaml
llm_prefs = prefs_dict.get("llm", {})
tts_prefs = prefs_dict.get("tts", {})
stt_prefs = prefs_dict.get("stt", {})
llm_providers_export = {}
user_providers = llm_prefs.get("providers", {})
if not user_providers:
# Fallback to system defaults if no user config exists
llm_providers_export = {
"deepseek_api_key": settings.DEEPSEEK_API_KEY,
"deepseek_model_name": settings.DEEPSEEK_MODEL_NAME,
"gemini_api_key": settings.GEMINI_API_KEY,
"gemini_model_name": settings.GEMINI_MODEL_NAME,
"openai_api_key": settings.OPENAI_API_KEY
}
else:
for p, p_data in user_providers.items():
llm_providers_export[f"{p}_api_key"] = p_data.get("api_key")
llm_providers_export[f"{p}_model_name"] = p_data.get("model")
def get_provider_export(section_prefs, fallback_provider, fallback_model, fallback_api_key, fallback_voice=None):
active_p = section_prefs.get("active_provider")
providers = section_prefs.get("providers", {})
if active_p and active_p in providers:
p_data = providers[active_p]
return {
"provider": active_p,
"model_name": p_data.get("model"),
"voice_name": p_data.get("voice"),
"api_key": p_data.get("api_key")
}
# Fallback to system settings
return {
"provider": fallback_provider,
"model_name": fallback_model,
"voice_name": fallback_voice,
"api_key": fallback_api_key
}
# Layer 2 (Day 2) Export: Only LLM, TTS, STT
yaml_data = {
"llm_providers": {
"providers": user_providers or settings.LLM_PROVIDERS
},
"tts_provider": get_provider_export(tts_prefs, settings.TTS_PROVIDER, settings.TTS_MODEL_NAME, settings.TTS_API_KEY, settings.TTS_VOICE_NAME),
"stt_provider": get_provider_export(stt_prefs, settings.STT_PROVIDER, settings.STT_MODEL_NAME, settings.STT_API_KEY)
}
# Filter out None values recursively
def remove_none(obj):
if isinstance(obj, dict):
return {k: remove_none(v) for k, v in obj.items() if v is not None}
return obj
clean_yaml_data = remove_none(yaml_data)
yaml_str = yaml.dump(clean_yaml_data, sort_keys=False, default_flow_style=False)
return PlainTextResponse(
content=yaml_str,
media_type="application/x-yaml",
headers={"Content-Disposition": "attachment; filename=\"day2_config.yaml\""}
)
@router.post("/me/config/import", response_model=schemas.UserPreferences, summary="Import Configurations from YAML")
async def import_user_config_yaml(
file: UploadFile = File(...),
db: Session = Depends(get_db),
user_id: str = Depends(get_current_user_id)
):
"""Imports user configuration from a YAML file."""
if not user_id:
raise HTTPException(status_code=401, detail="Unauthorized")
user = services.user_service.get_user_by_id(db=db, user_id=user_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
content = await file.read()
try:
import yaml
data = yaml.safe_load(content)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid YAML file: {e}")
# Reverse mapping: YAML -> UserPreferences structure
new_llm = { "providers": {}, "active_provider": None }
new_tts = { "providers": {}, "active_provider": None }
new_stt = { "providers": {}, "active_provider": None }
# --- LLM ---
llm_data = data.get("llm_providers", {})
if "providers" in llm_data: # Structured
new_llm["providers"] = llm_data["providers"]
else: # Flattened (as exported)
for k, v in llm_data.items():
if k.endswith("_api_key"):
p = k.replace("_api_key", "")
if p not in new_llm["providers"]: new_llm["providers"][p] = {}
new_llm["providers"][p]["api_key"] = v
elif k.endswith("_model_name"):
p = k.replace("_model_name", "")
if p not in new_llm["providers"]: new_llm["providers"][p] = {}
new_llm["providers"][p]["model"] = v
if new_llm["providers"]:
new_llm["active_provider"] = next(iter(new_llm["providers"]), None)
# --- TTS ---
tts_data = data.get("tts_provider", {})
if tts_data:
p = tts_data.get("provider") or "google_gemini"
new_tts["active_provider"] = p
new_tts["providers"][p] = {
"api_key": tts_data.get("api_key"),
"model": tts_data.get("model_name"),
"voice": tts_data.get("voice_name")
}
# --- STT ---
stt_data = data.get("stt_provider", {})
if stt_data:
p = stt_data.get("provider") or "google_gemini"
new_stt["active_provider"] = p
new_stt["providers"][p] = {
"api_key": stt_data.get("api_key"),
"model": stt_data.get("model_name")
}
user.preferences = {
"llm": new_llm,
"tts": new_tts,
"stt": new_stt,
"statuses": {}
}
from sqlalchemy.orm.attributes import flag_modified
flag_modified(user, "preferences")
# --- Day 2 Sync ---
from app.config import settings as global_settings
if new_llm.get("providers"):
global_settings.LLM_PROVIDERS.update(new_llm["providers"])
if new_tts.get("active_provider"):
p = new_tts["active_provider"]
p_data = new_tts["providers"].get(p, {})
if p_data:
global_settings.TTS_PROVIDER = p
global_settings.TTS_MODEL_NAME = p_data.get("model") or global_settings.TTS_MODEL_NAME
global_settings.TTS_VOICE_NAME = p_data.get("voice") or global_settings.TTS_VOICE_NAME
global_settings.TTS_API_KEY = p_data.get("api_key") or global_settings.TTS_API_KEY
if new_stt.get("active_provider"):
p = new_stt["active_provider"]
p_data = new_stt["providers"].get(p, {})
if p_data:
global_settings.STT_PROVIDER = p
global_settings.STT_MODEL_NAME = p_data.get("model") or global_settings.STT_MODEL_NAME
global_settings.STT_API_KEY = p_data.get("api_key") or global_settings.STT_API_KEY
try:
global_settings.save_to_yaml()
except Exception as ey:
logger.error(f"Failed to sync settings to YAML on import: {ey}")
db.add(user)
db.commit()
db.refresh(user)
return schemas.UserPreferences(llm=user.preferences.get("llm", {}), tts=user.preferences.get("tts", {}), stt=user.preferences.get("stt", {}))
# --- NEW ADMIN ROUTES ---
@router.get("/admin/users", response_model=list[schemas.UserProfile], summary="List All Users (Admin Only)")
async def admin_list_users(
db: Session = Depends(get_db),
user_id: str = Depends(get_current_user_id)
):
"""Returns a list of all registered users in the system."""
if not user_id:
raise HTTPException(status_code=401, detail="Unauthorized")
user = services.user_service.get_user_by_id(db=db, user_id=user_id)
if not user or user.role != "admin":
raise HTTPException(status_code=403, detail="Forbidden: Admin only")
users = services.user_service.get_all_users(db)
response = []
for u in users:
p = schemas.UserProfile.model_validate(u)
if u.group:
p.group_name = u.group.name
response.append(p)
return response
@router.put("/admin/users/{uid}/role", response_model=schemas.UserProfile, summary="Update User Role (Admin Only)")
async def admin_update_role(
uid: str,
role_req: schemas.UserRoleUpdate,
db: Session = Depends(get_db),
user_id: str = Depends(get_current_user_id)
):
"""Updates a user's role. Prevents demoting the last administrator."""
if not user_id:
raise HTTPException(status_code=401, detail="Unauthorized")
user = services.user_service.get_user_by_id(db=db, user_id=user_id)
if not user or user.role != "admin":
raise HTTPException(status_code=403, detail="Forbidden: Admin only")
success = services.user_service.update_user_role(db, uid, role_req.role)
if not success:
raise HTTPException(status_code=400, detail="Failed to update role. Maybe this is the last admin?")
return services.user_service.get_user_by_id(db, uid)
@router.put("/admin/users/{uid}/group", response_model=schemas.UserProfile, summary="Update User Group (Admin Only)")
async def admin_update_user_group(
uid: str,
group_req: schemas.UserGroupUpdate,
db: Session = Depends(get_db),
user_id: str = Depends(get_current_user_id)
):
"""Assigns a user to a group."""
if not user_id:
raise HTTPException(status_code=401, detail="Unauthorized")
user = services.user_service.get_user_by_id(db=db, user_id=user_id)
if not user or user.role != "admin":
raise HTTPException(status_code=403, detail="Forbidden: Admin only")
success = services.user_service.assign_user_to_group(db, uid, group_req.group_id)
if not success:
raise HTTPException(status_code=404, detail="User or group not found")
return services.user_service.get_user_by_id(db, uid)
@router.get("/admin/groups", response_model=list[schemas.GroupInfo], summary="List All Groups (Admin Only)")
async def admin_list_groups(
db: Session = Depends(get_db),
user_id: str = Depends(get_current_user_id)
):
"""Returns all existing groups."""
if not user_id:
raise HTTPException(status_code=401, detail="Unauthorized")
user = services.user_service.get_user_by_id(db=db, user_id=user_id)
if not user or user.role != "admin":
raise HTTPException(status_code=403, detail="Forbidden: Admin only")
# Explicitly convert to Pydantic models within the session scope
# to prevent SQLAlchemy lazy-loading issues in async context.
groups = services.user_service.get_all_groups(db)
return [schemas.GroupInfo.model_validate(g) for g in groups]
@router.post("/admin/groups", response_model=schemas.GroupInfo, summary="Create Group (Admin Only)")
async def admin_create_group(
group_req: schemas.GroupCreate,
db: Session = Depends(get_db),
user_id: str = Depends(get_current_user_id)
):
"""Creates a new group."""
if not user_id:
raise HTTPException(status_code=401, detail="Unauthorized")
user = services.user_service.get_user_by_id(db=db, user_id=user_id)
if not user or user.role != "admin":
raise HTTPException(status_code=403, detail="Forbidden: Admin only")
group = services.user_service.create_group(db, group_req.name, group_req.description, group_req.policy)
if group is None:
raise HTTPException(status_code=409, detail=f"A group named '{group_req.name}' already exists. Please choose a unique name.")
return schemas.GroupInfo.model_validate(group)
@router.put("/admin/groups/{gid}", response_model=schemas.GroupInfo, summary="Update Group (Admin Only)")
async def admin_update_group(
gid: str,
group_req: schemas.GroupUpdate,
db: Session = Depends(get_db),
user_id: str = Depends(get_current_user_id)
):
"""Updates a group's metadata or policy."""
if not user_id:
raise HTTPException(status_code=401, detail="Unauthorized")
user = services.user_service.get_user_by_id(db=db, user_id=user_id)
if not user or user.role != "admin":
raise HTTPException(status_code=403, detail="Forbidden: Admin only")
# The 'ungrouped' group cannot be renamed — only its policy can be updated
if gid == "ungrouped" and group_req.name and group_req.name.strip().lower() != "ungrouped":
raise HTTPException(status_code=403, detail="The default 'Ungrouped' group cannot be renamed.")
group = services.user_service.update_group(db, gid, group_req.name, group_req.description, group_req.policy)
if group is None:
raise HTTPException(status_code=404, detail="Group not found")
if group is False:
raise HTTPException(status_code=409, detail=f"A group named '{group_req.name}' already exists. Please choose a unique name.")
return schemas.GroupInfo.model_validate(group)
@router.delete("/admin/groups/{gid}", summary="Delete Group (Admin Only)")
async def admin_delete_group(
gid: str,
db: Session = Depends(get_db),
user_id: str = Depends(get_current_user_id)
):
"""Deletes a group. Users are moved back to 'ungrouped'."""
if not user_id:
raise HTTPException(status_code=401, detail="Unauthorized")
user = services.user_service.get_user_by_id(db=db, user_id=user_id)
if not user or user.role != "admin":
raise HTTPException(status_code=403, detail="Forbidden: Admin only")
# Cannot delete the system default group
if gid == "ungrouped":
raise HTTPException(status_code=403, detail="The default 'Ungrouped' group cannot be deleted.")
success = services.user_service.delete_group(db, gid)
if not success:
raise HTTPException(status_code=400, detail="Failed to delete group.")
return {"message": "Group deleted successfully"}
return router