Newer
Older
AnthosCertManager / pkg / controller / certificates / utils.go
package certificates

import (
	"crypto"
	"crypto/ecdsa"
	"crypto/ed25519"
	"crypto/rsa"
	"crypto/x509/pkix"
	"encoding/asn1"
	"fmt"
	"reflect"
	"time"

	acmapi "gitbucket.jerxie.com/yangyangxie/AnthosCertManager/pkg/apis/anthoscertmanager/v1"
	"gitbucket.jerxie.com/yangyangxie/AnthosCertManager/pkg/util"
	"gitbucket.jerxie.com/yangyangxie/AnthosCertManager/pkg/util/pki"
	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

// RenewalTimeFunc is a custom function type for calculating renewal time of a certificate.
type RenewalTimeFunc func(time.Time, time.Time, *metav1.Duration) *metav1.Time

// RenewalTime calculates renewal time for a certificate. Default renewal time
// is 2/3 through certificate's lifetime. If user has configured
// spec.renewBefore, renewal time will be renewBefore period before expiry
// (unless that is after the expiry).
func RenewalTime(notBefore, notAfter time.Time, renewBeforeOverride *metav1.Duration) *metav1.Time {

	// 1. Calculate how long before expiry a cert should be renewed

	actualDuration := notAfter.Sub(notBefore)

	renewBefore := actualDuration / 3

	// If spec.renewBefore was set (and is less than duration)
	// respect that. We don't want to prevent users from renewing
	// longer lived certs more frequently.
	if renewBeforeOverride != nil && renewBeforeOverride.Duration < actualDuration {
		renewBefore = renewBeforeOverride.Duration
	}

	// 2. Calculate when a cert should be renewed

	// Truncate the renewal time to nearest second. This is important
	// because the renewal time also gets stored on Certificate's status
	// where it is truncated to the nearest second. We use the renewal time
	// from Certificate's status to determine when the Certificate will be
	// added to the queue to be renewed, but then re-calculate whether it
	// needs to be renewed _now_ using this function- so returning a
	// non-truncated value here would potentially cause Certificates to be
	// re-queued for renewal earlier than the calculated renewal time thus
	// causing Certificates to not be automatically renewed. See
	// https://github.com/cert-manager/cert-manager/pull/4399.
	rt := metav1.NewTime(notAfter.Add(-1 * renewBefore).Truncate(time.Second))
	return &rt
}

// PrivateKeyMatchesSpec returns an error if the private key bit size
// doesn't match the provided spec. RSA, Ed25519 and ECDSA are supported.
// If any error is returned, a list of violations will also be returned.
func PrivateKeyMatchesSpec(pk crypto.PrivateKey, spec acmapi.CertificateSpec) ([]string, error) {
	spec = *spec.DeepCopy()
	if spec.PrivateKey == nil {
		spec.PrivateKey = &acmapi.CertificatePrivateKey{}
	}
	switch spec.PrivateKey.Algorithm {
	case "", acmapi.RSAKeyAlgorithm:
		return rsaPrivateKeyMatchesSpec(pk, spec)
	case acmapi.Ed25519KeyAlgorithm:
		return ed25519PrivateKeyMatchesSpec(pk, spec)
	case acmapi.ECDSAKeyAlgorithm:
		return ecdsaPrivateKeyMatchesSpec(pk, spec)
	default:
		return nil, fmt.Errorf("unrecognised key algorithm type %q", spec.PrivateKey.Algorithm)
	}
}

func rsaPrivateKeyMatchesSpec(pk crypto.PrivateKey, spec acmapi.CertificateSpec) ([]string, error) {
	rsaPk, ok := pk.(*rsa.PrivateKey)
	if !ok {
		return []string{"spec.privateKey.algorithm"}, nil
	}
	var violations []string
	// TODO: we should not use implicit defaulting here, and instead rely on
	//  defaulting performed within the Kubernetes apiserver here.
	//  This requires careful handling in order to not interrupt users upgrading
	//  from older versions.
	// The default RSA keySize is set to 2048.
	keySize := pki.MinRSAKeySize
	if spec.PrivateKey.Size > 0 {
		keySize = spec.PrivateKey.Size
	}
	if rsaPk.N.BitLen() != keySize {
		violations = append(violations, "spec.privateKey.size")
	}
	return violations, nil
}

func ecdsaPrivateKeyMatchesSpec(pk crypto.PrivateKey, spec acmapi.CertificateSpec) ([]string, error) {
	ecdsaPk, ok := pk.(*ecdsa.PrivateKey)
	if !ok {
		return []string{"spec.privateKey.algorithm"}, nil
	}
	var violations []string
	// TODO: we should not use implicit defaulting here, and instead rely on
	//  defaulting performed within the Kubernetes apiserver here.
	//  This requires careful handling in order to not interrupt users upgrading
	//  from older versions.
	// The default EC curve type is EC256
	expectedKeySize := pki.ECCurve256
	if spec.PrivateKey.Size > 0 {
		expectedKeySize = spec.PrivateKey.Size
	}
	if expectedKeySize != ecdsaPk.Curve.Params().BitSize {
		violations = append(violations, "spec.privateKey.size")
	}
	return violations, nil
}

func ed25519PrivateKeyMatchesSpec(pk crypto.PrivateKey, spec acmapi.CertificateSpec) ([]string, error) {
	_, ok := pk.(ed25519.PrivateKey)
	if !ok {
		return []string{"spec.privateKey.algorithm"}, nil
	}

	return nil, nil
}

// RequestMatchesSpec compares a CertificateRequest with a CertificateSpec
// and returns a list of field names on the Certificate that do not match their
// counterpart fields on the CertificateRequest.
// If decoding the x509 certificate request fails, an error will be returned.
func RequestMatchesSpec(req *acmapi.CertificateRequest, spec acmapi.CertificateSpec) ([]string, error) {
	x509req, err := pki.DecodeX509CertificateRequestBytes(req.Spec.Request)
	if err != nil {
		return nil, err
	}

	// It is safe to mutate top-level fields in `spec` as it is not a pointer
	// meaning changes will not effect the caller.
	if spec.Subject == nil {
		spec.Subject = &acmapi.X509Subject{}
	}

	var violations []string
	if spec.LiteralSubject == "" {
		if x509req.Subject.CommonName != spec.CommonName {
			violations = append(violations, "spec.commonName")
		}
		if !util.EqualUnsorted(x509req.DNSNames, spec.DNSNames) {
			violations = append(violations, "spec.dnsNames")
		}
		if !util.EqualUnsorted(pki.IPAddressesToString(x509req.IPAddresses), spec.IPAddresses) {
			violations = append(violations, "spec.ipAddresses")
		}
		if !util.EqualUnsorted(pki.URLsToString(x509req.URIs), spec.URIs) {
			violations = append(violations, "spec.uris")
		}
		if !util.EqualUnsorted(x509req.EmailAddresses, spec.EmailAddresses) {
			violations = append(violations, "spec.emailAddresses")
		}
		if x509req.Subject.SerialNumber != spec.Subject.SerialNumber {
			violations = append(violations, "spec.subject.serialNumber")
		}
		if !util.EqualUnsorted(x509req.Subject.Organization, spec.Subject.Organizations) {
			violations = append(violations, "spec.subject.organizations")
		}
		if !util.EqualUnsorted(x509req.Subject.Country, spec.Subject.Countries) {
			violations = append(violations, "spec.subject.countries")
		}
		if !util.EqualUnsorted(x509req.Subject.Locality, spec.Subject.Localities) {
			violations = append(violations, "spec.subject.localities")
		}
		if !util.EqualUnsorted(x509req.Subject.OrganizationalUnit, spec.Subject.OrganizationalUnits) {
			violations = append(violations, "spec.subject.organizationalUnits")
		}
		if !util.EqualUnsorted(x509req.Subject.PostalCode, spec.Subject.PostalCodes) {
			violations = append(violations, "spec.subject.postCodes")
		}
		if !util.EqualUnsorted(x509req.Subject.Province, spec.Subject.Provinces) {
			violations = append(violations, "spec.subject.postCodes")
		}
		if !util.EqualUnsorted(x509req.Subject.StreetAddress, spec.Subject.StreetAddresses) {
			violations = append(violations, "spec.subject.streetAddresses")
		}
		if req.Spec.IsCA != spec.IsCA {
			violations = append(violations, "spec.isCA")
		}
		if !util.EqualKeyUsagesUnsorted(req.Spec.Usages, spec.Usages) {
			violations = append(violations, "spec.usages")
		}
		if spec.Duration != nil && req.Spec.Duration != nil &&
			spec.Duration.Duration != req.Spec.Duration.Duration {
			violations = append(violations, "spec.duration")
		}
		if !reflect.DeepEqual(spec.IssuerRef, req.Spec.IssuerRef) {
			violations = append(violations, "spec.issuerRef")
		}
	} else {
		// we have a LiteralSubject
		// parse the subject of the csr in the same way as we parse LiteralSubject and see whether the RDN Sequences match

		var rdnSequenceFromCertificateRequest pkix.RDNSequence
		_, err2 := asn1.Unmarshal(x509req.RawSubject, &rdnSequenceFromCertificateRequest)
		if err2 != nil {
			return nil, err2
		}

		rdnSequenceFromCertificate, err := pki.ParseSubjectStringToRdnSequence(spec.LiteralSubject)
		if err != nil {
			return nil, err
		}

		if !reflect.DeepEqual(rdnSequenceFromCertificate, rdnSequenceFromCertificateRequest) {
			violations = append(violations, "spec.literalSubject")
		}
	}

	return violations, nil
}