Newer
Older
EnvoyControlPlane / internal / pkg / cert / rotation / rotator.go
package rotation

import (
	"context"
	internallog "envoy-control-plane/internal/log"
	"envoy-control-plane/internal/pkg/api"
	"envoy-control-plane/internal/pkg/cert"
	certapi "envoy-control-plane/internal/pkg/cert/api"
	"envoy-control-plane/internal/pkg/cert/tool"
	"envoy-control-plane/internal/pkg/storage"
	"fmt"
	"time"
)

const (
	// defaultRenewBefore specifies the default duration before expiration when a certificate should be renewed.
	defaultRenewBefore = 7 * 24 * time.Hour
)

// CertRotator manages the background rotation process for certificates stored in the system.
type CertRotator struct {
	checkInterval time.Duration
	storage       *storage.Storage
	certParser    tool.CertificateParser
}

// NewCertRotator creates a new CertRotator instance.
// Note: The redundant NewCertRotor function was removed.
func NewCertRotator(interval time.Duration, s *storage.Storage) CertRotator {
	return CertRotator{
		checkInterval: interval,
		storage:       s,
		// Assuming tool.CertificateParser implements an interface or is used directly as a concrete type
		certParser: tool.CertificateParser{},
	}
}

// loadCertificatesWithAutoRotationEnrolled retrieves all certificates marked for rotation.
func (cr *CertRotator) loadCertificatesWithAutoRotationEnrolled(ctx context.Context) ([]*storage.CertStorage, error) {
	log := internallog.LogFromContext(ctx)

	certs, err := cr.storage.LoadAllCertificates(ctx)
	if err != nil {
		log.Errorf("Failed to load all certificates from storage: %v", err)
		return nil, fmt.Errorf("failed to load certificates: %w", err)
	}

	res := make([]*storage.CertStorage, 0, len(certs)) // Pre-allocate capacity
	for _, cert := range certs {
		if cert.EnableRotation {
			res = append(res, cert)
		}
	}
	log.Debugf("Loaded %d certificates enrolled for rotation.", len(res))
	return res, nil
}

// checkAndRotateCertificate performs the parsing, renewal check, and rotation for a single certificate.
func (cr *CertRotator) checkAndRotateCertificate(ctx context.Context, c *storage.CertStorage) {
	log := internallog.LogFromContext(ctx)

	// 1. Parse the certificate
	certInfo, err := cr.certParser.Parse(c.CertPEM)
	if err != nil {
		log.Errorf("Failed to parse certificate data for domain %s: %v", c.Domain, err)
		return
	}

	// Assuming the parser returns at least one certificate struct in the slice for a valid PEM.
	if len(certInfo) == 0 {
		log.Errorf("Parsed certificate data for domain %s contained no certificates.", c.Domain)
		return
	}

	// 2. Determine renewal window
	renewBefore := c.RenewBefore
	if renewBefore == 0 {
		renewBefore = defaultRenewBefore
		log.Debugf("No renew_before set for domain %s. Using default: %v", c.Domain, renewBefore)
	} else {
		log.Debugf("Using custom renew_before for domain %s: %v", c.Domain, renewBefore)
	}

	// Calculate the deadline: Expiration Date - RenewBefore Duration
	certExpiration := certInfo[0].NotAfter
	renewalDeadline := certExpiration.Add(-renewBefore)

	// 3. Check for renewal (if the renewal deadline is in the past, or very soon)
	// Renewal is needed if the certificate expires before the renewal window ends.
	if time.Now().After(renewalDeadline) {
		log.Infof("Certificate for domain %s needs renewal. Expires: %v, Deadline: %v", c.Domain, certExpiration.Format(time.RFC3339), renewalDeadline.Format(time.RFC3339))

		// Build the old certificate structure for renewal
		oldCert := &certapi.Certificate{
			Domain:     c.Domain,
			CertPEM:    c.CertPEM,
			KeyPEM:     c.KeyPEM,
			AccountKey: c.AccountKey,
			AccountURL: c.AccountURL,
		}

		// 4. Renew the certificate
		certIsser, err := cert.NewCertIssuer(c.IssuerType)
		if err != nil {
			log.Errorf("Failed to create certificate issuer for domain %s: %v", c.Domain, err)
			return
		}
		newCert, err := certIsser.RenewCertificate(oldCert, api.ACME_CALLENGE_WEB_PATH, c.Email)
		if err != nil {
			log.Errorf("Failed to renew the certificate for domain %s: %v", c.Domain, err)
			return
		}

		// 5. Save the new certificate (corrected to pass cr.storage directly)
		if err := cert.SaveCertificateData(ctx, cr.storage, newCert, c.Email, c.IssuerType, c.SecretName); err != nil {
			// This is a critical failure, we have a new cert but failed to save it.
			log.Errorf("Failed to save the new certificate for domain %s into the database: %v", c.Domain, err)
			return
		}
		log.Infof("Successfully renewed and saved certificate for domain %s.", c.Domain)
	} else {
		log.Debugf("Certificate for domain %s is valid until %v. No renewal needed.", c.Domain, certExpiration.Format(time.RFC3339))
	}
}

// RotateCertificates starts a background goroutine to periodically check and renew certificates.
func (cr *CertRotator) RotateCertificates(ctx context.Context) error {
	log := internallog.LogFromContext(ctx)

	ticker := time.NewTicker(cr.checkInterval)
	// Ensure the ticker is stopped when the function exits (though the goroutine will outlive this return)
	defer ticker.Stop()

	// Start the certificate rotation watcher in a non-blocking goroutine.
	go func() { // Run a check immediately on startup
		// The rotation should be done periodically AND immediately on start to catch expiring certificates.
		log.Debugf("[Rotator] Performing initial certificate rotation check.")
		if err := cr.runRotationCheck(ctx); err != nil {
			log.Errorf("[Rotator] Initial check failed: %v", err)
		}

		for {
			select {
			case <-ctx.Done():
				// Context was canceled (e.g., application shutdown). Exit the loop gracefully.
				log.Infof("[Rotator] Context cancelled. Shutting down rotation watcher.")
				return
			case <-ticker.C:
				// The ticker fired. Time to check and potentially rotate certificates.
				log.Debugf("[Rotator] Ticker fired. Running periodic check.")
				if err := cr.runRotationCheck(ctx); err != nil {
					log.Errorf("[Rotator] Periodic check failed: %v", err)
				}
			}
		}
	}()

	// The background process has started successfully.
	return nil
}

// runRotationCheck is a helper that orchestrates loading and checking all certificates.
func (cr *CertRotator) runRotationCheck(ctx context.Context) error {
	log := internallog.LogFromContext(ctx)

	certs, err := cr.loadCertificatesWithAutoRotationEnrolled(ctx)
	if err != nil {
		// Log the error but continue the loop, as this is a transient failure.
		return fmt.Errorf("failed to load certificates with auto rotation enabled: %w", err)
	}

	for _, c := range certs {
		// Run the check and rotation logic for each certificate
		cr.checkAndRotateCertificate(ctx, c)
	}
	log.Debugf("[Rotator] Rotation check completed for %d certificates.", len(certs))
	return nil
}