Newer
Older
EnvoyControlPlane / internal / pkg / rotation / rotator.go
package rotation

import (
	"context"
	internallog "envoy-control-plane/internal/log"
	"envoy-control-plane/internal/pkg/api"
	"envoy-control-plane/internal/pkg/cert"
	certapi "envoy-control-plane/internal/pkg/cert/api"
	"envoy-control-plane/internal/pkg/cert/tool"
	"envoy-control-plane/internal/pkg/storage"
	"fmt"

	"time"
)

const (
	defaultRenewBefore = 7 * 24 * time.Hour
)

func NewCertRotor(ctx context.Context) (*CertRotator, error) {
	// Implementation for creating a new CertRotator
	return &CertRotator{}, nil
}

type CertRotator struct {
	checkInterval time.Duration
	storage       storage.Storage
	certIssuer    certapi.CertIssuer
	certPaser     tool.CertificateParser
}

func NewCertRotator(interval time.Duration, s storage.Storage, ci certapi.CertIssuer) CertRotator {
	return CertRotator{
		checkInterval: interval,
		storage:       s,
		certIssuer:    ci,
		certPaser:     tool.CertificateParser{},
	}
}

// loadCertificatesWithAutoRotationEnrolled retrieves all certificates marked for rotation.
// The method name has been corrected for clarity and idiomatic Go naming conventions (shorter).
func (cr *CertRotator) loadCertificatesWithAutoRotationEnrolled(ctx context.Context) ([]*storage.CertStorage, error) {
	// Use the logger from context for better tracing
	log := internallog.LogFromContext(ctx)

	certs, err := cr.storage.LoadAllCertificates(ctx)
	if err != nil {
		log.Errorf("Failed to load all certificates from storage: %w", err)
		return nil, fmt.Errorf("failed to load certificates: %w", err)
	}

	res := make([]*storage.CertStorage, 0, len(certs)) // Pre-allocate capacity
	for _, cert := range certs {
		if cert.EnableRotation {
			res = append(res, cert)
		}
	}
	log.Debugf("Loaded %d certificates enrolled for rotation.", len(res))
	return res, nil
}

func (cr *CertRotator) RotateCertificates(ctx context.Context) error {
	log := internallog.LogFromContext(ctx)

	// Implementation for rotating certificates
	ticker := time.NewTicker(cr.checkInterval)
	defer ticker.Stop()

	// 3. Start the goroutine loop.
	go func() {
		log.Infof("[Rotator] Starting certificate rotation watcher. Interval: %v\n", cr.checkInterval)
		for {
			select {
			case <-ctx.Done():
				// Context was canceled (e.g., application shutdown). Exit the loop gracefully.
				log.Infof("[Rotator] Context cancelled. Shutting down rotation watcher.")
				return
			case <-ticker.C:
				// The ticker fired. Time to check and potentially rotate the certificate.
				certs, err := cr.loadCertificatesWithAutoRotationEnrolled(ctx)
				if err != nil {
					// Log the error but continue the loop, as this is a transient failure.
					log.Errorf("[Rotator] failed to load certificates with auto ratation enabled: %v\n", err)
				}

				for _, c := range certs {
					certInfo, err := cr.certPaser.Parse(c.CertPEM)
					if err != nil {
						log.Errorf("failed to parse the cert %s: %w", string(c.CertPEM), err)
						continue
					}
					renewBefore := c.RenewBefore
					if renewBefore == 0 {
						renewBefore = defaultRenewBefore
					}

					if time.Now().Add(c.RenewBefore).After(certInfo[0].NotBefore) {
						oldCert := &certapi.Certificate{
							Domain:     c.Domain,
							CertPEM:    c.CertPEM,
							KeyPEM:     c.KeyPEM,
							AccountKey: c.AccountKey,
							AccountURL: c.AccountURL,
						}
						newCert, err := cr.certIssuer.RenewCertificate(oldCert, api.ACME_CALLENGE_WEB_PATH, c.Email)
						if err != nil {
							log.Errorf("failed to renew the certificate: %w", err)
							continue
						}
						if err := cert.SaveCertificateData(ctx, &cr.storage, newCert, c.Email, c.IssuerType, c.SecretName); err != nil {
							log.Errorf("failed to save the new certificate into the database")
							continue
						}

					}
				}
			}
		}
	}()
	return nil
}