Newer
Older
cortex-hub / ai-hub / app / api / routes / user.py
from fastapi import APIRouter, HTTPException, Depends, Header, Query, Request, UploadFile, File
from fastapi.responses import RedirectResponse as redirect
from sqlalchemy.orm import Session
from app.db import models
from typing import Optional, Annotated
import logging
import os
import httpx
import jwt
import urllib.parse

# Correctly import from your application's schemas and dependencies
from app.api.dependencies import ServiceContainer, get_db
from app.api import schemas
from app.core.services.user import login_required

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Minimum OIDC configuration from settings
from app.config import settings
OIDC_CLIENT_ID = settings.OIDC_CLIENT_ID
OIDC_CLIENT_SECRET = settings.OIDC_CLIENT_SECRET
OIDC_SERVER_URL = settings.OIDC_SERVER_URL
OIDC_REDIRECT_URI = settings.OIDC_REDIRECT_URI

# --- Derived OIDC Configuration ---
OIDC_AUTHORIZATION_URL = f"{OIDC_SERVER_URL}/auth"
OIDC_TOKEN_URL = f"{OIDC_SERVER_URL}/token"
OIDC_USERINFO_URL = f"{OIDC_SERVER_URL}/userinfo"

# A dependency to simulate getting the current user ID from a request header
def get_current_user_id(x_user_id: Annotated[Optional[str], Header()] = None) -> Optional[str]:
    """
    Retrieves the user ID from the X-User-ID header.
    This simulates an authentication system and is used by the login_required decorator.
    """
    return x_user_id


