Newer
Older
cortex-hub / ai-hub / app / core / services / auth.py
import httpx
import jwt
import urllib.parse
from fastapi import HTTPException
import logging
from app.config import settings
from typing import Optional, Dict, Any

logger = logging.getLogger(__name__)

class AuthService:
    def __init__(self, services):
        self.services = services

    def get_oidc_urls(self) -> Dict[str, str]:
        server_url = settings.OIDC_SERVER_URL.rstrip("/")
        return {
            "auth": f"{server_url}/auth",
            "token": f"{server_url}/token",
            "userinfo": f"{server_url}/userinfo"
        }

    def generate_login_url(self, frontend_callback_uri: Optional[str]) -> str:
        oidc_urls = self.get_oidc_urls()
        params = {
            "response_type": "code",
            "scope": "openid profile email",
            "client_id": settings.OIDC_CLIENT_ID,
            "redirect_uri": settings.OIDC_REDIRECT_URI,
            "state": frontend_callback_uri or ""
        }
        return f"{oidc_urls['auth']}?{urllib.parse.urlencode(params)}"

    async def handle_callback(self, code: str, db) -> Dict[str, Any]:
        oidc_urls = self.get_oidc_urls()
        token_data = {
            "grant_type": "authorization_code",
            "code": code,
            "redirect_uri": settings.OIDC_REDIRECT_URI,
            "client_id": settings.OIDC_CLIENT_ID,
            "client_secret": settings.OIDC_CLIENT_SECRET,
        }
        
        try:
            async with httpx.AsyncClient() as client:
                token_response = await client.post(oidc_urls['token'], data=token_data, timeout=30.0)
                token_response.raise_for_status()
                response_json = token_response.json()
        except httpx.HTTPStatusError as e:
            logger.error(f"OIDC Token exchange failed with status {e.response.status_code}: {e.response.text}")
            raise HTTPException(status_code=500, detail=f"OIDC Token exchange failed: {e.response.text}")
        except httpx.RequestError as e:
            logger.error(f"OIDC Token exchange request error: {e}")
            raise HTTPException(status_code=500, detail=f"Failed to communicate with OIDC provider: {e}")

        id_token = response_json.get("id_token")
        if not id_token:
            raise HTTPException(status_code=400, detail="Failed to get ID token from OIDC provider.")

        try:
            decoded_id_token = jwt.decode(id_token, options={"verify_signature": False})
        except jwt.DecodeError as e:
            raise HTTPException(status_code=400, detail="Failed to decode ID token from OIDC provider.")
        
        oidc_id = decoded_id_token.get("sub")
        email = decoded_id_token.get("email")
        username = decoded_id_token.get("name") or decoded_id_token.get("preferred_username") or email
        
        if not all([oidc_id, email]):
            raise HTTPException(status_code=400, detail="Essential user data missing from ID token (sub and email required).")

        user_id = self.services.user_service.save_user(
            db=db,
            oidc_id=oidc_id,
            email=email,
            username=username
        )
        return {"user_id": user_id}