import httpx
import jwt
import urllib.parse
import time
from fastapi import HTTPException
import logging
from app.config import settings
from typing import Optional, Dict, Any
logger = logging.getLogger(__name__)
class AuthService:
_discovery_cache = None
_discovery_expiry = 0
def __init__(self, services):
self.services = services
async def get_discovery(self) -> Dict[str, Any]:
if self._discovery_cache and time.time() < self._discovery_expiry:
return self._discovery_cache
discovery_url = f"{settings.OIDC_SERVER_URL.rstrip('/')}/.well-known/openid-configuration"
try:
async with httpx.AsyncClient() as client:
response = await client.get(discovery_url, timeout=10.0)
response.raise_for_status()
self._discovery_cache = response.json()
self._discovery_expiry = time.time() + 3600 # Cache for 1 hour
return self._discovery_cache
except Exception as e:
logger.error(f"Failed to fetch OIDC discovery from {discovery_url}: {e}")
# Fallback to defaults if discovery fails, for resiliency
server_url = settings.OIDC_SERVER_URL.rstrip("/")
return {
"authorization_endpoint": f"{server_url}/auth",
"token_endpoint": f"{server_url}/token",
"jwks_uri": f"{server_url}/keys",
"userinfo_endpoint": f"{server_url}/userinfo"
}
async def generate_login_url(self, frontend_callback_uri: Optional[str]) -> str:
discovery = await self.get_discovery()
params = {
"response_type": "code",
"scope": "openid profile email",
"client_id": settings.OIDC_CLIENT_ID,
"redirect_uri": settings.OIDC_REDIRECT_URI,
"state": frontend_callback_uri or ""
}
auth_endpoint = discovery.get("authorization_endpoint")
return f"{auth_endpoint}?{urllib.parse.urlencode(params)}"
async def handle_callback(self, code: str, db) -> Dict[str, Any]:
discovery = await self.get_discovery()
token_endpoint = discovery.get("token_endpoint")
token_data = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": settings.OIDC_REDIRECT_URI,
"client_id": settings.OIDC_CLIENT_ID,
"client_secret": settings.OIDC_CLIENT_SECRET,
}
try:
async with httpx.AsyncClient() as client:
token_response = await client.post(token_endpoint, data=token_data, timeout=30.0)
token_response.raise_for_status()
response_json = token_response.json()
id_token = response_json.get("id_token")
except httpx.HTTPStatusError as e:
logger.error(f"OIDC Token exchange failed with status {e.response.status_code}: {e.response.text}")
raise HTTPException(status_code=500, detail=f"OIDC Token exchange failed: {e.response.text}")
except httpx.RequestError as e:
logger.error(f"OIDC Token exchange request error: {e}")
raise HTTPException(status_code=500, detail=f"Failed to communicate with OIDC provider: {e}")
# 1. Fetch JWKS (Public Keys) to verify signature
jwks_url = discovery.get("jwks_uri")
try:
async with httpx.AsyncClient() as client:
jwks_response = await client.get(jwks_url, timeout=10.0)
jwks_response.raise_for_status()
jwks = jwks_response.json()
except Exception as e:
logger.error(f"Failed to fetch JWKS from {jwks_url}: {e}")
raise HTTPException(status_code=500, detail="Failed to verify identity: Identity provider keys unreachable.")
# 2. Decode and Verify Signature
try:
# We use the 'sub' and 'email' as primary identity
# Enforce signature verification, audience, and issuer checks
# Note: PyJWT's PyJWKClient can automate this, but here we use a lower-level
# approach to work within the existing generic JWT library constraints.
jwk_set = jwt.PyJWKSet.from_dict(jwks)
sh = jwt.get_unverified_header(id_token)
key = jwk_set[sh["kid"]]
decoded_id_token = jwt.decode(
id_token,
key.key,
algorithms=["RS256"],
audience=settings.OIDC_CLIENT_ID,
issuer=settings.OIDC_SERVER_URL.rstrip("/")
)
except jwt.PyJWTError as e:
logger.error(f"JWT Verification failed: {e}")
raise HTTPException(status_code=401, detail=f"Invalid authentication token: {str(e)}")
oidc_id = decoded_id_token.get("sub")
email = decoded_id_token.get("email")
username = decoded_id_token.get("name") or decoded_id_token.get("preferred_username") or email
if not all([oidc_id, email]):
raise HTTPException(status_code=400, detail="Essential user data missing from ID token (sub and email required).")
user_id, linked = self.services.user_service.save_user(
db=db,
oidc_id=oidc_id,
email=email,
username=username
)
return {"user_id": user_id, "linked": linked}