Newer
Older
cortex-hub / ai-hub / app / core / services / user.py
from typing import Optional, Union
import uuid
from datetime import datetime
from sqlalchemy.orm import Session
from sqlalchemy.exc import SQLAlchemyError

# Assuming the models are in a file named `models.py` in the `app.db` directory
from app.db import models

class UserService:
    def __init__(self):
        pass

    def save_user(self, db: Session, oidc_id: str, email: str, username: str) -> str:
        """
        Saves or updates a user record based on their OIDC ID.
        If a user with this OIDC ID exists, it returns their existing ID.
        Otherwise, it creates a new user record.
        The first user to register will be granted the 'admin' role.
        """
        try:
            # Check if a user with this OIDC ID already exists
            existing_user = db.query(models.User).filter(models.User.oidc_id == oidc_id).first()

            if existing_user:
                # Update the user's information and login activity
                existing_user.email = email
                existing_user.username = username
                existing_user.last_login_at = datetime.utcnow()
                
                # Check if user should be promoted to admin based on config
                from app.config import settings
                if email in settings.SUPER_ADMINS and existing_user.role != "admin":
                    existing_user.role = "admin"
                
                db.commit()
                return existing_user.id
            else:
                # Ensure default group exists
                default_group = self.get_or_create_default_group(db)
                
                # Determine role based on SUPER_ADMINS or fallback to user
                from app.config import settings
                role = "admin" if email in settings.SUPER_ADMINS else "user"

                # Create a new user record
                new_user = models.User(
                    id=str(uuid.uuid4()),  # Generate a unique ID for the user
                    oidc_id=oidc_id,
                    email=email,
                    username=username,
                    role=role,
                    group_id=default_group.id,
                    created_at=datetime.utcnow(),
                    last_login_at=datetime.utcnow()
                )
                db.add(new_user)
                db.commit()
                db.refresh(new_user)
                return new_user.id
        except SQLAlchemyError as e:
            db.rollback()
            raise
        
    def get_user_by_id(self, db: Session, user_id: str) -> Optional[models.User]:
        """
        Retrieves a user record by their unique ID.
        Returns the User object if found, otherwise None.
        """
        try:
            # Query the database for a user with the given ID
            user = db.query(models.User).filter(models.User.id == user_id).first()
            return user
        except SQLAlchemyError as e:
            # Log the error and return None in case of a database issue.
            print(f"Database error while fetching user by ID: {e}")
            return None

    def get_all_users(self, db: Session) -> list[models.User]:
        """Retrieves all registered users."""
        try:
            return db.query(models.User).all()
        except SQLAlchemyError as e:
            print(f"Database error while fetching all users: {e}")
            return []

    def update_user_role(self, db: Session, user_id: str, new_role: str) -> bool:
        """Updates a user's role. Ensures at least one admin exists."""
        try:
            user = db.query(models.User).filter(models.User.id == user_id).first()
            if not user:
                return False
            
            # If trying to demote an admin, check if they are the only one
            if user.role == "admin" and new_role != "admin":
                admin_count = db.query(models.User).filter(models.User.role == "admin").count()
                if admin_count <= 1:
                    # Cannot demote the last admin
                    return False
            
            user.role = new_role
            db.commit()
            return True
        except SQLAlchemyError as e:
            db.rollback()
            print(f"Database error while updating user role: {e}")
            return False

    def get_system_settings(self, db: Session) -> dict:
        """Retrieves global AI provider settings from the first super admin found."""
        try:
            from app.config import settings
            super_admin_email = settings.SUPER_ADMINS[0] if settings.SUPER_ADMINS else None
            if super_admin_email:
                admin_user = db.query(models.User).filter(models.User.email == super_admin_email).first()
                if admin_user and admin_user.preferences:
                    return admin_user.preferences
            return {}
        except SQLAlchemyError:
            return {}

    # --- Group Management Methods ---

    def get_or_create_default_group(self, db: Session) -> models.Group:
        """Ensures the 'ungrouped' default group exists."""
        default_group = db.query(models.Group).filter(models.Group.id == "ungrouped").first()
        if not default_group:
            default_group = models.Group(
                id="ungrouped",
                name="Ungrouped",
                description="Default group for new users",
                policy={}
            )
            db.add(default_group)
            db.commit()
            db.refresh(default_group)
        return default_group

    def get_all_groups(self, db: Session) -> list[models.Group]:
        """Returns all existing groups, without lazy-loading user relationships."""
        from sqlalchemy.orm import noload
        return db.query(models.Group).options(noload(models.Group.users)).all()

    def get_group_by_id(self, db: Session, group_id: str) -> Optional[models.Group]:
        """Fetches a group by ID, without lazy-loading user relationships."""
        from sqlalchemy.orm import noload
        return db.query(models.Group).options(noload(models.Group.users)).filter(models.Group.id == group_id).first()

    def create_group(self, db: Session, name: str, description: str = None, policy: dict = None) -> Optional[models.Group]:
        """Creates a new user group. Returns None if a group with the same name already exists."""
        existing = db.query(models.Group).filter(
            models.Group.name.ilike(name.strip())
        ).first()
        if existing:
            return None  # Signals a name conflict
        group = models.Group(
            id=str(uuid.uuid4()),
            name=name.strip(),
            description=description,
            policy=policy or {}
        )
        db.add(group)
        db.commit()
        db.refresh(group)
        return group

    def update_group(self, db: Session, group_id: str, name: str = None, description: str = None, policy: dict = None) -> Optional[models.Group]:
        """Updates group metadata or policy. Returns False (bool) if name conflicts with another group."""
        group = self.get_group_by_id(db, group_id)
        if not group:
            return None
        if name:
            name = name.strip()
            # Check for name conflict with a DIFFERENT group
            conflict = db.query(models.Group).filter(
                models.Group.name.ilike(name),
                models.Group.id != group_id
            ).first()
            if conflict:
                return False  # Signals a name conflict (distinct from None = not found)
            group.name = name
        if description is not None: group.description = description
        if policy is not None: group.policy = policy
        db.commit()
        db.refresh(group)
        return group

    def delete_group(self, db: Session, group_id: str) -> bool:
        """Deletes a group. Moves its users to 'ungrouped'."""
        if group_id == "ungrouped":
            return False # Cannot delete the default group
        group = self.get_group_by_id(db, group_id)
        if not group:
            return False
        
        default_group = self.get_or_create_default_group(db)
        # Move users
        db.query(models.User).filter(models.User.group_id == group_id).update({"group_id": default_group.id})
        
        db.delete(group)
        db.commit()
        return True

    def assign_user_to_group(self, db: Session, user_id: str, group_id: str) -> bool:
        """Assigns a user to a group."""
        user = self.get_user_by_id(db, user_id)
        group = self.get_group_by_id(db, group_id)
        if not user or not group:
            return False
        user.group_id = group_id
        db.commit()
        return True

