package rotation
import (
"context"
internallog "envoy-control-plane/internal/log"
"envoy-control-plane/internal/pkg/api"
"envoy-control-plane/internal/pkg/cert"
certapi "envoy-control-plane/internal/pkg/cert/api"
"envoy-control-plane/internal/pkg/cert/tool"
"envoy-control-plane/internal/pkg/storage"
"fmt"
"time"
)
const (
// defaultRenewBefore specifies the default duration before expiration when a certificate should be renewed.
defaultRenewBefore = 7 * 24 * time.Hour
)
// CertRotator manages the background rotation process for certificates stored in the system.
type CertRotator struct {
checkInterval time.Duration
storage *storage.Storage
certParser tool.CertificateParser
}
// NewCertRotator creates a new CertRotator instance.
// Note: The redundant NewCertRotor function was removed.
func NewCertRotator(interval time.Duration, s *storage.Storage) CertRotator {
return CertRotator{
checkInterval: interval,
storage: s,
// Assuming tool.CertificateParser implements an interface or is used directly as a concrete type
certParser: tool.CertificateParser{},
}
}
// loadCertificatesWithAutoRotationEnrolled retrieves all certificates marked for rotation.
func (cr *CertRotator) loadCertificatesWithAutoRotationEnrolled(ctx context.Context) ([]*storage.CertStorage, error) {
log := internallog.LogFromContext(ctx)
certs, err := cr.storage.LoadAllCertificates(ctx)
if err != nil {
log.Errorf("Failed to load all certificates from storage: %v", err)
return nil, fmt.Errorf("failed to load certificates: %w", err)
}
res := make([]*storage.CertStorage, 0, len(certs)) // Pre-allocate capacity
for _, cert := range certs {
if cert.EnableRotation {
res = append(res, cert)
}
}
log.Debugf("Loaded %d certificates enrolled for rotation.", len(res))
return res, nil
}
// checkAndRotateCertificate performs the parsing, renewal check, and rotation for a single certificate.
func (cr *CertRotator) checkAndRotateCertificate(ctx context.Context, c *storage.CertStorage) {
log := internallog.LogFromContext(ctx)
// 1. Parse the certificate
certInfo, err := cr.certParser.Parse(c.CertPEM)
if err != nil {
log.Errorf("Failed to parse certificate data for domain %s: %v", c.Domain, err)
return
}
// Assuming the parser returns at least one certificate struct in the slice for a valid PEM.
if len(certInfo) == 0 {
log.Errorf("Parsed certificate data for domain %s contained no certificates.", c.Domain)
return
}
// 2. Determine renewal window
renewBefore := c.RenewBefore
if renewBefore == 0 {
renewBefore = defaultRenewBefore
log.Debugf("No renew_before set for domain %s. Using default: %v", c.Domain, renewBefore)
} else {
log.Debugf("Using custom renew_before for domain %s: %v", c.Domain, renewBefore)
}
// Calculate the deadline: Expiration Date - RenewBefore Duration
certExpiration := certInfo[0].NotAfter
renewalDeadline := certExpiration.Add(-renewBefore)
// 3. Check for renewal (if the renewal deadline is in the past, or very soon)
// Renewal is needed if the certificate expires before the renewal window ends.
if time.Now().After(renewalDeadline) {
log.Infof("Certificate for domain %s needs renewal. Expires: %v, Deadline: %v", c.Domain, certExpiration.Format(time.RFC3339), renewalDeadline.Format(time.RFC3339))
// Build the old certificate structure for renewal
oldCert := &certapi.Certificate{
Domain: c.Domain,
CertPEM: c.CertPEM,
KeyPEM: c.KeyPEM,
AccountKey: c.AccountKey,
AccountURL: c.AccountURL,
}
// 4. Renew the certificate
certIsser, err := cert.NewCertIssuer(c.IssuerType)
if err != nil {
log.Errorf("Failed to create certificate issuer for domain %s: %v", c.Domain, err)
return
}
newCert, err := certIsser.RenewCertificate(oldCert, api.ACME_CALLENGE_WEB_PATH, c.Email)
if err != nil {
log.Errorf("Failed to renew the certificate for domain %s: %v", c.Domain, err)
return
}
// 5. Save the new certificate (corrected to pass cr.storage directly)
if err := cert.SaveCertificateData(ctx, cr.storage, newCert, c.Email, c.IssuerType, c.SecretName); err != nil {
// This is a critical failure, we have a new cert but failed to save it.
log.Errorf("Failed to save the new certificate for domain %s into the database: %v", c.Domain, err)
return
}
log.Infof("Successfully renewed and saved certificate for domain %s.", c.Domain)
} else {
log.Debugf("Certificate for domain %s is valid until %v. No renewal needed.", c.Domain, certExpiration.Format(time.RFC3339))
}
}
// RotateCertificates starts a background goroutine to periodically check and renew certificates.
func (cr *CertRotator) RotateCertificates(ctx context.Context) error {
log := internallog.LogFromContext(ctx)
ticker := time.NewTicker(cr.checkInterval)
// Ensure the ticker is stopped when the function exits (though the goroutine will outlive this return)
defer ticker.Stop()
// Start the certificate rotation watcher in a non-blocking goroutine.
go func() { // Run a check immediately on startup
// The rotation should be done periodically AND immediately on start to catch expiring certificates.
log.Debugf("[Rotator] Performing initial certificate rotation check.")
if err := cr.runRotationCheck(ctx); err != nil {
log.Errorf("[Rotator] Initial check failed: %v", err)
}
for {
select {
case <-ctx.Done():
// Context was canceled (e.g., application shutdown). Exit the loop gracefully.
log.Infof("[Rotator] Context cancelled. Shutting down rotation watcher.")
return
case <-ticker.C:
// The ticker fired. Time to check and potentially rotate certificates.
log.Debugf("[Rotator] Ticker fired. Running periodic check.")
if err := cr.runRotationCheck(ctx); err != nil {
log.Errorf("[Rotator] Periodic check failed: %v", err)
}
}
}
}()
// The background process has started successfully.
return nil
}
// runRotationCheck is a helper that orchestrates loading and checking all certificates.
func (cr *CertRotator) runRotationCheck(ctx context.Context) error {
log := internallog.LogFromContext(ctx)
certs, err := cr.loadCertificatesWithAutoRotationEnrolled(ctx)
if err != nil {
// Log the error but continue the loop, as this is a transient failure.
return fmt.Errorf("failed to load certificates with auto rotation enabled: %w", err)
}
for _, c := range certs {
// Run the check and rotation logic for each certificate
cr.checkAndRotateCertificate(ctx, c)
}
log.Debugf("[Rotator] Rotation check completed for %d certificates.", len(certs))
return nil
}