import httpx
import jwt
import urllib.parse
import time
from fastapi import HTTPException
import logging
from app.config import settings
from typing import Optional, Dict, Any

logger = logging.getLogger(__name__)

class AuthService:
    _discovery_cache = None
    _discovery_expiry = 0

    def __init__(self, services):
        self.services = services

    async def get_discovery(self) -> Dict[str, Any]:
        if self._discovery_cache and time.time() < self._discovery_expiry:
            return self._discovery_cache

        discovery_url = f"{settings.OIDC_SERVER_URL.rstrip('/')}/.well-known/openid-configuration"
        try:
            async with httpx.AsyncClient() as client:
                response = await client.get(discovery_url, timeout=10.0)
                response.raise_for_status()
                self._discovery_cache = response.json()
                self._discovery_expiry = time.time() + 3600 # Cache for 1 hour
                return self._discovery_cache
        except Exception as e:
            logger.error(f"Failed to fetch OIDC discovery from {discovery_url}: {e}")
            # Fallback to defaults if discovery fails, for resiliency
            server_url = settings.OIDC_SERVER_URL.rstrip("/")
            return {
                "authorization_endpoint": f"{server_url}/auth",
                "token_endpoint": f"{server_url}/token",
                "jwks_uri": f"{server_url}/keys",
                "userinfo_endpoint": f"{server_url}/userinfo"
            }

    async def generate_login_url(self, frontend_callback_uri: Optional[str]) -> str:
        discovery = await self.get_discovery()
        params = {
            "response_type": "code",
            "scope": "openid profile email",
            "client_id": settings.OIDC_CLIENT_ID,
            "redirect_uri": settings.OIDC_REDIRECT_URI,
            "state": frontend_callback_uri or ""
        }
        auth_endpoint = discovery.get("authorization_endpoint")
        return f"{auth_endpoint}?{urllib.parse.urlencode(params)}"

    async def handle_callback(self, code: str, db) -> Dict[str, Any]:
        discovery = await self.get_discovery()
        token_endpoint = discovery.get("token_endpoint")
        token_data = {
            "grant_type": "authorization_code",
            "code": code,
            "redirect_uri": settings.OIDC_REDIRECT_URI,
            "client_id": settings.OIDC_CLIENT_ID,
            "client_secret": settings.OIDC_CLIENT_SECRET,
        }
        
        try:
            async with httpx.AsyncClient() as client:
                token_response = await client.post(token_endpoint, data=token_data, timeout=30.0)
                token_response.raise_for_status()
                response_json = token_response.json()
                id_token = response_json.get("id_token")

        except httpx.HTTPStatusError as e:
            logger.error(f"OIDC Token exchange failed with status {e.response.status_code}: {e.response.text}")
            raise HTTPException(status_code=500, detail=f"OIDC Token exchange failed: {e.response.text}")
        except httpx.RequestError as e:
            logger.error(f"OIDC Token exchange request error: {e}")
            raise HTTPException(status_code=500, detail=f"Failed to communicate with OIDC provider: {e}")

        # 1. Fetch JWKS (Public Keys) to verify signature
        jwks_url = discovery.get("jwks_uri")
        try:
            async with httpx.AsyncClient() as client:
                jwks_response = await client.get(jwks_url, timeout=10.0)
                jwks_response.raise_for_status()
                jwks = jwks_response.json()
        except Exception as e:
            logger.error(f"Failed to fetch JWKS from {jwks_url}: {e}")
            raise HTTPException(status_code=500, detail="Failed to verify identity: Identity provider keys unreachable.")

        # 2. Decode and Verify Signature
        try:
            # We use the 'sub' and 'email' as primary identity
            # Enforce signature verification, audience, and issuer checks
            # Note: PyJWT's PyJWKClient can automate this, but here we use a lower-level 
            # approach to work within the existing generic JWT library constraints.
            jwk_set = jwt.PyJWKSet.from_dict(jwks)
            sh = jwt.get_unverified_header(id_token)
            key = jwk_set[sh["kid"]]
            
            decoded_id_token = jwt.decode(
                id_token, 
                key.key, 
                algorithms=["RS256"], 
                audience=settings.OIDC_CLIENT_ID,
                issuer=settings.OIDC_SERVER_URL.rstrip("/")
            )
        except jwt.PyJWTError as e:
            logger.error(f"JWT Verification failed: {e}")
            raise HTTPException(status_code=401, detail=f"Invalid authentication token: {str(e)}")
        
        oidc_id = decoded_id_token.get("sub")
        email = decoded_id_token.get("email")
        username = decoded_id_token.get("name") or decoded_id_token.get("preferred_username") or email
        
        if not all([oidc_id, email]):
            raise HTTPException(status_code=400, detail="Essential user data missing from ID token (sub and email required).")

        user_id, linked = self.services.user_service.save_user(
            db=db,
            oidc_id=oidc_id,
            email=email,
            username=username
        )
        return {"user_id": user_id, "linked": linked}