def create_users_router(services: ServiceContainer) -> APIRouter:
    router = APIRouter(prefix="/users", tags=["Users"])

    @router.get("/login", summary="Initiate OIDC Login Flow")
    async def login_redirect(
        request: Request,
        # Allow the frontend to provide its callback URL
        frontend_callback_uri: Optional[str] = Query(None, description="The frontend URI to redirect back to after OIDC provider.")
    ):
        """
        Initiates the OIDC authentication flow. The `frontend_callback_uri`
        specifies where the user should be redirected after successful
        authentication with the OIDC provider.
        """
        # Store the frontend_callback_uri in a session or a cache,
        # linked to the state parameter for security.
        # For simplicity, we will pass it as a query parameter in the callback.
        # A more robust solution would use a state parameter.
        
        # Use urllib.parse.urlencode to properly encode parameters
        params = {
            "response_type": "code",
            "scope": "openid profile email",
            "client_id": OIDC_CLIENT_ID,
            "redirect_uri": OIDC_REDIRECT_URI,
            "state": frontend_callback_uri or ""
        }
        
        auth_url = f"{OIDC_AUTHORIZATION_URL}?{urllib.parse.urlencode(params)}"
        logger.info(f"Redirecting to OIDC authorization URL: {auth_url}")
        return redirect(url=auth_url)
    
    @router.get("/login/callback", summary="Handle OIDC Login Callback")
    async def login_callback(
        request: Request,
        code: str = Query(..., description="Authorization code from OIDC provider"),
        state: str = Query(..., description="The original frontend redirect URI"),
        db: Session = Depends(get_db)
    ):
        """
        Handles the callback from the OIDC provider, exchanges the code for
        tokens, and then redirects the user back to the frontend with
        the user data or a session token.
        """
        logger.info(f"Received callback with authorization code: {code[:10]}... and state: {state}")
        
        try:
            logger.info(f"Exchanging code for tokens at: {OIDC_TOKEN_URL}")
            # Step 1: Exchange the authorization code for an access token and an ID token
            token_data = {
                "grant_type": "authorization_code",
                "code": code,
                "redirect_uri": OIDC_REDIRECT_URI,
                "client_id": OIDC_CLIENT_ID,
                "client_secret": OIDC_CLIENT_SECRET,
            }
            
            async with httpx.AsyncClient() as client:
                logger.debug(f"Sending POST to {OIDC_TOKEN_URL} with data keys: {list(token_data.keys())}")
                token_response = await client.post(OIDC_TOKEN_URL, data=token_data, timeout=30.0)
                token_response.raise_for_status()
                response_json = token_response.json()
            
            logger.info("Successfully received tokens from OIDC provider.")
            id_token = response_json.get("id_token")
            
            if not id_token:
                logger.error("Error: ID token not found in the response.")
                raise HTTPException(status_code=400, detail="Failed to get ID token from OIDC provider.")
            
            # Step 2: Decode the ID token to get user information
            logger.info("Decoding ID token...")
            decoded_id_token = jwt.decode(id_token, options={"verify_signature": False})
            oidc_id = decoded_id_token.get("sub")
            email = decoded_id_token.get("email")
            # Dex and others often use 'name' for the full name, or 'preferred_username'
            username = decoded_id_token.get("name") or decoded_id_token.get("preferred_username") or email
            
            logger.info(f"User decoded: email={email}, oidc_id={oidc_id}")
            
            if not all([oidc_id, email]):
                logger.error(f"Error: Essential user data missing. oidc_id={oidc_id}, email={email}")
                raise HTTPException(status_code=400, detail="Essential user data missing from ID token (sub and email required).")

            # Step 3: Save the user and get their unique ID
            logger.info("Saving user to database...")
            user_id = services.user_service.save_user(
                db=db,
                oidc_id=oidc_id,
                email=email,
                username=username
            )
            logger.info(f"User saved/updated successfully with internal ID: {user_id}")
            
            # Step 4: Redirect back to the frontend
            frontend_redirect_url = f"{state}?user_id={user_id}"
            logger.info(f"Redirecting back to frontend: {frontend_redirect_url}")
            
            return redirect(url=frontend_redirect_url)

        except httpx.HTTPStatusError as e:
            logger.error(f"OIDC Token exchange failed with status {e.response.status_code}: {e.response.text}")
            raise HTTPException(status_code=500, detail=f"OIDC Token exchange failed: {e.response.text}")
        except httpx.RequestError as e:
            logger.error(f"OIDC Token exchange request error: {e}")
            raise HTTPException(status_code=500, detail=f"Failed to communicate with OIDC provider: {e}")
        except jwt.DecodeError as e:
            logger.error(f"ID token decode error: {e}")
            raise HTTPException(status_code=400, detail="Failed to decode ID token from OIDC provider.")
        except Exception as e:
            logger.exception(f"An unexpected error occurred during OIDC callback: {e}")
            raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {e}")

    @router.get("/me", response_model=schemas.UserStatus, summary="Get Current User Status")
    async def get_current_status(
        db: Session = Depends(get_db),
        user_id: str = Depends(get_current_user_id)
    ):
        """
        Checks the login status of the current user.
        Requires a valid user_id to be present in the request header.
        """
        try:
            # In a real-world scenario, you would fetch user details from the DB using user_id
            # For this example, we return a mock response based on the presence of user_id
            
            user : Optional[models.User] = services.user_service.get_user_by_id(db=db, user_id=user_id)  # Ensure user exists
            email = user.email if user else None
            is_anonymous = user is None
            is_logged_in = user is not None
            return schemas.UserStatus(
                id=user_id,
                email=email,
                is_logged_in=is_logged_in,
                is_anonymous=is_anonymous  
            )
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"An error occurred: {e}")

    @router.get("/me/profile", response_model=schemas.UserProfile, summary="Get Current User Profile")
    async def get_user_profile(
        db: Session = Depends(get_db),
        user_id: str = Depends(get_current_user_id)
    ):
        """Retrieves profile information for the current user."""
        if not user_id:
            raise HTTPException(status_code=401, detail="Unauthorized")
        user = services.user_service.get_user_by_id(db=db, user_id=user_id)
        if not user:
            raise HTTPException(status_code=404, detail="User not found")
        
        response = schemas.UserProfile.model_validate(user)
        if user.group:
            response.group_name = user.group.name
        return response

    @router.put("/me/profile", response_model=schemas.UserProfile, summary="Update User Profile")
    async def update_user_profile(
        profile_data: schemas.UserProfileUpdate,
        db: Session = Depends(get_db),
        user_id: str = Depends(get_current_user_id)
    ):
        """Updates profile details for the current user."""
        if not user_id:
            raise HTTPException(status_code=401, detail="Unauthorized")
        user = services.user_service.get_user_by_id(db=db, user_id=user_id)
        if not user:
            raise HTTPException(status_code=404, detail="User not found")
        
        if profile_data.username: user.username = profile_data.username
        if profile_data.full_name: user.full_name = profile_data.full_name
        if profile_data.avatar_url: user.avatar_url = profile_data.avatar_url
        
        db.add(user)
        db.commit()
        db.refresh(user)
        
        response = schemas.UserProfile.model_validate(user)
        if user.group:
            response.group_name = user.group.name
        return response

    @router.get("/me/config", response_model=schemas.ConfigResponse, summary="Get Current User Preferences")
    async def get_user_config(
        db: Session = Depends(get_db),
        user_id: str = Depends(get_current_user_id)
    ):
        """Gets user specific preferences (LLM, TTS, STT config overrides)."""
        if not user_id:
            raise HTTPException(status_code=401, detail="Unauthorized")
        user = services.user_service.get_user_by_id(db=db, user_id=user_id)
        if not user:
            raise HTTPException(status_code=404, detail="User not found")
        prefs_dict = user.preferences or {}
        
        # Calculate effective config
        from app.config import settings
        
        def mask_key(k):
            if not k: return None
            if len(k) <= 8: return "****"
            return k[:4] + "*" * (len(k)-8) + k[-4:]

        llm_prefs = prefs_dict.get("llm", {})
        tts_prefs = prefs_dict.get("tts", {})
        stt_prefs = prefs_dict.get("stt", {})
        
        # Load system defaults from DB if needed
        system_prefs = services.user_service.get_system_settings(db)

        user_providers = llm_prefs.get("providers", {})
        if not user_providers:
            # Try to get from system admin in DB first
            system_llm = system_prefs.get("llm", {}).get("providers", {})
            if system_llm:
                user_providers = system_llm
            else:
                # Fallback to hardcoded settings defaults (if any left in yaml)
                user_providers = {
                    "deepseek": {"api_key": settings.DEEPSEEK_API_KEY, "model": settings.DEEPSEEK_MODEL_NAME},
                    "gemini": {"api_key": settings.GEMINI_API_KEY, "model": settings.GEMINI_MODEL_NAME},
                }

        llm_providers_effective = {}
        for p, p_p in user_providers.items():
            if p_p:
               llm_providers_effective[p] = {
                   "api_key": mask_key(p_p.get("api_key")),
                   "model": p_p.get("model")
               }

        user_tts_providers = tts_prefs.get("providers", {})
        if not user_tts_providers:
            system_tts = system_prefs.get("tts", {}).get("providers", {})
            if system_tts:
                user_tts_providers = system_tts
            else:
                user_tts_providers = {
                    settings.TTS_PROVIDER: {
                        "api_key": settings.TTS_API_KEY,
                        "model": settings.TTS_MODEL_NAME,
                        "voice": settings.TTS_VOICE_NAME
                    }
                }
        
        tts_providers_effective = {}
        for p, p_p in user_tts_providers.items():
            if p_p:
                tts_providers_effective[p] = {
                    "api_key": mask_key(p_p.get("api_key")),
                    "model": p_p.get("model"),
                    "voice": p_p.get("voice")
                }

        user_stt_providers = stt_prefs.get("providers", {})
        if not user_stt_providers:
            system_stt = system_prefs.get("stt", {}).get("providers", {})
            if system_stt:
                user_stt_providers = system_stt
            else:
                user_stt_providers = {
                    settings.STT_PROVIDER: {
                        "api_key": settings.STT_API_KEY,
                        "model": settings.STT_MODEL_NAME
                    }
                }
                
        stt_providers_effective = {}
        for p, p_p in user_stt_providers.items():
            if p_p:
                stt_providers_effective[p] = {
                    "api_key": mask_key(p_p.get("api_key")),
                    "model": p_p.get("model")
                }
        effective = {
            "llm": {
                "active_provider": llm_prefs.get("active_provider") or (next(iter(llm_providers_effective), None)) or "deepseek",
                "providers": llm_providers_effective
            },
            "tts": {
                "active_provider": tts_prefs.get("active_provider") or (next(iter(tts_providers_effective), None)) or settings.TTS_PROVIDER,
                "providers": tts_providers_effective
            },
            "stt": {
                "active_provider": stt_prefs.get("active_provider") or (next(iter(stt_providers_effective), None)) or settings.STT_PROVIDER,
                "providers": stt_providers_effective
            }
        }

        # --- Group Policy Enforcement ---
        # Only enforce for non-admin users. Admins should see all configured providers.
        # If user has no group, they fall under the 'ungrouped' default group policy.
        group = user.group or services.user_service.get_or_create_default_group(db)
        if group and user.role != "admin":
            policy = group.policy or {}
            def apply_policy(section_key, policy_key, prefs_dict):
                # A policy is a list of allowed provider IDs. Empty list means NO access to that section's providers.
                allowed = policy.get(policy_key, [])
                if not allowed:
                    # Explicit empty list or missing key results in NO providers
                    effective[section_key]["providers"] = {}
                    if prefs_dict and "providers" in prefs_dict:
                        prefs_dict["providers"] = {}
                    effective[section_key]["active_provider"] = ""
                    return prefs_dict
                
                # Filter the effective providers map
                providers = effective[section_key]["providers"]
                filtered_eff = {k: v for k, v in providers.items() if k in allowed}
                effective[section_key]["providers"] = filtered_eff
                
                # Filter the user preferences too to avoid showing forbidden items
                if prefs_dict and "providers" in prefs_dict:
                    prefs_dict["providers"] = {
                        k: v for k, v in prefs_dict["providers"].items() if k in allowed
                    }
                
                # Ensure active provider is still valid under policy
                if effective[section_key].get("active_provider") not in allowed:
                    effective[section_key]["active_provider"] = next(iter(filtered_eff), None) or ""
                
                return prefs_dict

            llm_prefs = apply_policy("llm", "llm", llm_prefs)
            tts_prefs = apply_policy("tts", "tts", tts_prefs)
            stt_prefs = apply_policy("stt", "stt", stt_prefs)

        # Ensure we mask the preferences dict we send back to the user
        def mask_section_prefs(section_dict):
            if not section_dict: return {}
            import copy
            masked_dict = copy.deepcopy(section_dict)
            providers = masked_dict.get("providers", {})
            for p_name, p_data in providers.items():
                if p_data.get("api_key"):
                    p_data["api_key"] = mask_key(p_data["api_key"])
            return masked_dict

        return schemas.ConfigResponse(
            preferences=schemas.UserPreferences(
                llm=mask_section_prefs(llm_prefs),
                tts=mask_section_prefs(tts_prefs),
                stt=mask_section_prefs(stt_prefs),
                statuses=user.preferences.get("statuses", {})
            ),
            effective=effective
        )

    @router.put("/me/config", response_model=schemas.UserPreferences, summary="Update Current User Preferences")
    async def update_user_config(
        prefs: schemas.UserPreferences,
        db: Session = Depends(get_db),
        user_id: str = Depends(get_current_user_id)
    ):
        """Updates user specific preferences."""
        if not user_id:
            raise HTTPException(status_code=401, detail="Unauthorized")
        user = services.user_service.get_user_by_id(db=db, user_id=user_id)
        if not user:
            raise HTTPException(status_code=404, detail="User not found")
        
        # When saving, if the api_key contains ****, we must retain the old one from the DB
        old_prefs = user.preferences or {}
        
        def preserve_masked_keys(section_name, new_section):
            if not new_section or "providers" not in new_section:
                return
            old_section = old_prefs.get(section_name, {}).get("providers", {})
            for p_name, p_data in new_section["providers"].items():
                if p_data.get("api_key") and "***" in p_data["api_key"]:
                    if p_name in old_section:
                        p_data["api_key"] = old_section[p_name].get("api_key")

        def resolve_clone_from(section_name, new_section):
            """
            If a new provider instance was created with _clone_from=<source_id>,
            copy the real API key from the source provider stored in the DB.
            The _clone_from marker is then removed so it is not persisted.
            """
            if not new_section or "providers" not in new_section:
                return
            # Look in DB-stored prefs first, then fall back to system settings
            old_section = old_prefs.get(section_name, {}).get("providers", {})
            system_prefs = services.user_service.get_system_settings(db)
            system_section = system_prefs.get(section_name, {}).get("providers", {})

            for p_name, p_data in new_section["providers"].items():
                clone_source = p_data.pop("_clone_from", None)
                if not clone_source:
                    continue
                # Resolve real key: DB prefs > system settings
                real_key = (
                    old_section.get(clone_source, {}).get("api_key")
                    or system_section.get(clone_source, {}).get("api_key")
                )
                if real_key and "***" not in str(real_key):
                    p_data["api_key"] = real_key
                    logger.info(
                        f"Resolved _clone_from: {p_name} inherited api_key from {clone_source} [{section_name}]"
                    )
                else:
                    logger.warning(
                        f"Could not resolve _clone_from for {p_name}: source '{clone_source}' key not found or masked."
                    )

        if prefs.llm: preserve_masked_keys("llm", prefs.llm)
        if prefs.tts: preserve_masked_keys("tts", prefs.tts)
        if prefs.stt: preserve_masked_keys("stt", prefs.stt)

        # resolve_clone_from must run AFTER preserve_masked_keys so the source
        # provider's key is already unmasked in the new_section dict when needed.
        if prefs.llm: resolve_clone_from("llm", prefs.llm)
        if prefs.tts: resolve_clone_from("tts", prefs.tts)
        if prefs.stt: resolve_clone_from("stt", prefs.stt)

        user.preferences = {
            "llm": prefs.llm,
            "tts": prefs.tts,
            "stt": prefs.stt,
            "statuses": prefs.statuses or {}
        }
        
        # --- Enterprise RBAC Sync ---
        # ONLY admins can sync to Global Settings and persist to config.yaml
        if user.role == "admin":
            from sqlalchemy.orm.attributes import flag_modified
            flag_modified(user, "preferences")
            
            from app.config import settings as global_settings
            
            # Sync LLM
            if prefs.llm and prefs.llm.get("providers"):
                global_settings.LLM_PROVIDERS.update(prefs.llm.get("providers", {}))
                
            # Sync TTS
            if prefs.tts and prefs.tts.get("active_provider"):
                p_name = prefs.tts["active_provider"]
                p_data = prefs.tts.get("providers", {}).get(p_name, {})
                if p_data:
                    global_settings.TTS_PROVIDER = p_name
                    global_settings.TTS_MODEL_NAME = p_data.get("model") or global_settings.TTS_MODEL_NAME
                    global_settings.TTS_VOICE_NAME = p_data.get("voice") or global_settings.TTS_VOICE_NAME
                    global_settings.TTS_API_KEY = p_data.get("api_key") or global_settings.TTS_API_KEY
                    
            # Sync STT
            if prefs.stt and prefs.stt.get("active_provider"):
                p_name = prefs.stt["active_provider"]
                p_data = prefs.stt.get("providers", {}).get(p_name, {})
                if p_data:
                    global_settings.STT_PROVIDER = p_name
                    global_settings.STT_MODEL_NAME = p_data.get("model") or global_settings.STT_MODEL_NAME
                    global_settings.STT_API_KEY = p_data.get("api_key") or global_settings.STT_API_KEY

            # Write to config.yaml
            try:
                global_settings.save_to_yaml()
            except Exception as ey:
                logger.error(f"Failed to sync settings to YAML: {ey}")
            
            logger.info(f"Saving updated global preferences via admin {user_id}")
        else:
            # Normal users: only allow modifying their personal active_provider selection
            # but preserve OLD keys if they somehow try to send new ones (UI should prevent this anyway)
            # Actually, let's just ignore their "providers" map entirely if we want strict admin control
            user.preferences["llm"]["active_provider"] = prefs.llm.get("active_provider")
            user.preferences["tts"]["active_provider"] = prefs.tts.get("active_provider")
            user.preferences["stt"]["active_provider"] = prefs.stt.get("active_provider")
            user.preferences["statuses"] = prefs.statuses or {}
            from sqlalchemy.orm.attributes import flag_modified
            flag_modified(user, "preferences")
            logger.info(f"Saving personal preferences for user {user_id}")

        db.add(user)
        db.commit()
        db.refresh(user)
        
        return schemas.UserPreferences(
            llm=user.preferences.get("llm", {}),
            tts=user.preferences.get("tts", {}),
            stt=user.preferences.get("stt", {}),
            statuses=user.preferences.get("statuses", {})
        )

    @router.get("/me/config/models", response_model=list[schemas.ModelInfoResponse], summary="Get Models for Provider")
    async def get_provider_models(provider_name: str, section: str = "llm"):
        import litellm
        from fastapi.concurrency import run_in_threadpool
        
        def fetch_models():
            try:
                models = litellm.models_by_provider.get(provider_name, [])
                out = []
                for m in models:
                    try:
                        info = litellm.get_model_info(m)
                        if "error" not in info:
                            mode = info.get("mode")
                            is_valid = False
                            if section == "llm":
                                is_valid = mode in ["chat", "text-completion", "custom", None]
                            elif section == "tts":
                                is_valid = mode == "audio_speech"
                            elif section == "stt":
                                is_valid = mode == "audio_transcription"
                            elif section == "image":
                                is_valid = mode == "image_generation"
                            else:
                                is_valid = True

                            if is_valid:
                                out.append({
                                    "model_name": m,
                                    "max_tokens": info.get("max_tokens"),
                                    "max_input_tokens": info.get("max_input_tokens")
                                })
                    except Exception:
                        pass
                return out
            except Exception:
                return []
                
        results = await run_in_threadpool(fetch_models)
        return results

    @router.get("/me/config/providers", response_model=list[str], summary="Get All Valid Providers per Section")
    async def get_all_providers(
        section: str = "llm", 
        configured_only: bool = Query(False, description="If true, only returns providers currently configured in preferences or system defaults"),
        db: Session = Depends(get_db),
        user_id: str = Depends(get_current_user_id)
    ):
        import litellm
        from app.core.providers.factory import get_registered_tts_providers, get_registered_stt_providers
        
        if configured_only:
            # Fetch effective config to see what's actually configured
            user = services.user_service.get_user_by_id(db=db, user_id=user_id)
            # We don't want to recursivly call our own logic, but we can look at system settings + user prefs
            system_prefs = services.user_service.get_system_settings(db)
            user_prefs = user.preferences if user else {}
            
            configured = set()
            # Add from system defaults
            for p in system_prefs.get(section, {}).get("providers", {}).keys():
                configured.add(p)
            # Add from user overrides
            for p in user_prefs.get(section, {}).get("providers", {}).keys():
                configured.add(p)
            
            # If nothing configured, fallback to hardcoded defaults in settings (simulating get_user_config logic)
            if not configured:
                from app.config import settings
                if section == "llm":
                    configured.update(["deepseek", "gemini"])
                elif section == "tts":
                    configured.add(settings.TTS_PROVIDER)
                elif section == "stt":
                    configured.add(settings.STT_PROVIDER)
            
            return sorted(list(configured))

        if section == "llm":
            return ["general"] + [p.value for p in litellm.LlmProviders]
        elif section == "tts":
            return ["general"] + get_registered_tts_providers() + ["openai"]
        elif section == "stt":
            return ["general"] + get_registered_stt_providers() + ["openai"]
        elif section == "image":
            return ["general"] + [p.value for p in litellm.LlmProviders]
        return ["general"] + [p.value for p in litellm.LlmProviders]

    @router.post("/me/config/verify_llm", response_model=schemas.VerifyProviderResponse)
    async def verify_llm(
        req: schemas.VerifyProviderRequest,
        db: Session = Depends(get_db),
        user_id: str = Depends(get_current_user_id)
    ):
        from app.core.providers.factory import get_llm_provider
        if not user_id:
            raise HTTPException(status_code=401, detail="Unauthorized")
        user = services.user_service.get_user_by_id(db=db, user_id=user_id)
        if not user or user.role != "admin":
            raise HTTPException(status_code=403, detail="Forbidden: Admin only")
        actual_key = req.api_key
        try:
            llm_prefs = {}
            user = services.user_service.get_user_by_id(db=db, user_id=user_id)
            if user and user.preferences:
                llm_prefs = user.preferences.get("llm", {}).get("providers", {}).get(req.provider_name, {})
                
            # Handle masked keys by backfilling from stored prefs if needed
            if actual_key and "***" in actual_key:
                actual_key = llm_prefs.get("api_key")
                if not actual_key:
                     # Fallback to system defaults if admin
                     system_prefs = services.user_service.get_system_settings(db)
                     actual_key = system_prefs.get("llm", {}).get("providers", {}).get(req.provider_type or req.provider_name, {}).get("api_key")

            kwargs = {k: v for k, v in llm_prefs.items() if k not in ["api_key", "model"]}
            if req.provider_type:
                kwargs["provider_type"] = req.provider_type
                
            llm = get_llm_provider(
                provider_name=req.provider_name, 
                model_name=req.model or "", 
                api_key_override=actual_key,
                **kwargs
            )
            # LiteLLM check: litellm models are callable
            res = llm("Hello")
            return schemas.VerifyProviderResponse(success=True, message="Connection successful!")
        except Exception as e:
            import logging
            logging.getLogger(__name__).error(f"LLM Verification failed for {req.provider_name} ({req.provider_type}): {e}")
            return schemas.VerifyProviderResponse(success=False, message=str(e))

    @router.post("/me/config/verify_tts", response_model=schemas.VerifyProviderResponse)
    async def verify_tts(
        req: schemas.VerifyProviderRequest,
        db: Session = Depends(get_db),
        user_id: str = Depends(get_current_user_id)
    ):
        from app.core.providers.factory import get_tts_provider
        from app.config import settings
        import logging
        logger = logging.getLogger(__name__)

        if not user_id:
            raise HTTPException(status_code=401, detail="Unauthorized")
        user = services.user_service.get_user_by_id(db=db, user_id=user_id)
        if not user or user.role != "admin":
            raise HTTPException(status_code=403, detail="Forbidden: Admin only")
            
        actual_key = req.api_key
        try:
            tts_prefs = user.preferences.get("tts", {}).get("providers", {}).get(req.provider_name, {}) if user.preferences else {}
            
            # Key resolution: Masked keys should be replaced with real ones from DB or system config
            if not actual_key or "***" in str(actual_key):
                actual_key = tts_prefs.get("api_key")
                if not actual_key or "***" in str(actual_key):
                    # Try system settings
                    system_prefs = services.user_service.get_system_settings(db)
                    actual_key = system_prefs.get("tts", {}).get("providers", {}).get(req.provider_type or req.provider_name, {}).get("api_key")
                    # Final fallback to settings.py constants
                    if not actual_key: actual_key = settings.TTS_API_KEY or settings.GEMINI_API_KEY

            logger.info(f"verify_tts: instance={req.provider_name}, type={req.provider_type}, model={req.model}")

            kwargs = {k: v for k, v in tts_prefs.items() if k not in ["api_key", "model", "voice"]}
            if req.provider_type:
                kwargs["provider_type"] = req.provider_type

            provider = get_tts_provider(
                provider_name=req.provider_name,
                api_key=actual_key,
                model_name=req.model or "",
                voice_name=req.voice or "",
                **kwargs
            )
            await provider.generate_speech("Test")
            return schemas.VerifyProviderResponse(success=True, message="Connection successful!")
        except Exception as e:
            logger.error(f"TTS verification failed for {req.provider_name}: {e}")
            return schemas.VerifyProviderResponse(success=False, message=str(e))

    @router.post("/me/config/verify_stt", response_model=schemas.VerifyProviderResponse)
    async def verify_stt(
        req: schemas.VerifyProviderRequest,
        db: Session = Depends(get_db),
        user_id: str = Depends(get_current_user_id)
    ):
        from app.core.providers.factory import get_stt_provider
        from app.config import settings
        import logging
        logger = logging.getLogger(__name__)

        if not user_id:
            raise HTTPException(status_code=401, detail="Unauthorized")
        user = services.user_service.get_user_by_id(db=db, user_id=user_id)
        if not user or user.role != "admin":
            raise HTTPException(status_code=403, detail="Forbidden: Admin only")
            
        actual_key = req.api_key
        try:
            stt_prefs = user.preferences.get("stt", {}).get("providers", {}).get(req.provider_name, {}) if user.preferences else {}
            
            if not actual_key or "***" in str(actual_key):
                actual_key = stt_prefs.get("api_key")
                if not actual_key or "***" in str(actual_key):
                    system_prefs = services.user_service.get_system_settings(db)
                    actual_key = system_prefs.get("stt", {}).get("providers", {}).get(req.provider_type or req.provider_name, {}).get("api_key")
                    if not actual_key: actual_key = settings.STT_API_KEY or settings.GEMINI_API_KEY

            kwargs = {k: v for k, v in stt_prefs.items() if k not in ["api_key", "model"]}
            if req.provider_type:
                kwargs["provider_type"] = req.provider_type

            provider = get_stt_provider(
                provider_name=req.provider_name,
                api_key=actual_key,
                model_name=req.model or "",
                **kwargs
            )
            # Minimal STT check: factory init is usually enough to catch invalid credentials for SDK-based providers
            return schemas.VerifyProviderResponse(success=True, message="Provider initialized. Full transcription test requires audio payload.")
        except Exception as e:
            logger.error(f"STT verification failed for {req.provider_name}: {e}")
            return schemas.VerifyProviderResponse(success=False, message=str(e))

    @router.post("/logout", summary="Log Out the Current User")
    async def logout():
        """
        Simulates a user logout. In a real application, this would clear the session token or cookie.
        """
        return {"message": "Logged out successfully"}

    @router.get("/me/config/export", summary="Export Configurations to YAML")
    async def export_user_config_yaml(
        db: Session = Depends(get_db),
        user_id: str = Depends(get_current_user_id)
    ):
        """Exports the effective user configuration as a YAML file (Admin only)."""
        from fastapi.responses import PlainTextResponse
        if not user_id:
            raise HTTPException(status_code=401, detail="Unauthorized")
        user = services.user_service.get_user_by_id(db=db, user_id=user_id)
        if not user or user.role != "admin":
            raise HTTPException(status_code=403, detail="Forbidden: Admin only")
        
        prefs_dict = user.preferences or {}
        from app.config import settings
        import yaml
        
        llm_prefs = prefs_dict.get("llm", {})
        tts_prefs = prefs_dict.get("tts", {})
        stt_prefs = prefs_dict.get("stt", {})

        llm_providers_export = {}
        user_providers = llm_prefs.get("providers", {})
        if not user_providers:
            # Fallback to system defaults if no user config exists
            llm_providers_export = {
                "deepseek_api_key": settings.DEEPSEEK_API_KEY,
                "deepseek_model_name": settings.DEEPSEEK_MODEL_NAME,
                "gemini_api_key": settings.GEMINI_API_KEY,
                "gemini_model_name": settings.GEMINI_MODEL_NAME,
                "openai_api_key": settings.OPENAI_API_KEY
            }
        else:
            for p, p_data in user_providers.items():
                llm_providers_export[f"{p}_api_key"] = p_data.get("api_key")
                llm_providers_export[f"{p}_model_name"] = p_data.get("model")

        def get_provider_export(section_prefs, fallback_provider, fallback_model, fallback_api_key, fallback_voice=None):
            active_p = section_prefs.get("active_provider")
            providers = section_prefs.get("providers", {})
            if active_p and active_p in providers:
                p_data = providers[active_p]
                return {
                    "provider": active_p,
                    "model_name": p_data.get("model"),
                    "voice_name": p_data.get("voice"),
                    "api_key": p_data.get("api_key")
                }
            # Fallback to system settings
            return {
                "provider": fallback_provider,
                "model_name": fallback_model,
                "voice_name": fallback_voice,
                "api_key": fallback_api_key
            }

        # Layer 2 (Day 2) Export: Only LLM, TTS, STT
        yaml_data = {
            "llm_providers": {
                "providers": user_providers or settings.LLM_PROVIDERS
            },
            "tts_provider": get_provider_export(tts_prefs, settings.TTS_PROVIDER, settings.TTS_MODEL_NAME, settings.TTS_API_KEY, settings.TTS_VOICE_NAME),
            "stt_provider": get_provider_export(stt_prefs, settings.STT_PROVIDER, settings.STT_MODEL_NAME, settings.STT_API_KEY)
        }
        
        # Filter out None values recursively
        def remove_none(obj):
            if isinstance(obj, dict):
                return {k: remove_none(v) for k, v in obj.items() if v is not None}
            return obj
            
        clean_yaml_data = remove_none(yaml_data)
        yaml_str = yaml.dump(clean_yaml_data, sort_keys=False, default_flow_style=False)
        
        return PlainTextResponse(
            content=yaml_str, 
            media_type="application/x-yaml",
            headers={"Content-Disposition": "attachment; filename=\"day2_config.yaml\""}
        )

    @router.post("/me/config/import", response_model=schemas.UserPreferences, summary="Import Configurations from YAML")
    async def import_user_config_yaml(
        file: UploadFile = File(...),
        db: Session = Depends(get_db),
        user_id: str = Depends(get_current_user_id)
    ):
        """Imports user configuration from a YAML file."""
        if not user_id:
            raise HTTPException(status_code=401, detail="Unauthorized")
        user = services.user_service.get_user_by_id(db=db, user_id=user_id)
        if not user:
            raise HTTPException(status_code=404, detail="User not found")

        content = await file.read()
        try:
            import yaml
            data = yaml.safe_load(content)
        except Exception as e:
            raise HTTPException(status_code=400, detail=f"Invalid YAML file: {e}")

        # Reverse mapping: YAML -> UserPreferences structure
        new_llm = { "providers": {}, "active_provider": None }
        new_tts = { "providers": {}, "active_provider": None }
        new_stt = { "providers": {}, "active_provider": None }

        # --- LLM ---
        llm_data = data.get("llm_providers", {})
        if "providers" in llm_data: # Structured
            new_llm["providers"] = llm_data["providers"]
        else: # Flattened (as exported)
            for k, v in llm_data.items():
                if k.endswith("_api_key"):
                    p = k.replace("_api_key", "")
                    if p not in new_llm["providers"]: new_llm["providers"][p] = {}
                    new_llm["providers"][p]["api_key"] = v
                elif k.endswith("_model_name"):
                    p = k.replace("_model_name", "")
                    if p not in new_llm["providers"]: new_llm["providers"][p] = {}
                    new_llm["providers"][p]["model"] = v
        
        if new_llm["providers"]:
            new_llm["active_provider"] = next(iter(new_llm["providers"]), None)

        # --- TTS ---
        tts_data = data.get("tts_provider", {})
        if tts_data:
            p = tts_data.get("provider") or "google_gemini"
            new_tts["active_provider"] = p
            new_tts["providers"][p] = {
                "api_key": tts_data.get("api_key"),
                "model": tts_data.get("model_name"),
                "voice": tts_data.get("voice_name")
            }

        # --- STT ---
        stt_data = data.get("stt_provider", {})
        if stt_data:
            p = stt_data.get("provider") or "google_gemini"
            new_stt["active_provider"] = p
            new_stt["providers"][p] = {
                "api_key": stt_data.get("api_key"),
                "model": stt_data.get("model_name")
            }

        user.preferences = {
            "llm": new_llm,
            "tts": new_tts,
            "stt": new_stt,
            "statuses": {}
        }
        from sqlalchemy.orm.attributes import flag_modified
        flag_modified(user, "preferences")
        
        # --- Day 2 Sync ---
        from app.config import settings as global_settings
        if new_llm.get("providers"):
            global_settings.LLM_PROVIDERS.update(new_llm["providers"])
        if new_tts.get("active_provider"):
            p = new_tts["active_provider"]
            p_data = new_tts["providers"].get(p, {})
            if p_data:
                global_settings.TTS_PROVIDER = p
                global_settings.TTS_MODEL_NAME = p_data.get("model") or global_settings.TTS_MODEL_NAME
                global_settings.TTS_VOICE_NAME = p_data.get("voice") or global_settings.TTS_VOICE_NAME
                global_settings.TTS_API_KEY = p_data.get("api_key") or global_settings.TTS_API_KEY
        if new_stt.get("active_provider"):
            p = new_stt["active_provider"]
            p_data = new_stt["providers"].get(p, {})
            if p_data:
                global_settings.STT_PROVIDER = p
                global_settings.STT_MODEL_NAME = p_data.get("model") or global_settings.STT_MODEL_NAME
                global_settings.STT_API_KEY = p_data.get("api_key") or global_settings.STT_API_KEY
                
        try:
            global_settings.save_to_yaml()
        except Exception as ey:
            logger.error(f"Failed to sync settings to YAML on import: {ey}")

        db.add(user)
        db.commit()
        db.refresh(user)
        return schemas.UserPreferences(llm=user.preferences.get("llm", {}), tts=user.preferences.get("tts", {}), stt=user.preferences.get("stt", {}))

    # --- NEW ADMIN ROUTES ---

    @router.get("/admin/users", response_model=list[schemas.UserProfile], summary="List All Users (Admin Only)")
    async def admin_list_users(
        db: Session = Depends(get_db),
        user_id: str = Depends(get_current_user_id)
    ):
        """Returns a list of all registered users in the system."""
        if not user_id:
            raise HTTPException(status_code=401, detail="Unauthorized")
        user = services.user_service.get_user_by_id(db=db, user_id=user_id)
        if not user or user.role != "admin":
            raise HTTPException(status_code=403, detail="Forbidden: Admin only")
        
        users = services.user_service.get_all_users(db)
        response = []
        for u in users:
            p = schemas.UserProfile.model_validate(u)
            if u.group:
                p.group_name = u.group.name
            response.append(p)
        return response

    @router.put("/admin/users/{uid}/role", response_model=schemas.UserProfile, summary="Update User Role (Admin Only)")
    async def admin_update_role(
        uid: str,
        role_req: schemas.UserRoleUpdate,
        db: Session = Depends(get_db),
        user_id: str = Depends(get_current_user_id)
    ):
        """Updates a user's role. Prevents demoting the last administrator."""
        if not user_id:
            raise HTTPException(status_code=401, detail="Unauthorized")
        user = services.user_service.get_user_by_id(db=db, user_id=user_id)
        if not user or user.role != "admin":
            raise HTTPException(status_code=403, detail="Forbidden: Admin only")
        
        success = services.user_service.update_user_role(db, uid, role_req.role)
        if not success:
            raise HTTPException(status_code=400, detail="Failed to update role. Maybe this is the last admin?")
        
        return services.user_service.get_user_by_id(db, uid)

    @router.put("/admin/users/{uid}/group", response_model=schemas.UserProfile, summary="Update User Group (Admin Only)")
    async def admin_update_user_group(
        uid: str,
        group_req: schemas.UserGroupUpdate,
        db: Session = Depends(get_db),
        user_id: str = Depends(get_current_user_id)
    ):
        """Assigns a user to a group."""
        if not user_id:
            raise HTTPException(status_code=401, detail="Unauthorized")
        user = services.user_service.get_user_by_id(db=db, user_id=user_id)
        if not user or user.role != "admin":
            raise HTTPException(status_code=403, detail="Forbidden: Admin only")
        
        success = services.user_service.assign_user_to_group(db, uid, group_req.group_id)
        if not success:
            raise HTTPException(status_code=404, detail="User or group not found")
        
        return services.user_service.get_user_by_id(db, uid)

    @router.get("/admin/groups", response_model=list[schemas.GroupInfo], summary="List All Groups (Admin Only)")
    async def admin_list_groups(
        db: Session = Depends(get_db),
        user_id: str = Depends(get_current_user_id)
    ):
        """Returns all existing groups."""
        if not user_id:
            raise HTTPException(status_code=401, detail="Unauthorized")
        user = services.user_service.get_user_by_id(db=db, user_id=user_id)
        if not user or user.role != "admin":
            raise HTTPException(status_code=403, detail="Forbidden: Admin only")
        
        # Explicitly convert to Pydantic models within the session scope
        # to prevent SQLAlchemy lazy-loading issues in async context.
        groups = services.user_service.get_all_groups(db)
        return [schemas.GroupInfo.model_validate(g) for g in groups]

    @router.post("/admin/groups", response_model=schemas.GroupInfo, summary="Create Group (Admin Only)")
    async def admin_create_group(
        group_req: schemas.GroupCreate,
        db: Session = Depends(get_db),
        user_id: str = Depends(get_current_user_id)
    ):
        """Creates a new group."""
        if not user_id:
            raise HTTPException(status_code=401, detail="Unauthorized")
        user = services.user_service.get_user_by_id(db=db, user_id=user_id)
        if not user or user.role != "admin":
            raise HTTPException(status_code=403, detail="Forbidden: Admin only")
        
        group = services.user_service.create_group(db, group_req.name, group_req.description, group_req.policy)
        if group is None:
            raise HTTPException(status_code=409, detail=f"A group named '{group_req.name}' already exists. Please choose a unique name.")
        return schemas.GroupInfo.model_validate(group)

    @router.put("/admin/groups/{gid}", response_model=schemas.GroupInfo, summary="Update Group (Admin Only)")
    async def admin_update_group(
        gid: str,
        group_req: schemas.GroupUpdate,
        db: Session = Depends(get_db),
        user_id: str = Depends(get_current_user_id)
    ):
        """Updates a group's metadata or policy."""
        if not user_id:
            raise HTTPException(status_code=401, detail="Unauthorized")
        user = services.user_service.get_user_by_id(db=db, user_id=user_id)
        if not user or user.role != "admin":
            raise HTTPException(status_code=403, detail="Forbidden: Admin only")
        
        # The 'ungrouped' group cannot be renamed — only its policy can be updated
        if gid == "ungrouped" and group_req.name and group_req.name.strip().lower() != "ungrouped":
            raise HTTPException(status_code=403, detail="The default 'Ungrouped' group cannot be renamed.")
        
        group = services.user_service.update_group(db, gid, group_req.name, group_req.description, group_req.policy)
        if group is None:
            raise HTTPException(status_code=404, detail="Group not found")
        if group is False:
            raise HTTPException(status_code=409, detail=f"A group named '{group_req.name}' already exists. Please choose a unique name.")
        return schemas.GroupInfo.model_validate(group)

    @router.delete("/admin/groups/{gid}", summary="Delete Group (Admin Only)")
    async def admin_delete_group(
        gid: str,
        db: Session = Depends(get_db),
        user_id: str = Depends(get_current_user_id)
    ):
        """Deletes a group. Users are moved back to 'ungrouped'."""
        if not user_id:
            raise HTTPException(status_code=401, detail="Unauthorized")
        user = services.user_service.get_user_by_id(db=db, user_id=user_id)
        if not user or user.role != "admin":
            raise HTTPException(status_code=403, detail="Forbidden: Admin only")
        
        # Cannot delete the system default group
        if gid == "ungrouped":
            raise HTTPException(status_code=403, detail="The default 'Ungrouped' group cannot be deleted.")
        
        success = services.user_service.delete_group(db, gid)
        if not success:
            raise HTTPException(status_code=400, detail="Failed to delete group.")
        
        return {"message": "Group deleted successfully"}

    return router