diff --git a/Makefile b/Makefile index 50978c9..c6da0a4 100644 --- a/Makefile +++ b/Makefile @@ -32,7 +32,7 @@ .PHONY: run run: @echo "==> Running $(APP_NAME)..." - @$(GO) run . --port=$${PORT:-18000} --node-id=test-id + @$(GO) run . --port=$${PORT:-18000} --node-id=test-id --enable-cert-issuance --webroot-path=data/acme ## Run tests .PHONY: test diff --git a/data/config.db b/data/config.db old mode 100644 new mode 100755 Binary files differ diff --git a/internal/app/app.go b/internal/app/app.go index 762d596..dfab493 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -9,6 +9,8 @@ "os" "path/filepath" "strings" + "sync" + "time" "github.com/envoyproxy/go-control-plane/pkg/cache/types" cachev3 "github.com/envoyproxy/go-control-plane/pkg/cache/v3" @@ -21,7 +23,8 @@ internallog "envoy-control-plane/internal/log" "envoy-control-plane/internal/pkg/api" "envoy-control-plane/internal/pkg/cert" - rtspserver "envoy-control-plane/internal/pkg/server" + "envoy-control-plane/internal/pkg/cert/rotation" + restAPIServer "envoy-control-plane/internal/pkg/server" "envoy-control-plane/internal/pkg/snapshot" internalstorage "envoy-control-plane/internal/pkg/storage" ) @@ -156,10 +159,15 @@ // Run encapsulates the entire application startup logic. func Run(ctx context.Context) error { log := internallog.LogFromContext(ctx) - cfg := config.GetConfig() // Use the globally available config - // 1. Conditional Certificate Issuance (Non-blocking) - cert.RunCertIssuance(ctx) + cfg := config.GetConfig() + + // 1. Certificate Issuance Requisite Check. This will ensure that if cert issuance is enabled, the prerequisites are met before proceeding. + if cfg.EnableCertIssuance { + if err := cert.CertIssuancePrerequisite(ctx, cfg.WebrootPath); err != nil { + return fmt.Errorf("certificate issuance prerequisite check failed: %w", err) + } + } // 2. Database Initialization dbConnStr, dbDriver, err := internalstorage.SetupDBConnection(ctx, cfg.DBConnStr) @@ -189,14 +197,68 @@ return fmt.Errorf("failed to load initial configuration: %w", err) } + var wg sync.WaitGroup + // Use a background context for the WaitGroup to avoid deadlocks during shutdown + wgCtx := context.Background() + // 5. Start xDS gRPC Server + wg.Add(1) cb := &test.Callbacks{Debug: true} srv := server.NewServer(ctx, cache, cb) - go internal.RunServer(srv, cfg.Port) // Assuming internal.RunServer is correct + go func() { + defer wg.Done() + log.Infof("Starting xDS gRPC server on port %d...", cfg.Port) + // Updated to pass context for graceful shutdown + internal.RunServer(ctx, srv, cfg.Port) + log.Infof("xDS gRPC server shut down.") + }() // 6. Start REST API Server + wg.Add(1) mux := http.NewServeMux() api.RegisterRoutes(mux, manager, cfg.EnableCertIssuance, cfg.WebrootPath) // Pass needed dependencies - return rtspserver.RunRESTServer(ctx, mux, cfg.RESTPort, cfg.WebrootPath, cfg.EnableCertIssuance) + go func() { + defer wg.Done() + log.Infof("Starting REST API server on port %d...", cfg.RESTPort) + restAPIServer.RunRESTServer(ctx, mux, cfg.RESTPort, cfg.WebrootPath, cfg.EnableCertIssuance) + log.Infof("REST API server shut down.") + }() + + if cfg.EnableCertIssuance { + // 7. Start Certificate Rotator + wg.Add(1) + certRotator := rotation.NewCertRotator(cfg.CertCheckInterval, storage) + go func() { + defer wg.Done() + log.Infof("Starting certificate rotator with check interval: %v", cfg.CertCheckInterval) + if err := certRotator.RotateCertificates(ctx); err != nil { + log.Errorf("Certificate rotator failed: %v", err) + } + }() + } + + log.Infof("All services started. Application is running and blocking until shutdown signal.") + <-ctx.Done() + log.Infof("Shutdown signal received. Waiting for all services to stop...") + + // Create a context with a timeout for graceful shutdown + shutdownCtx, cancel := context.WithTimeout(wgCtx, 10*time.Second) + defer cancel() + + // Wait for all background goroutines to finish or the shutdown timeout to expire + go func() { + wg.Wait() + cancel() // Signal completion + }() + + <-shutdownCtx.Done() + + if shutdownCtx.Err() == context.DeadlineExceeded { + log.Errorf("Graceful shutdown timeout exceeded. Exiting forcefully.") + return fmt.Errorf("shutdown failed: services did not terminate within timeout") + } + + log.Infof("Application shut down successfully.") + return nil } diff --git a/internal/config/config.go b/internal/config/config.go index 542a39e..22cb044 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -2,6 +2,7 @@ import ( "flag" + "time" "k8s.io/klog/v2" ) @@ -16,6 +17,7 @@ DBConnStr string EnableCertIssuance bool WebrootPath string + CertCheckInterval time.Duration } // Global configuration variable. @@ -33,6 +35,7 @@ flag.StringVar(&cfg.DBConnStr, "db", "", "Optional database connection string for config persistence") flag.BoolVar(&cfg.EnableCertIssuance, "enable-cert-issuance", false, "Enable Let's Encrypt certificate issuance on startup") flag.StringVar(&cfg.WebrootPath, "webroot-path", "data/acme", "Local path to serve the HTTP-01 challenge file (required if enabled)") + flag.DurationVar(&cfg.CertCheckInterval, "cert-check-interval", 1*time.Hour, "Interval for checking certificate expiration and renewing if necessary") } // GetConfig returns the application configuration. diff --git a/internal/log/log.go b/internal/log/log.go index 31f8307..9e5bf28 100644 --- a/internal/log/log.go +++ b/internal/log/log.go @@ -14,6 +14,7 @@ Infof(format string, args ...interface{}) Warnf(format string, args ...interface{}) Errorf(format string, args ...interface{}) + Fatalf(format string, args ...interface{}) } // Ensure DefaultLogger satisfies the Logger interface. @@ -55,6 +56,10 @@ klog.Errorf(format, args...) } +func (l *DefaultLogger) Fatalf(format string, args ...interface{}) { + klog.Fatalf(format, args...) +} + // ----------------------------------------------------------------------------- // Context Key and Functions // ----------------------------------------------------------------------------- diff --git a/internal/pkg/cert/cert.go b/internal/pkg/cert/cert.go index 2490b91..d3a4811 100644 --- a/internal/pkg/cert/cert.go +++ b/internal/pkg/cert/cert.go @@ -2,38 +2,30 @@ import ( "context" + "fmt" "os" - "envoy-control-plane/internal/config" internallog "envoy-control-plane/internal/log" ) const defaultFileMode = 0755 -// RunCertIssuance handles the conditional logic and argument validation for cert issuance. -func RunCertIssuance(ctx context.Context) { +// CertIssuancePrerequisite handles the conditional logic and argument validation for cert issuance. +func CertIssuancePrerequisite(ctx context.Context, webrootpath string) error { log := internallog.LogFromContext(ctx) - cfg := config.GetConfig() - - if !cfg.EnableCertIssuance { - return - } - log.Infof("Certificate issuance enabled. Validating arguments...") - - if cfg.WebrootPath == "" { + if webrootpath == "" { log.Errorf("Webroot path is required for certificate issuance") - return + return fmt.Errorf("webroot path is required for certificate issuance") } // 1. Ensure webroot path exists - if _, err := os.Stat(cfg.WebrootPath); os.IsNotExist(err) { - log.Warnf("Webroot path '%s' does not exist. Creating it.", cfg.WebrootPath) - if err := os.MkdirAll(cfg.WebrootPath, defaultFileMode); err != nil { + if _, err := os.Stat(webrootpath); os.IsNotExist(err) { + log.Warnf("Webroot path '%s' does not exist. Creating it.", webrootpath) + if err := os.MkdirAll(webrootpath, defaultFileMode); err != nil { log.Errorf("Failed to create webroot path: %v", err) + return fmt.Errorf("failed to create webroot path: %w", err) } } - - // NOTE: The commented-out code for starting the HTTP-01 server on :80 should - // be placed here if you implement it fully. + return nil } diff --git a/internal/pkg/cert/rotation/rotator.go b/internal/pkg/cert/rotation/rotator.go new file mode 100644 index 0000000..100c530 --- /dev/null +++ b/internal/pkg/cert/rotation/rotator.go @@ -0,0 +1,178 @@ +package rotation + +import ( + "context" + internallog "envoy-control-plane/internal/log" + "envoy-control-plane/internal/pkg/api" + "envoy-control-plane/internal/pkg/cert" + certapi "envoy-control-plane/internal/pkg/cert/api" + "envoy-control-plane/internal/pkg/cert/tool" + "envoy-control-plane/internal/pkg/storage" + "fmt" + "time" +) + +const ( + // defaultRenewBefore specifies the default duration before expiration when a certificate should be renewed. + defaultRenewBefore = 7 * 24 * time.Hour +) + +// CertRotator manages the background rotation process for certificates stored in the system. +type CertRotator struct { + checkInterval time.Duration + storage *storage.Storage + certParser tool.CertificateParser +} + +// NewCertRotator creates a new CertRotator instance. +// Note: The redundant NewCertRotor function was removed. +func NewCertRotator(interval time.Duration, s *storage.Storage) CertRotator { + return CertRotator{ + checkInterval: interval, + storage: s, + // Assuming tool.CertificateParser implements an interface or is used directly as a concrete type + certParser: tool.CertificateParser{}, + } +} + +// loadCertificatesWithAutoRotationEnrolled retrieves all certificates marked for rotation. +func (cr *CertRotator) loadCertificatesWithAutoRotationEnrolled(ctx context.Context) ([]*storage.CertStorage, error) { + log := internallog.LogFromContext(ctx) + + certs, err := cr.storage.LoadAllCertificates(ctx) + if err != nil { + log.Errorf("Failed to load all certificates from storage: %v", err) + return nil, fmt.Errorf("failed to load certificates: %w", err) + } + + res := make([]*storage.CertStorage, 0, len(certs)) // Pre-allocate capacity + for _, cert := range certs { + if cert.EnableRotation { + res = append(res, cert) + } + } + log.Debugf("Loaded %d certificates enrolled for rotation.", len(res)) + return res, nil +} + +// checkAndRotateCertificate performs the parsing, renewal check, and rotation for a single certificate. +func (cr *CertRotator) checkAndRotateCertificate(ctx context.Context, c *storage.CertStorage) { + log := internallog.LogFromContext(ctx) + + // 1. Parse the certificate + certInfo, err := cr.certParser.Parse(c.CertPEM) + if err != nil { + log.Errorf("Failed to parse certificate data for domain %s: %v", c.Domain, err) + return + } + + // Assuming the parser returns at least one certificate struct in the slice for a valid PEM. + if len(certInfo) == 0 { + log.Errorf("Parsed certificate data for domain %s contained no certificates.", c.Domain) + return + } + + // 2. Determine renewal window + renewBefore := c.RenewBefore + if renewBefore == 0 { + renewBefore = defaultRenewBefore + log.Debugf("No renew_before set for domain %s. Using default: %v", c.Domain, renewBefore) + } else { + log.Debugf("Using custom renew_before for domain %s: %v", c.Domain, renewBefore) + } + + // Calculate the deadline: Expiration Date - RenewBefore Duration + certExpiration := certInfo[0].NotAfter + renewalDeadline := certExpiration.Add(-renewBefore) + + // 3. Check for renewal (if the renewal deadline is in the past, or very soon) + // Renewal is needed if the certificate expires before the renewal window ends. + if time.Now().After(renewalDeadline) { + log.Infof("Certificate for domain %s needs renewal. Expires: %v, Deadline: %v", c.Domain, certExpiration.Format(time.RFC3339), renewalDeadline.Format(time.RFC3339)) + + // Build the old certificate structure for renewal + oldCert := &certapi.Certificate{ + Domain: c.Domain, + CertPEM: c.CertPEM, + KeyPEM: c.KeyPEM, + AccountKey: c.AccountKey, + AccountURL: c.AccountURL, + } + + // 4. Renew the certificate + certIsser, err := cert.NewCertIssuer(c.IssuerType) + if err != nil { + log.Errorf("Failed to create certificate issuer for domain %s: %v", c.Domain, err) + return + } + newCert, err := certIsser.RenewCertificate(oldCert, api.ACME_CALLENGE_WEB_PATH, c.Email) + if err != nil { + log.Errorf("Failed to renew the certificate for domain %s: %v", c.Domain, err) + return + } + + // 5. Save the new certificate (corrected to pass cr.storage directly) + if err := cert.SaveCertificateData(ctx, cr.storage, newCert, c.Email, c.IssuerType, c.SecretName); err != nil { + // This is a critical failure, we have a new cert but failed to save it. + log.Errorf("Failed to save the new certificate for domain %s into the database: %v", c.Domain, err) + return + } + log.Infof("Successfully renewed and saved certificate for domain %s.", c.Domain) + } else { + log.Debugf("Certificate for domain %s is valid until %v. No renewal needed.", c.Domain, certExpiration.Format(time.RFC3339)) + } +} + +// RotateCertificates starts a background goroutine to periodically check and renew certificates. +func (cr *CertRotator) RotateCertificates(ctx context.Context) error { + log := internallog.LogFromContext(ctx) + + ticker := time.NewTicker(cr.checkInterval) + // Ensure the ticker is stopped when the function exits (though the goroutine will outlive this return) + defer ticker.Stop() + + // Start the certificate rotation watcher in a non-blocking goroutine. + go func() { // Run a check immediately on startup + // The rotation should be done periodically AND immediately on start to catch expiring certificates. + log.Debugf("[Rotator] Performing initial certificate rotation check.") + if err := cr.runRotationCheck(ctx); err != nil { + log.Errorf("[Rotator] Initial check failed: %v", err) + } + + for { + select { + case <-ctx.Done(): + // Context was canceled (e.g., application shutdown). Exit the loop gracefully. + log.Infof("[Rotator] Context cancelled. Shutting down rotation watcher.") + return + case <-ticker.C: + // The ticker fired. Time to check and potentially rotate certificates. + log.Debugf("[Rotator] Ticker fired. Running periodic check.") + if err := cr.runRotationCheck(ctx); err != nil { + log.Errorf("[Rotator] Periodic check failed: %v", err) + } + } + } + }() + + // The background process has started successfully. + return nil +} + +// runRotationCheck is a helper that orchestrates loading and checking all certificates. +func (cr *CertRotator) runRotationCheck(ctx context.Context) error { + log := internallog.LogFromContext(ctx) + + certs, err := cr.loadCertificatesWithAutoRotationEnrolled(ctx) + if err != nil { + // Log the error but continue the loop, as this is a transient failure. + return fmt.Errorf("failed to load certificates with auto rotation enabled: %w", err) + } + + for _, c := range certs { + // Run the check and rotation logic for each certificate + cr.checkAndRotateCertificate(ctx, c) + } + log.Debugf("[Rotator] Rotation check completed for %d certificates.", len(certs)) + return nil +} diff --git a/internal/pkg/rotation/rotator.go b/internal/pkg/rotation/rotator.go deleted file mode 100644 index 5ce1592..0000000 --- a/internal/pkg/rotation/rotator.go +++ /dev/null @@ -1,122 +0,0 @@ -package rotation - -import ( - "context" - internallog "envoy-control-plane/internal/log" - "envoy-control-plane/internal/pkg/api" - "envoy-control-plane/internal/pkg/cert" - certapi "envoy-control-plane/internal/pkg/cert/api" - "envoy-control-plane/internal/pkg/cert/tool" - "envoy-control-plane/internal/pkg/storage" - "fmt" - - "time" -) - -const ( - defaultRenewBefore = 7 * 24 * time.Hour -) - -func NewCertRotor(ctx context.Context) (*CertRotator, error) { - // Implementation for creating a new CertRotator - return &CertRotator{}, nil -} - -type CertRotator struct { - checkInterval time.Duration - storage storage.Storage - certIssuer certapi.CertIssuer - certPaser tool.CertificateParser -} - -func NewCertRotator(interval time.Duration, s storage.Storage, ci certapi.CertIssuer) CertRotator { - return CertRotator{ - checkInterval: interval, - storage: s, - certIssuer: ci, - certPaser: tool.CertificateParser{}, - } -} - -// loadCertificatesWithAutoRotationEnrolled retrieves all certificates marked for rotation. -// The method name has been corrected for clarity and idiomatic Go naming conventions (shorter). -func (cr *CertRotator) loadCertificatesWithAutoRotationEnrolled(ctx context.Context) ([]*storage.CertStorage, error) { - // Use the logger from context for better tracing - log := internallog.LogFromContext(ctx) - - certs, err := cr.storage.LoadAllCertificates(ctx) - if err != nil { - log.Errorf("Failed to load all certificates from storage: %w", err) - return nil, fmt.Errorf("failed to load certificates: %w", err) - } - - res := make([]*storage.CertStorage, 0, len(certs)) // Pre-allocate capacity - for _, cert := range certs { - if cert.EnableRotation { - res = append(res, cert) - } - } - log.Debugf("Loaded %d certificates enrolled for rotation.", len(res)) - return res, nil -} - -func (cr *CertRotator) RotateCertificates(ctx context.Context) error { - log := internallog.LogFromContext(ctx) - - // Implementation for rotating certificates - ticker := time.NewTicker(cr.checkInterval) - defer ticker.Stop() - - // 3. Start the goroutine loop. - go func() { - log.Infof("[Rotator] Starting certificate rotation watcher. Interval: %v\n", cr.checkInterval) - for { - select { - case <-ctx.Done(): - // Context was canceled (e.g., application shutdown). Exit the loop gracefully. - log.Infof("[Rotator] Context cancelled. Shutting down rotation watcher.") - return - case <-ticker.C: - // The ticker fired. Time to check and potentially rotate the certificate. - certs, err := cr.loadCertificatesWithAutoRotationEnrolled(ctx) - if err != nil { - // Log the error but continue the loop, as this is a transient failure. - log.Errorf("[Rotator] failed to load certificates with auto ratation enabled: %v\n", err) - } - - for _, c := range certs { - certInfo, err := cr.certPaser.Parse(c.CertPEM) - if err != nil { - log.Errorf("failed to parse the cert %s: %w", string(c.CertPEM), err) - continue - } - renewBefore := c.RenewBefore - if renewBefore == 0 { - renewBefore = defaultRenewBefore - } - - if time.Now().Add(c.RenewBefore).After(certInfo[0].NotBefore) { - oldCert := &certapi.Certificate{ - Domain: c.Domain, - CertPEM: c.CertPEM, - KeyPEM: c.KeyPEM, - AccountKey: c.AccountKey, - AccountURL: c.AccountURL, - } - newCert, err := cr.certIssuer.RenewCertificate(oldCert, api.ACME_CALLENGE_WEB_PATH, c.Email) - if err != nil { - log.Errorf("failed to renew the certificate: %w", err) - continue - } - if err := cert.SaveCertificateData(ctx, &cr.storage, newCert, c.Email, c.IssuerType, c.SecretName); err != nil { - log.Errorf("failed to save the new certificate into the database") - continue - } - - } - } - } - } - }() - return nil -} diff --git a/internal/pkg/server/server.go b/internal/pkg/server/server.go index fca8175..8f52898 100644 --- a/internal/pkg/server/server.go +++ b/internal/pkg/server/server.go @@ -11,7 +11,7 @@ ) // RunRESTServer starts the REST API server with appropriate middleware. -func RunRESTServer(ctx context.Context, mux *http.ServeMux, restPort uint, webrootPath string, enableCertIssuance bool) error { +func RunRESTServer(ctx context.Context, mux *http.ServeMux, restPort uint, webrootPath string, enableCertIssuance bool) { log := internallog.LogFromContext(ctx) corsHandler := api.CORS(mux) @@ -24,9 +24,8 @@ } if err := http.ListenAndServe(restAddr, corsHandler); err != nil { - return fmt.Errorf("REST server error: %w", err) + log.Fatalf("REST server error: %w", err) } - return nil } // NOTE: The function to start the gRPC xDS server (internal.RunServer) remains in your existing 'internal' package. diff --git a/internal/run_server.go b/internal/run_server.go index a4f5bfc..6fdf795 100644 --- a/internal/run_server.go +++ b/internal/run_server.go @@ -1,6 +1,7 @@ package internal import ( + "context" "fmt" "log" "net" @@ -28,7 +29,7 @@ ) // RunServer starts the xDS management server on the given port. -func RunServer(srv server.Server, port uint) { +func RunServer(ctx context.Context, srv server.Server, port uint) { var grpcOptions []grpc.ServerOption grpcOptions = append(grpcOptions, grpc.MaxConcurrentStreams(grpcMaxConcurrent), @@ -60,7 +61,21 @@ } log.Printf("management server listening on %s\n", addr) - if err := grpcServer.Serve(lis); err != nil { - log.Fatalf("gRPC server failed: %v", err) - } + + go func() { + log.Printf("management server listening on %s\n", addr) + + if err := grpcServer.Serve(lis); err != nil && err != grpc.ErrServerStopped { + log.Fatalf("gRPC server failed: %v", err) + } + }() + + // Block until the context is done (cancellation received) + <-ctx.Done() + + // Gracefully stop the server + log.Printf("management server stopping gracefully...") + grpcServer.GracefulStop() + log.Printf("management server stopped.") + }