import httpx
import jwt
import urllib.parse
import time
from fastapi import HTTPException
import logging
from sqlalchemy.orm import Session
from app.config import settings
from app.db import models
from typing import Optional, Dict, Any
logger = logging.getLogger(__name__)
class AuthService:
_discovery_cache = None
_discovery_expiry = 0
def __init__(self, services):
self.services = services
async def get_discovery(self) -> Dict[str, Any]:
if self._discovery_cache and time.time() < self._discovery_expiry:
return self._discovery_cache
discovery_url = f"{settings.OIDC_SERVER_URL.rstrip('/')}/.well-known/openid-configuration"
try:
async with httpx.AsyncClient() as client:
response = await client.get(discovery_url, timeout=10.0)
response.raise_for_status()
self._discovery_cache = response.json()
self._discovery_expiry = time.time() + 3600 # Cache for 1 hour
return self._discovery_cache
except Exception as e:
logger.error(f"Failed to fetch OIDC discovery from {discovery_url}: {e}")
# Fallback to defaults if discovery fails, for resiliency
server_url = settings.OIDC_SERVER_URL.rstrip("/")
return {
"authorization_endpoint": f"{server_url}/auth",
"token_endpoint": f"{server_url}/token",
"jwks_uri": f"{server_url}/keys",
"userinfo_endpoint": f"{server_url}/userinfo"
}
async def generate_login_url(self, frontend_callback_uri: Optional[str]) -> str:
discovery = await self.get_discovery()
params = {
"response_type": "code",
"scope": "openid profile email",
"client_id": settings.OIDC_CLIENT_ID,
"redirect_uri": settings.OIDC_REDIRECT_URI,
"state": frontend_callback_uri or ""
}
auth_endpoint = discovery.get("authorization_endpoint")
return f"{auth_endpoint}?{urllib.parse.urlencode(params)}"
def create_session_token(self, user_id: str) -> str:
"""
Generates a short-lived session token (JWT) for local/password-based users.
Used to maintain security parity with OIDC ID tokens.
"""
from app.config import settings
import jwt
from datetime import datetime, timedelta
payload = {
"sub": user_id,
"iat": datetime.utcnow(),
"exp": datetime.utcnow() + timedelta(hours=24), # 24 hour local session
"iss": "cortex-hub-internal"
}
return jwt.encode(payload, settings.SECRET_KEY, algorithm="HS256")
async def verify_id_token(self, id_token: str, db: Session) -> models.User:
"""
Verifies an OIDC ID token (JWT), syncs the user record, and returns the User object.
"""
discovery = await self.get_discovery()
# 1. Fetch JWKS (Public Keys) to verify signature
jwks_url = discovery.get("jwks_uri")
try:
# 2. Use PyJWKClient to fetch the signing key and verify the JWT
jwks_client = jwt.PyJWKClient(jwks_url)
signing_key = jwks_client.get_signing_key_from_jwt(id_token)
decoded = jwt.decode(
id_token,
signing_key.key,
algorithms=["RS256"],
audience=settings.OIDC_CLIENT_ID,
options={"verify_aud": True if settings.OIDC_CLIENT_ID else False}
)
except Exception as e:
logger.warning(f"OIDC Token verification failed: {e}")
raise HTTPException(status_code=401, detail=f"Invalid OIDC token: {str(e)}")
# 3. Sync user with local database
email = decoded.get("email")
sub = decoded.get("sub")
if not email or not sub:
raise HTTPException(status_code=400, detail="OIDC token missing required claims (email, sub).")
user_id = self.services.user_service.sync_oidc_user(
db=db,
email=email,
external_id=sub,
full_name=decoded.get("name"),
avatar_url=decoded.get("picture")
)
user = self.services.user_service.get_user_by_id(db, user_id)
if not user:
raise HTTPException(status_code=500, detail="Failed to retrieve user after sync.")
return user
async def handle_callback(self, code: str, db: Session) -> Dict[str, Any]:
discovery = await self.get_discovery()
token_endpoint = discovery.get("token_endpoint")
token_data = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": settings.OIDC_REDIRECT_URI,
"client_id": settings.OIDC_CLIENT_ID,
"client_secret": settings.OIDC_CLIENT_SECRET,
}
try:
async with httpx.AsyncClient() as client:
token_response = await client.post(token_endpoint, data=token_data, timeout=30.0)
token_response.raise_for_status()
response_json = token_response.json()
id_token = response_json.get("id_token")
user = await self.verify_id_token(id_token, db)
return {"user_id": user.id, "linked": False, "id_token": id_token}
except httpx.HTTPStatusError as e:
logger.error(f"OIDC Token exchange failed with status {e.response.status_code}: {e.response.text}")
raise HTTPException(status_code=500, detail=f"OIDC Token exchange failed: {e.response.text}")
except httpx.RequestError as e:
logger.error(f"OIDC Token exchange request error: {e}")
raise HTTPException(status_code=500, detail=f"Failed to communicate with OIDC provider: {e}")
except HTTPException:
raise
except Exception as e:
logger.error(f"Unexpected error during callback: {e}")
raise HTTPException(status_code=500, detail="Internal server error during authentication.")