Newer
Older
cortex-hub / ai-hub / app / core / services / auth.py
import httpx
import jwt
import urllib.parse
import time
from fastapi import HTTPException
import logging
from sqlalchemy.orm import Session
from app.config import settings
from app.db import models
from typing import Optional, Dict, Any

logger = logging.getLogger(__name__)

class AuthService:
    _discovery_cache = None
    _discovery_expiry = 0

    def __init__(self, services):
        self.services = services

    async def get_discovery(self) -> Dict[str, Any]:
        if self._discovery_cache and time.time() < self._discovery_expiry:
            return self._discovery_cache

        discovery_url = f"{settings.OIDC_SERVER_URL.rstrip('/')}/.well-known/openid-configuration"
        try:
            async with httpx.AsyncClient() as client:
                response = await client.get(discovery_url, timeout=10.0)
                response.raise_for_status()
                self._discovery_cache = response.json()
                self._discovery_expiry = time.time() + 3600 # Cache for 1 hour
                return self._discovery_cache
        except Exception as e:
            logger.error(f"Failed to fetch OIDC discovery from {discovery_url}: {e}")
            # Fallback to defaults if discovery fails, for resiliency
            server_url = settings.OIDC_SERVER_URL.rstrip("/")
            return {
                "authorization_endpoint": f"{server_url}/auth",
                "token_endpoint": f"{server_url}/token",
                "jwks_uri": f"{server_url}/keys",
                "userinfo_endpoint": f"{server_url}/userinfo"
            }

    async def generate_login_url(self, frontend_callback_uri: Optional[str]) -> str:
        discovery = await self.get_discovery()
        params = {
            "response_type": "code",
            "scope": "openid profile email",
            "client_id": settings.OIDC_CLIENT_ID,
            "redirect_uri": settings.OIDC_REDIRECT_URI,
            "state": frontend_callback_uri or ""
        }
        auth_endpoint = discovery.get("authorization_endpoint")
        return f"{auth_endpoint}?{urllib.parse.urlencode(params)}"

    def create_session_token(self, user_id: str) -> str:
        """
        Generates a short-lived session token (JWT) for local/password-based users.
        Used to maintain security parity with OIDC ID tokens.
        """
        from app.config import settings
        import jwt
        from datetime import datetime, timedelta
        
        payload = {
            "sub": user_id,
            "iat": datetime.utcnow(),
            "exp": datetime.utcnow() + timedelta(hours=24), # 24 hour local session
            "iss": "cortex-hub-internal"
        }
        return jwt.encode(payload, settings.SECRET_KEY, algorithm="HS256")

    async def verify_id_token(self, id_token: str, db: Session) -> models.User:
        """
        Verifies an OIDC ID token (JWT), syncs the user record, and returns the User object.
        """
        discovery = await self.get_discovery()
        
        # 1. Fetch JWKS (Public Keys) to verify signature
        jwks_url = discovery.get("jwks_uri")
        try:
            # 2. Use PyJWKClient to fetch the signing key and verify the JWT
            jwks_client = jwt.PyJWKClient(jwks_url)
            signing_key = jwks_client.get_signing_key_from_jwt(id_token)
            
            decoded = jwt.decode(
                id_token, 
                signing_key.key, 
                algorithms=["RS256"], 
                audience=settings.OIDC_CLIENT_ID,
                options={"verify_aud": True if settings.OIDC_CLIENT_ID else False}
            )
        except Exception as e:
            logger.warning(f"OIDC Token verification failed: {e}")
            raise HTTPException(status_code=401, detail=f"Invalid OIDC token: {str(e)}")

        # 3. Sync user with local database
        email = decoded.get("email")
        sub = decoded.get("sub")
        if not email or not sub:
            raise HTTPException(status_code=400, detail="OIDC token missing required claims (email, sub).")
        
        user_id = self.services.user_service.sync_oidc_user(
            db=db,
            email=email,
            external_id=sub,
            full_name=decoded.get("name"),
            avatar_url=decoded.get("picture")
        )
        
        user = self.services.user_service.get_user_by_id(db, user_id)
        if not user:
            raise HTTPException(status_code=500, detail="Failed to retrieve user after sync.")
        return user

    async def handle_callback(self, code: str, db: Session) -> Dict[str, Any]:
        discovery = await self.get_discovery()
        token_endpoint = discovery.get("token_endpoint")
        token_data = {
            "grant_type": "authorization_code",
            "code": code,
            "redirect_uri": settings.OIDC_REDIRECT_URI,
            "client_id": settings.OIDC_CLIENT_ID,
            "client_secret": settings.OIDC_CLIENT_SECRET,
        }
        
        try:
            async with httpx.AsyncClient() as client:
                token_response = await client.post(token_endpoint, data=token_data, timeout=30.0)
                token_response.raise_for_status()
                response_json = token_response.json()
                id_token = response_json.get("id_token")

            user = await self.verify_id_token(id_token, db)
            return {"user_id": user.id, "linked": False, "id_token": id_token}

        except httpx.HTTPStatusError as e:
            logger.error(f"OIDC Token exchange failed with status {e.response.status_code}: {e.response.text}")
            raise HTTPException(status_code=500, detail=f"OIDC Token exchange failed: {e.response.text}")
        except httpx.RequestError as e:
            logger.error(f"OIDC Token exchange request error: {e}")
            raise HTTPException(status_code=500, detail=f"Failed to communicate with OIDC provider: {e}")
        except HTTPException:
            raise
        except Exception as e:
            logger.error(f"Unexpected error during callback: {e}")
            raise HTTPException(status_code=500, detail="Internal server error during authentication.")