# --- Framework-dependent helper functions ---
# These functions are placeholders and would need to be integrated with your
# specific web framework (e.g., FastAPI, Flask, Django).

def login_required(f):
    """
    A decorator to protect API endpoints and web pages.
    It ensures a user is logged in and is a registered (non-anonymous) user.
    If not, it redirects them to the login page.
    This is a generic implementation. You would replace the logic inside
    with code specific to your web framework's authentication system.
    """
    async def wrapper(*args, **kwargs):
        # Placeholder logic: Check for user in the request context
        # For example, in FastAPI, you might use Depends(get_current_user)
        # In Flask, you might use session or current_user
        user_id = kwargs.get("user_id")  # Assuming user_id is passed in via a dependency
        if not user_id:
            # Depending on the framework, this would return an Unauthorized error or a redirect
            # For example: raise HTTPException(status_code=401, detail="Unauthorized")
            pass
        return await f(*args, **kwargs)
    return wrapper


def get_current_user_id() -> Optional[str]:
    """
    A helper function to get the current user's ID from the session.
    This is a placeholder and needs to be implemented with your framework's
    session/authentication management system.
    """
    # Placeholder logic
    # Example for FastAPI:
    # from fastapi import Depends, Request
    # from app.auth import get_current_user
    #
    # return get_current_user(request).id

    # For now, we return None as a generic placeholder
    return None