diff --git a/data/config.db b/data/config.db index 38afb96..2112031 100644 --- a/data/config.db +++ b/data/config.db Binary files differ diff --git a/internal/app/app.go b/internal/app/app.go index 6595ceb..762d596 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -173,7 +173,10 @@ defer db.Close() // 3. Storage and Snapshot Manager Setup - storage := internalstorage.NewStorage(db, dbDriver) + storage, err := internalstorage.NewStorage(db, dbDriver) + if err != nil { + return fmt.Errorf("failed to create new storage: %w", err) + } if err := storage.InitSchema(ctx); err != nil { return fmt.Errorf("failed to initialize DB schema: %w", err) } diff --git a/internal/pkg/cert/api/type.go b/internal/pkg/cert/api/type.go index 33c8431..3ca3490 100644 --- a/internal/pkg/cert/api/type.go +++ b/internal/pkg/cert/api/type.go @@ -1,13 +1,17 @@ package api +import "time" + // Certificate represents the result of a successful certificate issuance. type Certificate struct { - Domain string - CertPEM []byte - KeyPEM []byte - FullChain []byte // Cert + Issuer Chain - AccountKey []byte // Private key for the ACME account - AccountURL string // URL of the ACME account + Domain string + CertPEM []byte + KeyPEM []byte + FullChain []byte // Cert + Issuer Chain + AccountKey []byte // Private key for the ACME account + AccountURL string // URL of the ACME account + EnableRotation bool + RenewBefore time.Duration } // CertIssuer defines the contract for any service that issues TLS certificates. diff --git a/internal/pkg/cert/persist.go b/internal/pkg/cert/persist.go index 407fed1..577108a 100644 --- a/internal/pkg/cert/persist.go +++ b/internal/pkg/cert/persist.go @@ -16,19 +16,23 @@ } certStorage := &storage.CertStorage{ - Domain: cert.Domain, - Email: email, // Store email with the cert - CertPEM: cert.CertPEM, - KeyPEM: cert.KeyPEM, - AccountKey: cert.AccountKey, - AccountURL: cert.AccountURL, - IssuerType: issuertype, - SecretName: secretname, + Domain: cert.Domain, + Email: email, // Store email with the cert + CertPEM: cert.CertPEM, + KeyPEM: cert.KeyPEM, + AccountKey: cert.AccountKey, + AccountURL: cert.AccountURL, + IssuerType: issuertype, + SecretName: secretname, + EnableRotation: false, } if err := store.SaveCertificate(ctx, certStorage); err != nil { return fmt.Errorf("failed to save certificate data for %s: %w", cert.Domain, err) } + if err := store.UpdateSecretDomain(ctx, secretname, cert.Domain); err != nil { + return fmt.Errorf("failed to update the domain %s for secret %s: %w", cert.Domain, secretname, err) + } return nil } @@ -48,11 +52,13 @@ } cert := &api.Certificate{ - Domain: certStorage.Domain, - CertPEM: certStorage.CertPEM, - KeyPEM: certStorage.KeyPEM, - AccountKey: certStorage.AccountKey, - AccountURL: certStorage.AccountURL, + Domain: certStorage.Domain, + CertPEM: certStorage.CertPEM, + KeyPEM: certStorage.KeyPEM, + AccountKey: certStorage.AccountKey, + AccountURL: certStorage.AccountURL, + EnableRotation: certStorage.EnableRotation, + RenewBefore: certStorage.RenewBefore, } return cert, certStorage.Email, certStorage.IssuerType, nil diff --git a/internal/pkg/storage/postgres.go b/internal/pkg/storage/postgres.go new file mode 100644 index 0000000..b162dfd --- /dev/null +++ b/internal/pkg/storage/postgres.go @@ -0,0 +1,149 @@ +package storage + +import ( + "database/sql" + "fmt" +) + +// --------------------------- +// Postgres Strategy +// --------------------------- + +type PostgresStrategy struct{} + +func (p *PostgresStrategy) DriverName() string { return "postgres" } + +// Placeholder returns $1, $2, etc. +func (p *PostgresStrategy) Placeholder(n int) string { return fmt.Sprintf("$%d", n) } + +// GetTimeNow returns the PostgreSQL function for current timestamp. +func (p *PostgresStrategy) GetTimeNow() string { return "now()" } + +// GetTrueValue returns the boolean representation. +func (p *PostgresStrategy) GetTrueValue() string { return "true" } + +// GetFalseValue returns the boolean representation. +func (p *PostgresStrategy) GetFalseValue() string { return "false" } + +func (p *PostgresStrategy) InitSchemaSQL() string { + return ` + CREATE TABLE IF NOT EXISTS clusters ( + id SERIAL PRIMARY KEY, + name TEXT UNIQUE NOT NULL, + data JSONB NOT NULL, + enabled BOOLEAN DEFAULT true, + updated_at TIMESTAMP DEFAULT now() + ); + CREATE TABLE IF NOT EXISTS listeners ( + id SERIAL PRIMARY KEY, + name TEXT UNIQUE NOT NULL, + data JSONB NOT NULL, + enabled BOOLEAN DEFAULT true, + updated_at TIMESTAMP DEFAULT now() + ); + CREATE TABLE IF NOT EXISTS secrets ( + id SERIAL PRIMARY KEY, + name TEXT UNIQUE NOT NULL, + data JSONB NOT NULL, + enabled BOOLEAN DEFAULT true, + updated_at TIMESTAMP DEFAULT now(), + domain TEXT NULL + ); + CREATE TABLE IF NOT EXISTS certificates ( + domain TEXT PRIMARY KEY, + email TEXT NOT NULL, + cert_pem BYTEA NOT NULL, + key_pem BYTEA NOT NULL, + account_key BYTEA NOT NULL, + account_url TEXT NOT NULL, + issuer_type TEXT DEFAULT '', + secret_name TEXT DEFAULT '', + updated_at TIMESTAMP DEFAULT now(), + enable_rotation BOOLEAN DEFAULT false, + renew_before BIGINT DEFAULT 0 + );` +} + +func (p *PostgresStrategy) SaveCertificateSQL(ph []string) string { + // ph[0]...ph[9] correspond to the 10 values + return fmt.Sprintf(` + INSERT INTO certificates (domain, email, cert_pem, key_pem, account_key, account_url, issuer_type, secret_name, updated_at, enable_rotation, renew_before) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, now(), %s, %s) + ON CONFLICT (domain) DO UPDATE SET + email = EXCLUDED.email, + cert_pem = EXCLUDED.cert_pem, + key_pem = EXCLUDED.key_pem, + account_key = EXCLUDED.account_key, + account_url = EXCLUDED.account_url, + issuer_type = EXCLUDED.issuer_type, + secret_name = EXCLUDED.secret_name, + updated_at = now(), + enable_rotation = EXCLUDED.enable_rotation, + renew_before = EXCLUDED.renew_before`, + ph[0], ph[1], ph[2], ph[3], ph[4], ph[5], ph[6], ph[7], ph[8], ph[9]) +} + +func (p *PostgresStrategy) SaveSecretSQL(ph []string) string { + // ph[0] = name, ph[1] = data, ph[2] = domain + return fmt.Sprintf(` + INSERT INTO secrets (name, data, enabled, updated_at, domain) + VALUES (%s, %s, true, now(), %s) + ON CONFLICT (name) DO UPDATE SET data = %s, enabled = true, updated_at = now(), domain = %s`, + ph[0], ph[1], ph[2], ph[1], ph[2]) // Note: $2, $3 are repeated for the update clause +} + +func (p *PostgresStrategy) SaveClusterSQL(ph []string) string { + // ph[0] = name, ph[1] = data + return fmt.Sprintf(` + INSERT INTO clusters (name, data, enabled, updated_at) + VALUES (%s, %s, true, now()) + ON CONFLICT (name) DO UPDATE SET data = %s, enabled = true, updated_at = now()`, + ph[0], ph[1], ph[1]) +} + +func (p *PostgresStrategy) SaveListenerSQL(ph []string) string { + // ph[0] = name, ph[1] = data + return fmt.Sprintf(` + INSERT INTO listeners (name, data, enabled, updated_at) + VALUES (%s, %s, true, now()) + ON CONFLICT (name) DO UPDATE SET data = %s, enabled = true, updated_at = now()`, + ph[0], ph[1], ph[1]) +} + +func (p *PostgresStrategy) DumpSelectFields(table string) string { + if table == "secrets" { + return "name, data, domain" + } + return "name, data" +} + +func (p *PostgresStrategy) ScanRawRow(rows *sql.Rows, row *RawRow, table string) error { + if table == "secrets" { + // Postgres: 3 fields (name, JSONB data, domain) + return rows.Scan(&row.Name, &row.Data, &row.Domain) + } + // Postgres: 2 fields (name, JSONB data) + return rows.Scan(&row.Name, &row.Data) +} + +func (p *PostgresStrategy) RestoreRawRowSQL(table string) string { + // Note: We use the strategy's placeholder method if needed, but here we hardcode $1, $2, $3 for clarity. + if table == "secrets" { + return ` + INSERT INTO secrets (name, data, enabled, updated_at, domain) + VALUES ($1, $2, true, now(), $3) + ON CONFLICT (name) + DO UPDATE SET data = EXCLUDED.data, enabled = true, updated_at = now(), domain = EXCLUDED.domain` + } + // clusters or listeners (only needs $1 and $2 for name and data) + return fmt.Sprintf(` + INSERT INTO %s (name, data, enabled, updated_at) + VALUES ($1, $2, true, now()) + ON CONFLICT (name) + DO UPDATE SET data = EXCLUDED.data, enabled = true, updated_at = now()`, table) +} + +func (p *PostgresStrategy) ClearTableSQL(table string) string { + // Postgres preferred way to clear and reset IDs + return fmt.Sprintf("TRUNCATE TABLE %s RESTART IDENTITY", table) +} diff --git a/internal/pkg/storage/sqlite.go b/internal/pkg/storage/sqlite.go new file mode 100644 index 0000000..fb77474 --- /dev/null +++ b/internal/pkg/storage/sqlite.go @@ -0,0 +1,148 @@ +package storage + +import ( + "database/sql" + "fmt" +) + +// --------------------------- +// SQLite Strategy +// --------------------------- + +type SQLiteStrategy struct{} + +func (s *SQLiteStrategy) DriverName() string { return "sqlite" } + +// Placeholder returns ?. +func (s *SQLiteStrategy) Placeholder(n int) string { return "?" } + +// GetTimeNow returns the SQLite function for current timestamp. +func (s *SQLiteStrategy) GetTimeNow() string { return "CURRENT_TIMESTAMP" } + +// GetTrueValue returns the integer representation. +func (s *SQLiteStrategy) GetTrueValue() string { return "1" } + +// GetFalseValue returns the integer representation. +func (s *SQLiteStrategy) GetFalseValue() string { return "0" } + +func (s *SQLiteStrategy) InitSchemaSQL() string { + return ` + CREATE TABLE IF NOT EXISTS clusters ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT UNIQUE NOT NULL, + data TEXT NOT NULL, + enabled BOOLEAN DEFAULT 1, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ); + CREATE TABLE IF NOT EXISTS listeners ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT UNIQUE NOT NULL, + data TEXT NOT NULL, + enabled BOOLEAN DEFAULT 1, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ); + CREATE TABLE IF NOT EXISTS secrets ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT UNIQUE NOT NULL, + data TEXT NOT NULL, + enabled BOOLEAN DEFAULT 1, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + domain TEXT NULL + ); + CREATE TABLE IF NOT EXISTS certificates ( + domain TEXT PRIMARY KEY, + email TEXT NOT NULL, + cert_pem BLOB NOT NULL, + key_pem BLOB NOT NULL, + account_key BLOB NOT NULL, + account_url TEXT NOT NULL, + issuer_type TEXT DEFAULT '', + secret_name TEXT DEFAULT '', + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + enable_rotation BOOLEAN DEFAULT 0, + renew_before INTEGER DEFAULT 0 + );` +} + +func (s *SQLiteStrategy) SaveCertificateSQL(ph []string) string { + return ` + INSERT INTO certificates (domain, email, cert_pem, key_pem, account_key, account_url, issuer_type, secret_name, updated_at, enable_rotation, renew_before) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP, ?, ?) + ON CONFLICT(domain) DO UPDATE SET + email = excluded.email, + cert_pem = excluded.cert_pem, + key_pem = excluded.key_pem, + account_key = excluded.account_key, + account_url = excluded.account_url, + issuer_type = excluded.issuer_type, + secret_name = excluded.secret_name, + updated_at = CURRENT_TIMESTAMP, + enable_rotation = excluded.enable_rotation, + renew_before = excluded.renew_before` +} + +func (s *SQLiteStrategy) SaveSecretSQL(ph []string) string { + return ` + INSERT INTO secrets (name, data, enabled, updated_at, domain) + VALUES (?, ?, 1, CURRENT_TIMESTAMP, ?) + ON CONFLICT(name) DO UPDATE SET data=excluded.data, enabled=1, updated_at=CURRENT_TIMESTAMP, domain=excluded.domain` +} + +func (s *SQLiteStrategy) SaveClusterSQL(ph []string) string { + return ` + INSERT INTO clusters (name, data, enabled, updated_at) + VALUES (?, ?, 1, CURRENT_TIMESTAMP) + ON CONFLICT(name) DO UPDATE SET data=excluded.data, enabled=1, updated_at=CURRENT_TIMESTAMP` +} + +func (s *SQLiteStrategy) SaveListenerSQL(ph []string) string { + return ` + INSERT INTO listeners (name, data, enabled, updated_at) + VALUES (?, ?, 1, CURRENT_TIMESTAMP) + ON CONFLICT(name) DO UPDATE SET data=excluded.data, enabled=1, updated_at=CURRENT_TIMESTAMP` +} + +func (s *SQLiteStrategy) DumpSelectFields(table string) string { + if table == "secrets" { + return "name, data, domain" + } + return "name, data" +} + +func (s *SQLiteStrategy) ScanRawRow(rows *sql.Rows, row *RawRow, table string) error { + var dataStr string + if table == "secrets" { + // SQLite: 3 fields (name, TEXT data, domain) + if err := rows.Scan(&row.Name, &dataStr, &row.Domain); err != nil { + return err + } + } else { + // SQLite: 2 fields (name, TEXT data) + if err := rows.Scan(&row.Name, &dataStr); err != nil { + return err + } + } + row.Data = []byte(dataStr) + return nil +} + +func (s *SQLiteStrategy) RestoreRawRowSQL(table string) string { + if table == "secrets" { + return ` + INSERT INTO secrets (name, data, enabled, updated_at, domain) + VALUES (?, ?, 1, CURRENT_TIMESTAMP, ?) + ON CONFLICT(name) + DO UPDATE SET data=excluded.data, enabled=1, updated_at=CURRENT_TIMESTAMP, domain=excluded.domain` + } + // clusters or listeners + return fmt.Sprintf(` + INSERT INTO %s (name, data, enabled, updated_at) + VALUES (?, ?, 1, CURRENT_TIMESTAMP) + ON CONFLICT(name) + DO UPDATE SET data=excluded.data, enabled=1, updated_at=CURRENT_TIMESTAMP`, table) +} + +func (s *SQLiteStrategy) ClearTableSQL(table string) string { + // SQLite safe way to clear + return fmt.Sprintf("DELETE FROM %s", table) +} diff --git a/internal/pkg/storage/storage.go b/internal/pkg/storage/storage.go index 72cfc01..b40d419 100644 --- a/internal/pkg/storage/storage.go +++ b/internal/pkg/storage/storage.go @@ -7,176 +7,124 @@ "errors" "fmt" "strings" + "time" clusterv3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" listenerv3 "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" - secretv3 "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" // SDS Import + secretv3 "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" "google.golang.org/protobuf/encoding/protojson" ) +// ============================================================================= +// I. TYPES, CONSTANTS, & FACTORY +// ============================================================================= + // Storage abstracts database persistence type Storage struct { - db *sql.DB - driver string + db *sql.DB + strategy SQLStrategy // Use the strategy interface instead of the driver string } +// CertStorage represents the persistent data needed for certificate renewal. +type CertStorage struct { + Domain string + Email string + CertPEM []byte + KeyPEM []byte + AccountKey []byte + AccountURL string + IssuerType string + SecretName string + EnableRotation bool + RenewBefore time.Duration +} + +// CertSecret is a structure to hold the secret data plus its domain link +type CertSecret struct { + Secret *secretv3.Secret + Domain sql.NullString +} + +// SnapshotConfig aggregates xDS resources +type SnapshotConfig struct { + // Enabled resources (for xDS serving) + EnabledClusters []*clusterv3.Cluster + EnabledListeners []*listenerv3.Listener + EnabledSecrets []*secretv3.Secret + + // Disabled resources (for UI display) + DisabledClusters []*clusterv3.Cluster + DisabledListeners []*listenerv3.Listener + DisabledSecrets []*secretv3.Secret +} + +// RawRow is a temporary struct for DB Dump/Restore logic (not in original, but assumed) +type RawRow struct { + Name string + Data []byte + // Used only by the secrets table + Domain sql.NullString +} + +const ( + DeleteNone DeleteStrategy = iota + DeleteLogical + DeleteActual +) + // DeleteStrategy defines the action to take on missing resources type DeleteStrategy int -const ( - // DeleteNone performs only UPSERT for items in the list (default behavior) - DeleteNone DeleteStrategy = iota - // DeleteLogical marks missing resources as disabled (applicable to clusters, listeners, and secrets) - DeleteLogical - // DeleteActual removes missing resources physically from the database - DeleteActual -) - -// NewStorage initializes a Storage instance -func NewStorage(db *sql.DB, driver string) *Storage { - return &Storage{db: db, driver: driver} +// NewStorage initializes a Storage instance using the Factory to get the correct strategy. +func NewStorage(db *sql.DB, driver string) (*Storage, error) { + strategy, err := NewSQLStrategy(driver) + if err != nil { + return nil, err + } + return &Storage{db: db, strategy: strategy}, nil } -// placeholder returns correct SQL placeholder based on driver +// placeholder is now simplified to call the strategy func (s *Storage) placeholder(n int) string { - if s.driver == "postgres" { - return fmt.Sprintf("$%d", n) - } - return "?" + return s.strategy.Placeholder(n) } -// CertStorage represents the persistent data needed for certificate renewal. -// This mirrors the data that was previously stored in the internalcertapi.Certificate. -type CertStorage struct { - Domain string // The certificate domain (used as the primary key) - Email string // The ACME account email - CertPEM []byte // The current certificate (public part + chain) - KeyPEM []byte // The domain's private key - AccountKey []byte // The ACME Account private key D-value (for signing) - AccountURL string // The ACME Account URI (KID) - IssuerType string // The type of issuer (e.g., "LetsEncrypt"). Default to "" - SecretName string // The name of the SDS Secret this certificate is associated with. Empty if unlinked/manual. -} +// ============================================================================= +// II. CORE METHODS +// ============================================================================= -// InitSchema ensures required tables exist +// InitSchema is now simplified, calling the strategy's InitSchemaSQL. func (s *Storage) InitSchema(ctx context.Context) error { - var schema string - switch s.driver { - case "postgres": - schema = ` - CREATE TABLE IF NOT EXISTS clusters ( - id SERIAL PRIMARY KEY, - name TEXT UNIQUE NOT NULL, - data JSONB NOT NULL, - enabled BOOLEAN DEFAULT true, - updated_at TIMESTAMP DEFAULT now() - ); - CREATE TABLE IF NOT EXISTS listeners ( - id SERIAL PRIMARY KEY, - name TEXT UNIQUE NOT NULL, - data JSONB NOT NULL, - enabled BOOLEAN DEFAULT true, - updated_at TIMESTAMP DEFAULT now() - ); - -- SDS secrets table for Postgres - CREATE TABLE IF NOT EXISTS secrets ( - id SERIAL PRIMARY KEY, - name TEXT UNIQUE NOT NULL, - data JSONB NOT NULL, - enabled BOOLEAN DEFAULT true, - updated_at TIMESTAMP DEFAULT now() - ); - -- 👇 UPDATED CERTIFICATE TABLE FOR ACME RENEWAL - CREATE TABLE IF NOT EXISTS certificates ( - domain TEXT PRIMARY KEY, - email TEXT NOT NULL, - cert_pem BYTEA NOT NULL, - key_pem BYTEA NOT NULL, - account_key BYTEA NOT NULL, - account_url TEXT NOT NULL, - issuer_type TEXT DEFAULT '', - secret_name TEXT DEFAULT '', -- New field - updated_at TIMESTAMP DEFAULT now() - );` - default: // SQLite - schema = ` - CREATE TABLE IF NOT EXISTS clusters ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT UNIQUE NOT NULL, - data TEXT NOT NULL, - enabled BOOLEAN DEFAULT 1, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ); - CREATE TABLE IF NOT EXISTS listeners ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT UNIQUE NOT NULL, - data TEXT NOT NULL, - enabled BOOLEAN DEFAULT 1, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ); - -- SDS secrets table for SQLite - CREATE TABLE IF NOT EXISTS secrets ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT UNIQUE NOT NULL, - data TEXT NOT NULL, - enabled BOOLEAN DEFAULT 1, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ); - -- 👇 UPDATED CERTIFICATE TABLE FOR ACME RENEWAL - CREATE TABLE IF NOT EXISTS certificates ( - domain TEXT PRIMARY KEY, - email TEXT NOT NULL, - cert_pem BLOB NOT NULL, - key_pem BLOB NOT NULL, - account_key BLOB NOT NULL, - account_url TEXT NOT NULL, - issuer_type TEXT DEFAULT '', - secret_name TEXT DEFAULT '', -- New field - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - );` - } + schema := s.strategy.InitSchemaSQL() + + // EXEC SCHEMA _, err := s.db.ExecContext(ctx, schema) + if err != nil { + if strings.Contains(err.Error(), "already exists") { + return nil + } + } return err } // ----------------------------------------------------------------------------- -// NEW CERTIFICATE METHODS (UPSERT & LOAD) +// CERTIFICATE METHODS // ----------------------------------------------------------------------------- -// SaveCertificate inserts or updates a certificate resource +// SaveCertificate uses the strategy's SQL generation. func (s *Storage) SaveCertificate(ctx context.Context, cert *CertStorage) error { - var query string - switch s.driver { - case "postgres": - query = fmt.Sprintf(` - INSERT INTO certificates (domain, email, cert_pem, key_pem, account_key, account_url, issuer_type, secret_name, updated_at) - VALUES (%s, %s, %s, %s, %s, %s, %s, %s, now()) - ON CONFLICT (domain) DO UPDATE SET - email = EXCLUDED.email, - cert_pem = EXCLUDED.cert_pem, - key_pem = EXCLUDED.key_pem, - account_key = EXCLUDED.account_key, - account_url = EXCLUDED.account_url, - issuer_type = EXCLUDED.issuer_type, - secret_name = EXCLUDED.secret_name, -- Updated field - updated_at = now()`, - s.placeholder(1), s.placeholder(2), s.placeholder(3), s.placeholder(4), s.placeholder(5), s.placeholder(6), s.placeholder(7), s.placeholder(8)) - default: // SQLite - query = ` - INSERT INTO certificates (domain, email, cert_pem, key_pem, account_key, account_url, issuer_type, secret_name, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP) - ON CONFLICT(domain) DO UPDATE SET - email = excluded.email, - cert_pem = excluded.cert_pem, - key_pem = excluded.key_pem, - account_key = excluded.account_key, - account_url = excluded.account_url, - issuer_type = excluded.issuer_type, - secret_name = excluded.secret_name, -- Updated field - updated_at = CURRENT_TIMESTAMP` + renewBeforeNanos := cert.RenewBefore.Nanoseconds() + + // 1. Generate placeholders based on strategy (e.g., $1...$10 or ?...?) + ph := make([]string, 10) + for i := 0; i < 10; i++ { + ph[i] = s.placeholder(i + 1) } + // 2. Get the full query from the strategy + query := s.strategy.SaveCertificateSQL(ph) + _, err := s.db.ExecContext(ctx, query, cert.Domain, cert.Email, @@ -185,24 +133,36 @@ cert.AccountKey, cert.AccountURL, cert.IssuerType, - cert.SecretName, // New field + cert.SecretName, + cert.EnableRotation, + renewBeforeNanos, ) return err } -// LoadCertificate retrieves a certificate resource by domain +// LoadCertificate is largely simplified as only the placeholder needed change. func (s *Storage) LoadCertificate(ctx context.Context, domain string) (*CertStorage, error) { - // Updated SELECT statement to include secret_name - query := `SELECT email, cert_pem, key_pem, account_key, account_url, issuer_type, secret_name FROM certificates WHERE domain = $1` - if s.driver != "postgres" { - query = `SELECT email, cert_pem, key_pem, account_key, account_url, issuer_type, secret_name FROM certificates WHERE domain = ?` - } + // Use placeholder(1) and let the strategy handle the SQL dialect + query := fmt.Sprintf(`SELECT email, cert_pem, key_pem, account_key, account_url, issuer_type, secret_name, enable_rotation, renew_before FROM certificates WHERE domain = %s`, s.placeholder(1)) row := s.db.QueryRowContext(ctx, query, domain) cert := &CertStorage{Domain: domain} - // Updated Scan call to include &cert.SecretName - err := row.Scan(&cert.Email, &cert.CertPEM, &cert.KeyPEM, &cert.AccountKey, &cert.AccountURL, &cert.IssuerType, &cert.SecretName) + var renewBeforeNanos int64 + + err := row.Scan( + &cert.Email, + &cert.CertPEM, + &cert.KeyPEM, + &cert.AccountKey, + &cert.AccountURL, + &cert.IssuerType, + &cert.SecretName, + &cert.EnableRotation, + &renewBeforeNanos, + ) + + cert.RenewBefore = time.Duration(renewBeforeNanos) if errors.Is(err, sql.ErrNoRows) { return nil, fmt.Errorf("certificate for domain %s not found", domain) @@ -214,10 +174,9 @@ return cert, nil } -// LoadAllCertificates retrieves all stored certificate resources +// LoadAllCertificates is unchanged from the original, as it didn't have driver logic. func (s *Storage) LoadAllCertificates(ctx context.Context) ([]*CertStorage, error) { - // Updated SELECT statement to include secret_name - query := `SELECT domain, email, cert_pem, key_pem, account_key, account_url, issuer_type, secret_name FROM certificates` + query := `SELECT domain, email, cert_pem, key_pem, account_key, account_url, issuer_type, secret_name, enable_rotation, renew_before FROM certificates` rows, err := s.db.QueryContext(ctx, query) if err != nil { @@ -228,10 +187,24 @@ var certs []*CertStorage for rows.Next() { cert := &CertStorage{} - // Updated Scan call to include &cert.SecretName - if err := rows.Scan(&cert.Domain, &cert.Email, &cert.CertPEM, &cert.KeyPEM, &cert.AccountKey, &cert.AccountURL, &cert.IssuerType, &cert.SecretName); err != nil { + var renewBeforeNanos int64 + + if err := rows.Scan( + &cert.Domain, + &cert.Email, + &cert.CertPEM, + &cert.KeyPEM, + &cert.AccountKey, + &cert.AccountURL, + &cert.IssuerType, + &cert.SecretName, + &cert.EnableRotation, + &renewBeforeNanos, + ); err != nil { return nil, fmt.Errorf("failed to scan all certificate data: %w", err) } + + cert.RenewBefore = time.Duration(renewBeforeNanos) certs = append(certs, cert) } if err := rows.Err(); err != nil { @@ -242,397 +215,290 @@ } // ----------------------------------------------------------------------------- -// REST OF THE ORIGINAL CODE FOLLOWS... +// SECRET METHODS // ----------------------------------------------------------------------------- -// SaveCluster inserts or updates a cluster +// SaveSecret uses the strategy's SQL generation. +func (s *Storage) SaveSecret(ctx context.Context, certSecret *CertSecret) error { + secret := certSecret.Secret + data, err := protojson.Marshal(secret) + if err != nil { + return err + } + + // 1. Generate placeholders + ph := make([]string, 3) + for i := 0; i < 3; i++ { + ph[i] = s.placeholder(i + 1) + } + + // 2. Get the full query from the strategy + query := s.strategy.SaveSecretSQL(ph) + + // Prepare arguments for ExecContext + args := []interface{}{ + secret.GetName(), + string(data), + certSecret.Domain, // sql.NullString handles NULL appropriately + } + + // For Postgres, need to repeat data and domain for the ON CONFLICT clause + if s.strategy.DriverName() == "postgres" { + args = append(args, string(data), certSecret.Domain) + } + + _, err = s.db.ExecContext(ctx, query, args...) + return err +} + +// UpdateSecretDomain is now simplified. +func (s *Storage) UpdateSecretDomain(ctx context.Context, secretName string, domainName string) error { + var domainValue interface{} = domainName + if domainName == "" { + domainValue = nil + } + + query := fmt.Sprintf(` + UPDATE secrets + SET domain = %s, updated_at = %s + WHERE name = %s`, + s.placeholder(1), s.strategy.GetTimeNow(), s.placeholder(2)) + + res, err := s.db.ExecContext(ctx, query, domainValue, secretName) + if err != nil { + return fmt.Errorf("failed to update secret domain for %s: %w", secretName, err) + } + + rowsAffected, err := res.RowsAffected() + if err != nil { + return fmt.Errorf("failed to check rows affected after updating secret domain for %s: %w", secretName, err) + } + + if rowsAffected == 0 { + return fmt.Errorf("secret with name %s not found", secretName) + } + + return nil +} + +// ----------------------------------------------------------------------------- +// CLUSTER & LISTENER METHODS +// ----------------------------------------------------------------------------- + +// SaveCluster uses the strategy's SQL generation. func (s *Storage) SaveCluster(ctx context.Context, cluster *clusterv3.Cluster) error { data, err := protojson.Marshal(cluster) if err != nil { return err } - var query string - switch s.driver { - case "postgres": - // Explicitly set enabled=true on update to re-enable a logically deleted cluster - query = fmt.Sprintf(` - INSERT INTO clusters (name, data, enabled, updated_at) - VALUES (%s, %s, true, now()) - ON CONFLICT (name) DO UPDATE SET data = %s, enabled = true, updated_at = now()`, - s.placeholder(1), s.placeholder(2), s.placeholder(2)) - default: // SQLite - // Explicitly set enabled=1 on update to re-enable a logically deleted cluster - query = ` - INSERT INTO clusters (name, data, enabled, updated_at) - VALUES (?, ?, 1, CURRENT_TIMESTAMP) - ON CONFLICT(name) DO UPDATE SET data=excluded.data, enabled=1, updated_at=CURRENT_TIMESTAMP` + // 1. Generate placeholders + ph := make([]string, 2) + for i := 0; i < 2; i++ { + ph[i] = s.placeholder(i + 1) } - _, err = s.db.ExecContext(ctx, query, cluster.GetName(), string(data)) + // 2. Get the full query from the strategy + query := s.strategy.SaveClusterSQL(ph) + + // Arguments are the same for all drivers + args := []interface{}{cluster.GetName(), string(data)} + + // For Postgres, the data value is repeated in the ON CONFLICT clause + if s.strategy.DriverName() == "postgres" { + args = append(args, string(data)) + } + + _, err = s.db.ExecContext(ctx, query, args...) return err } -// SaveListener inserts or updates a listener +// SaveListener uses the strategy's SQL generation. func (s *Storage) SaveListener(ctx context.Context, listener *listenerv3.Listener) error { data, err := protojson.Marshal(listener) if err != nil { return err } - var query string - switch s.driver { - case "postgres": - // Explicitly set enabled=true on update to re-enable a logically deleted listener - query = fmt.Sprintf(` - INSERT INTO listeners (name, data, enabled, updated_at) - VALUES (%s, %s, true, now()) - ON CONFLICT (name) DO UPDATE SET data = %s, enabled = true, updated_at = now()`, - s.placeholder(1), s.placeholder(2), s.placeholder(2)) - default: // SQLite - // Explicitly set enabled=1 on update to re-enable a logically deleted listener - query = ` - INSERT INTO listeners (name, data, enabled, updated_at) - VALUES (?, ?, 1, CURRENT_TIMESTAMP) - ON CONFLICT(name) DO UPDATE SET data=excluded.data, enabled=1, updated_at=CURRENT_TIMESTAMP` + // 1. Generate placeholders + ph := make([]string, 2) + for i := 0; i < 2; i++ { + ph[i] = s.placeholder(i + 1) } - _, err = s.db.ExecContext(ctx, query, listener.GetName(), string(data)) - return err -} + // 2. Get the full query from the strategy + query := s.strategy.SaveListenerSQL(ph) -// SaveSecret inserts or updates a Secret -func (s *Storage) SaveSecret(ctx context.Context, secret *secretv3.Secret) error { - data, err := protojson.Marshal(secret) - if err != nil { - return err + // Arguments are the same for all drivers + args := []interface{}{listener.GetName(), string(data)} + + // For Postgres, the data value is repeated in the ON CONFLICT clause + if s.strategy.DriverName() == "postgres" { + args = append(args, string(data)) } - var query string - switch s.driver { - case "postgres": - // Explicitly set enabled=true on update to re-enable a logically deleted secret - query = fmt.Sprintf(` - INSERT INTO secrets (name, data, enabled, updated_at) - VALUES (%s, %s, true, now()) - ON CONFLICT (name) DO UPDATE SET data = %s, enabled = true, updated_at = now()`, - s.placeholder(1), s.placeholder(2), s.placeholder(2)) - default: // SQLite - // Explicitly set enabled=1 on update to re-enable a logically deleted secret - query = ` - INSERT INTO secrets (name, data, enabled, updated_at) - VALUES (?, ?, 1, CURRENT_TIMESTAMP) - ON CONFLICT(name) DO UPDATE SET data=excluded.data, enabled=1, updated_at=CURRENT_TIMESTAMP` - } - - _, err = s.db.ExecContext(ctx, query, secret.GetName(), string(data)) + _, err = s.db.ExecContext(ctx, query, args...) return err } // ----------------------------------------------------------------------------- -// LOAD ENABLED METHODS (UNCHANGED) +// LOAD ALL METHODS (SIMPLIFIED driver-specific logic into a helper) // ----------------------------------------------------------------------------- -// LoadEnabledClusters retrieves all enabled clusters -func (s *Storage) LoadEnabledClusters(ctx context.Context) ([]*clusterv3.Cluster, error) { - query := `SELECT data FROM clusters` - if s.driver == "postgres" { - query += ` WHERE enabled = true` - } else { - query += ` WHERE enabled = 1` +// getEnabledStatus extracts the boolean status from the database value, +// handling driver-specific types (BOOLEAN for Postgres, INT/TEXT for SQLite). +func (s *Storage) getEnabledStatus(rows *sql.Rows, raw *json.RawMessage) (bool, error) { + if s.strategy.DriverName() == "postgres" { + var enabledBool sql.NullBool + if err := rows.Scan(raw, &enabledBool); err != nil { + return false, err + } + return enabledBool.Bool, nil } + // SQLite/Generic handling: Read TEXT data and dynamic-type enabled field + var dataStr string + var enabledAny interface{} + if err := rows.Scan(&dataStr, &enabledAny); err != nil { + return false, err + } + *raw = json.RawMessage(dataStr) + + switch v := enabledAny.(type) { + case int64: + return v == 1, nil + case int: + return v == 1, nil + case bool: + return v, nil + case []byte: + return string(v) == "1" || strings.ToLower(string(v)) == "true", nil + case string: + return v == "1" || strings.ToLower(v) == "true", nil + default: + return false, fmt.Errorf("unsupported enabled column type for driver %s: %T", s.strategy.DriverName(), v) + } +} + +// LoadAllClusters uses the helper function. +func (s *Storage) LoadAllClusters(ctx context.Context) (enabled []*clusterv3.Cluster, disabled []*clusterv3.Cluster, err error) { + query := `SELECT data, enabled FROM clusters` + rows, err := s.db.QueryContext(ctx, query) if err != nil { - return nil, err + return nil, nil, err } defer rows.Close() - var clusters []*clusterv3.Cluster for rows.Next() { var raw json.RawMessage - if s.driver != "postgres" { - var dataStr string - if err := rows.Scan(&dataStr); err != nil { - return nil, err - } - raw = json.RawMessage(dataStr) - } else { - if err := rows.Scan(&raw); err != nil { - return nil, err - } + enabledStatus, err := s.getEnabledStatus(rows, &raw) + if err != nil { + return nil, nil, fmt.Errorf("failed to scan cluster data: %w", err) } var cluster clusterv3.Cluster if err := protojson.Unmarshal(raw, &cluster); err != nil { - return nil, err + return nil, nil, err } - clusters = append(clusters, &cluster) + + if enabledStatus { + enabled = append(enabled, &cluster) + } else { + disabled = append(disabled, &cluster) + } } - return clusters, nil + if err := rows.Err(); err != nil { + return nil, nil, err + } + return enabled, disabled, nil } -// LoadEnabledListeners retrieves all enabled listeners -func (s *Storage) LoadEnabledListeners(ctx context.Context) ([]*listenerv3.Listener, error) { - query := `SELECT data FROM listeners` - if s.driver == "postgres" { - query += ` WHERE enabled = true` - } else { - query += ` WHERE enabled = 1` - } +// LoadAllListeners uses the helper function. +func (s *Storage) LoadAllListeners(ctx context.Context) (enabled []*listenerv3.Listener, disabled []*listenerv3.Listener, err error) { + query := `SELECT data, enabled FROM listeners` rows, err := s.db.QueryContext(ctx, query) if err != nil { - return nil, err + return nil, nil, err } defer rows.Close() - var listeners []*listenerv3.Listener for rows.Next() { var raw json.RawMessage - if s.driver != "postgres" { - var dataStr string - if err := rows.Scan(&dataStr); err != nil { - return nil, err - } - raw = json.RawMessage(dataStr) - } else { - if err := rows.Scan(&raw); err != nil { - return nil, err - } + enabledStatus, err := s.getEnabledStatus(rows, &raw) + if err != nil { + return nil, nil, fmt.Errorf("failed to scan listener data: %w", err) } - var l listenerv3.Listener - if err := protojson.Unmarshal(raw, &l); err != nil { - return nil, err + var listener listenerv3.Listener + if err := protojson.Unmarshal(raw, &listener); err != nil { + return nil, nil, err } - listeners = append(listeners, &l) + + if enabledStatus { + enabled = append(enabled, &listener) + } else { + disabled = append(disabled, &listener) + } } - return listeners, nil + if err := rows.Err(); err != nil { + return nil, nil, err + } + return enabled, disabled, nil } -// LoadEnabledSecrets retrieves all enabled secrets -func (s *Storage) LoadEnabledSecrets(ctx context.Context) ([]*secretv3.Secret, error) { - query := `SELECT data FROM secrets` - if s.driver == "postgres" { - query += ` WHERE enabled = true` - } else { - query += ` WHERE enabled = 1` - } +// LoadAllSecrets uses the helper function. +func (s *Storage) LoadAllSecrets(ctx context.Context) (enabled []*secretv3.Secret, disabled []*secretv3.Secret, err error) { + query := `SELECT data, enabled FROM secrets` rows, err := s.db.QueryContext(ctx, query) if err != nil { - return nil, err + return nil, nil, err } defer rows.Close() - var secrets []*secretv3.Secret for rows.Next() { var raw json.RawMessage - if s.driver != "postgres" { - var dataStr string - if err := rows.Scan(&dataStr); err != nil { - return nil, err - } - raw = json.RawMessage(dataStr) - } else { - if err := rows.Scan(&raw); err != nil { - return nil, err - } + enabledStatus, err := s.getEnabledStatus(rows, &raw) + if err != nil { + return nil, nil, fmt.Errorf("failed to scan secret data: %w", err) } var secret secretv3.Secret if err := protojson.Unmarshal(raw, &secret); err != nil { - return nil, err + return nil, nil, err } - secrets = append(secrets, &secret) - } - return secrets, nil -} -// ----------------------------------------------------------------------------- -// LOAD ALL METHODS (UNCHANGED) -// ----------------------------------------------------------------------------- - -// LoadAllClusters retrieves all clusters, regardless of their enabled status -func (s *Storage) LoadAllClusters(ctx context.Context) ([]*clusterv3.Cluster, error) { - rows, err := s.db.QueryContext(ctx, `SELECT data FROM clusters`) - if err != nil { - return nil, err - } - defer rows.Close() - - var clusters []*clusterv3.Cluster - for rows.Next() { - var raw json.RawMessage - if s.driver != "postgres" { - var dataStr string - if err := rows.Scan(&dataStr); err != nil { - return nil, err - } - raw = json.RawMessage(dataStr) + if enabledStatus { + enabled = append(enabled, &secret) } else { - if err := rows.Scan(&raw); err != nil { - return nil, err - } + disabled = append(disabled, &secret) } - - var cluster clusterv3.Cluster - if err := protojson.Unmarshal(raw, &cluster); err != nil { - return nil, err - } - clusters = append(clusters, &cluster) } - return clusters, nil -} - -// LoadAllListeners retrieves all listeners, regardless of their enabled status -func (s *Storage) LoadAllListeners(ctx context.Context) ([]*listenerv3.Listener, error) { - rows, err := s.db.QueryContext(ctx, `SELECT data FROM listeners`) - if err != nil { - return nil, err + if err := rows.Err(); err != nil { + return nil, nil, err } - defer rows.Close() - - var listeners []*listenerv3.Listener - for rows.Next() { - var raw json.RawMessage - if s.driver != "postgres" { - var dataStr string - if err := rows.Scan(&dataStr); err != nil { - return nil, err - } - raw = json.RawMessage(dataStr) - } else { - if err := rows.Scan(&raw); err != nil { - return nil, err - } - } - - var l listenerv3.Listener - if err := protojson.Unmarshal(raw, &l); err != nil { - return nil, err - } - listeners = append(listeners, &l) - } - return listeners, nil -} - -// LoadAllSecrets retrieves all secrets, regardless of their enabled status -func (s *Storage) LoadAllSecrets(ctx context.Context) ([]*secretv3.Secret, error) { - rows, err := s.db.QueryContext(ctx, `SELECT data FROM secrets`) - if err != nil { - return nil, err - } - defer rows.Close() - - var secrets []*secretv3.Secret - for rows.Next() { - var raw json.RawMessage - if s.driver != "postgres" { - var dataStr string - if err := rows.Scan(&dataStr); err != nil { - return nil, err - } - raw = json.RawMessage(dataStr) - } else { - if err := rows.Scan(&raw); err != nil { - return nil, err - } - } - - var secret secretv3.Secret - if err := protojson.Unmarshal(raw, &secret); err != nil { - return nil, err - } - secrets = append(secrets, &secret) - } - return secrets, nil + return enabled, disabled, nil } // ----------------------------------------------------------------------------- -// SNAPSHOT MANAGEMENT (UNCHANGED) +// SNAPSHOT MANAGEMENT // ----------------------------------------------------------------------------- -// SnapshotConfig aggregates xDS resources -type SnapshotConfig struct { - // Enabled resources (for xDS serving) - EnabledClusters []*clusterv3.Cluster - EnabledListeners []*listenerv3.Listener - EnabledSecrets []*secretv3.Secret // New SDS resource - - // Disabled resources (for UI display) - DisabledClusters []*clusterv3.Cluster - DisabledListeners []*listenerv3.Listener - DisabledSecrets []*secretv3.Secret // New SDS resource -} - -// RebuildSnapshot rebuilds full snapshot from DB +// RebuildSnapshot (unchanged logic) func (s *Storage) RebuildSnapshot(ctx context.Context) (*SnapshotConfig, error) { - // 1. Load Enabled Resources (for xDS serving) - enabledClusters, err := s.LoadEnabledClusters(ctx) + enabledClusters, disabledClusters, err := s.LoadAllClusters(ctx) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to load all clusters: %w", err) } - enabledListeners, err := s.LoadEnabledListeners(ctx) + enabledListeners, disabledListeners, err := s.LoadAllListeners(ctx) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to load all listeners: %w", err) } - enabledSecrets, err := s.LoadEnabledSecrets(ctx) + enabledSecrets, disabledSecrets, err := s.LoadAllSecrets(ctx) if err != nil { - return nil, err - } - - // 2. Load ALL Resources (for comparison and disabled set) - allClusters, err := s.LoadAllClusters(ctx) - if err != nil { - return nil, err - } - allListeners, err := s.LoadAllListeners(ctx) - if err != nil { - return nil, err - } - allSecrets, err := s.LoadAllSecrets(ctx) - if err != nil { - return nil, err - } - - // 3. Separate Disabled Resources - - // Clusters - enabledClusterNames := make(map[string]struct{}, len(enabledClusters)) - for _, c := range enabledClusters { - enabledClusterNames[c.GetName()] = struct{}{} - } - var disabledClusters []*clusterv3.Cluster - for _, c := range allClusters { - if _, found := enabledClusterNames[c.GetName()]; !found { - disabledClusters = append(disabledClusters, c) - } - } - - // Listeners - enabledListenerNames := make(map[string]struct{}, len(enabledListeners)) - for _, l := range enabledListeners { - enabledListenerNames[l.GetName()] = struct{}{} - } - var disabledListeners []*listenerv3.Listener - for _, l := range allListeners { - if _, found := enabledListenerNames[l.GetName()]; !found { - disabledListeners = append(disabledListeners, l) - } - } - - // Secrets - enabledSecretNames := make(map[string]struct{}, len(enabledSecrets)) - for _, sec := range enabledSecrets { - enabledSecretNames[sec.GetName()] = struct{}{} - } - var disabledSecrets []*secretv3.Secret - for _, sec := range allSecrets { - if _, found := enabledSecretNames[sec.GetName()]; !found { - disabledSecrets = append(disabledSecrets, sec) - } + return nil, fmt.Errorf("failed to load all secrets: %w", err) } return &SnapshotConfig{ @@ -645,7 +511,7 @@ }, nil } -// SaveSnapshot saves the entire snapshot to the DB +// SaveSnapshot (unchanged logic, relying on refactored Save methods) func (s *Storage) SaveSnapshot(ctx context.Context, cfg *SnapshotConfig, strategy DeleteStrategy) error { if cfg == nil { return fmt.Errorf("SnapshotConfig is nil") @@ -667,7 +533,7 @@ // --- 1. Save/Upsert Clusters and Collect Names --- clusterNames := make([]string, 0, len(cfg.EnabledClusters)) for _, c := range cfg.EnabledClusters { - // NOTE: This uses the existing SaveCluster which doesn't use the transaction 'tx' + // NOTE: These SaveXXX methods now use the refactored logic via s.strategy if err = s.SaveCluster(ctx, c); err != nil { return fmt.Errorf("failed to save cluster %s: %w", c.GetName(), err) } @@ -677,7 +543,6 @@ // --- 2. Save/Upsert Listeners and Collect Names --- listenerNames := make([]string, 0, len(cfg.EnabledListeners)) for _, l := range cfg.EnabledListeners { - // NOTE: This uses the existing SaveListener which doesn't use the transaction 'tx' if err = s.SaveListener(ctx, l); err != nil { return fmt.Errorf("failed to save listener %s: %w", l.GetName(), err) } @@ -687,8 +552,11 @@ // --- 3. Save/Upsert Secrets and Collect Names --- secretNames := make([]string, 0, len(cfg.EnabledSecrets)) for _, sec := range cfg.EnabledSecrets { - // NOTE: This uses the existing SaveSecret which doesn't use the transaction 'tx' - if err = s.SaveSecret(ctx, sec); err != nil { + certSecret := &CertSecret{ + Secret: sec, + Domain: sql.NullString{Valid: false}, + } + if err = s.SaveSecret(ctx, certSecret); err != nil { return fmt.Errorf("failed to save secret %s: %w", sec.GetName(), err) } secretNames = append(secretNames, sec.GetName()) @@ -697,7 +565,6 @@ // --- 4. Apply Deletion Strategy --- switch strategy { case DeleteLogical: - // Logical Delete (Disable) for all resource types if err = s.disableMissingResources(ctx, "clusters", clusterNames); err != nil { return fmt.Errorf("failed to logically delete missing clusters: %w", err) } @@ -709,7 +576,6 @@ } case DeleteActual: - // Actual Delete (Physical Removal) for all resources if err = s.deleteMissingResources(ctx, "clusters", clusterNames); err != nil { return fmt.Errorf("failed to physically delete missing clusters: %w", err) } @@ -721,7 +587,6 @@ } case DeleteNone: - // Do nothing for missing resources return nil } @@ -729,71 +594,55 @@ } // ----------------------------------------------------------------------------- -// ENABLE/DISABLE & DELETE METHODS (UNCHANGED) +// ENABLE/DISABLE & DELETE METHODS // ----------------------------------------------------------------------------- -// EnableCluster toggles a cluster +// EnableCluster is now simplified. func (s *Storage) EnableCluster(ctx context.Context, name string, enabled bool) error { - query := `UPDATE clusters SET enabled = ?, updated_at = CURRENT_TIMESTAMP WHERE name = ?` - if s.driver == "postgres" { - query = `UPDATE clusters SET enabled = $1, updated_at = now() WHERE name = $2` - } + query := fmt.Sprintf(`UPDATE clusters SET enabled = %s, updated_at = %s WHERE name = %s`, + s.placeholder(1), s.strategy.GetTimeNow(), s.placeholder(2)) _, err := s.db.ExecContext(ctx, query, enabled, name) return err } -// EnableListener toggles a listener +// EnableListener is now simplified. func (s *Storage) EnableListener(ctx context.Context, name string, enabled bool) error { - query := `UPDATE listeners SET enabled = ?, updated_at = CURRENT_TIMESTAMP WHERE name = ?` - if s.driver == "postgres" { - query = `UPDATE listeners SET enabled = $1, updated_at = now() WHERE name = $2` - } + query := fmt.Sprintf(`UPDATE listeners SET enabled = %s, updated_at = %s WHERE name = %s`, + s.placeholder(1), s.strategy.GetTimeNow(), s.placeholder(2)) _, err := s.db.ExecContext(ctx, query, enabled, name) return err } -// EnableSecret toggles a secret +// EnableSecret is now simplified. func (s *Storage) EnableSecret(ctx context.Context, name string, enabled bool) error { - query := `UPDATE secrets SET enabled = ?, updated_at = CURRENT_TIMESTAMP WHERE name = ?` - if s.driver == "postgres" { - query = `UPDATE secrets SET enabled = $1, updated_at = now() WHERE name = $2` - } + query := fmt.Sprintf(`UPDATE secrets SET enabled = %s, updated_at = %s WHERE name = %s`, + s.placeholder(1), s.strategy.GetTimeNow(), s.placeholder(2)) _, err := s.db.ExecContext(ctx, query, enabled, name) return err } -// RemoveListener deletes a listener by name +// RemoveListener is now simplified. func (s *Storage) RemoveListener(ctx context.Context, name string) error { - query := `DELETE FROM listeners WHERE name = ?` - if s.driver == "postgres" { - query = `DELETE FROM listeners WHERE name = $1` - } + query := fmt.Sprintf(`DELETE FROM listeners WHERE name = %s`, s.placeholder(1)) _, err := s.db.ExecContext(ctx, query, name) return err } -// RemoveCluster deletes a cluster by name +// RemoveCluster is now simplified. func (s *Storage) RemoveCluster(ctx context.Context, name string) error { - query := `DELETE FROM clusters WHERE name = ?` - if s.driver == "postgres" { - query = `DELETE FROM clusters WHERE name = $1` - } + query := fmt.Sprintf(`DELETE FROM clusters WHERE name = %s`, s.placeholder(1)) _, err := s.db.ExecContext(ctx, query, name) return err } -// RemoveSecret deletes a secret by name +// RemoveSecret is now simplified. func (s *Storage) RemoveSecret(ctx context.Context, name string) error { - query := `DELETE FROM secrets WHERE name = ?` - if s.driver == "postgres" { - query = `DELETE FROM secrets WHERE name = $1` - } + query := fmt.Sprintf(`DELETE FROM secrets WHERE name = %s`, s.placeholder(1)) _, err := s.db.ExecContext(ctx, query, name) return err } -// disableMissingResources updates the 'enabled' status for resources in 'table' -// whose 'name' is NOT in 'inputNames'. +// disableMissingResources uses the strategy for dialect-specific values. func (s *Storage) disableMissingResources(ctx context.Context, table string, inputNames []string) error { if table != "clusters" && table != "listeners" && table != "secrets" { return fmt.Errorf("logical delete (disable) is only supported for tables with an 'enabled' column (clusters, listeners, secrets)") @@ -803,25 +652,12 @@ placeholders := make([]string, len(inputNames)) args := make([]interface{}, len(inputNames)) for i, name := range inputNames { - if s.driver == "postgres" { - placeholders[i] = fmt.Sprintf("$%d", i+1) - } else { - placeholders[i] = "?" - } + placeholders[i] = s.placeholder(i + 1) args[i] = name } - disabledValue := "false" - if s.driver != "postgres" { - disabledValue = "0" - } - - var updateTime string - if s.driver == "postgres" { - updateTime = "now()" - } else { - updateTime = "CURRENT_TIMESTAMP" - } + disabledValue := s.strategy.GetFalseValue() + updateTime := s.strategy.GetTimeNow() // If no names are provided, disable ALL currently enabled resources whereClause := "" @@ -831,16 +667,16 @@ // 2. Construct and execute the UPDATE query query := fmt.Sprintf(` - UPDATE %s - SET enabled = %s, updated_at = %s - %s`, + UPDATE %s + SET enabled = %s, updated_at = %s + %s`, table, disabledValue, updateTime, whereClause) _, err := s.db.ExecContext(ctx, query, args...) return err } -// deleteMissingResources physically deletes resources from 'table' whose 'name' is NOT in 'inputNames'. +// deleteMissingResources is unchanged except for using the generic s.placeholder(i+1) func (s *Storage) deleteMissingResources(ctx context.Context, table string, inputNames []string) error { if table != "clusters" && table != "listeners" && table != "secrets" { return fmt.Errorf("physical delete is only supported for tables: clusters, listeners, secrets") @@ -850,11 +686,7 @@ placeholders := make([]string, len(inputNames)) args := make([]interface{}, len(inputNames)) for i, name := range inputNames { - if s.driver == "postgres" { - placeholders[i] = fmt.Sprintf("$%d", i+1) - } else { - placeholders[i] = "?" - } + placeholders[i] = s.placeholder(i + 1) args[i] = name } @@ -866,8 +698,8 @@ // 2. Construct and execute the DELETE query query := fmt.Sprintf(` - DELETE FROM %s - %s`, + DELETE FROM %s + %s`, table, whereClause) _, err := s.db.ExecContext(ctx, query, args...) diff --git a/internal/pkg/storage/storage_dump.go b/internal/pkg/storage/storage_dump.go index 7d75294..9f1fca3 100644 --- a/internal/pkg/storage/storage_dump.go +++ b/internal/pkg/storage/storage_dump.go @@ -7,6 +7,9 @@ "fmt" ) +// NOTE: This file assumes the SQLStrategy interface and concrete implementations +// (PostgresStrategy, SQLiteStrategy) defined in the previous response are present. + // DBDump holds the complete state of the database for dumping/restoring. type DBDump struct { Clusters []*RawRow `json:"clusters,omitempty"` @@ -16,12 +19,6 @@ DBDriver string `json:"db_driver"` } -// RawRow represents a single row of xDS data with name + JSON body. -type RawRow struct { - Name string `json:"name"` - Data json.RawMessage `json:"data"` -} - // DBDumpRestoreMode defines the strategy for restoring the data. type DBDumpRestoreMode int @@ -32,14 +29,22 @@ RestoreOverride ) +// ----------------------------------------------------------------------------- +// DUMP METHODS (REFACTORED) +// ----------------------------------------------------------------------------- + // Dump exports all database content as JSON including certificates. func (s *Storage) Dump(ctx context.Context) ([]byte, error) { + // Refactored Load helper uses the strategy for SQL generation and scanning load := func(ctx context.Context, table string) ([]*RawRow, error) { if table != "clusters" && table != "listeners" && table != "secrets" { return nil, fmt.Errorf("invalid table for dump: %s", table) } - query := fmt.Sprintf(`SELECT name, data FROM %s`, table) + // Use the strategy to get the fields to select + fields := s.strategy.DumpSelectFields(table) + query := fmt.Sprintf(`SELECT %s FROM %s`, fields, table) + rows, err := s.db.QueryContext(ctx, query) if err != nil { return nil, fmt.Errorf("failed to query table %s: %w", table, err) @@ -49,16 +54,9 @@ var out []*RawRow for rows.Next() { row := &RawRow{} - if s.driver == "postgres" { - if err := rows.Scan(&row.Name, &row.Data); err != nil { - return nil, fmt.Errorf("scan error on %s: %w", table, err) - } - } else { - var dataStr string - if err := rows.Scan(&row.Name, &dataStr); err != nil { - return nil, fmt.Errorf("scan error on %s: %w", table, err) - } - row.Data = []byte(dataStr) + // Delegate all scanning logic to the strategy + if err := s.strategy.ScanRawRow(rows, row, table); err != nil { + return nil, fmt.Errorf("scan error on %s: %w", table, err) } out = append(out, row) } @@ -88,7 +86,7 @@ Listeners: listeners, Secrets: secrets, Certificates: certs, - DBDriver: s.driver, + DBDriver: s.strategy.DriverName(), // Use strategy for driver name } data, err := json.MarshalIndent(dump, "", " ") @@ -98,6 +96,10 @@ return data, nil } +// ----------------------------------------------------------------------------- +// RESTORE METHODS (REFACTORED) +// ----------------------------------------------------------------------------- + // Restore imports database content from a JSON dump (merge or override). func (s *Storage) Restore(ctx context.Context, data []byte, mode DBDumpRestoreMode) error { var dump DBDump @@ -105,8 +107,8 @@ return fmt.Errorf("failed to parse dump JSON: %w", err) } - if dump.DBDriver != s.driver { - return fmt.Errorf("database driver mismatch: dump='%s', current='%s'", dump.DBDriver, s.driver) + if dump.DBDriver != s.strategy.DriverName() { + return fmt.Errorf("database driver mismatch: dump='%s', current='%s'", dump.DBDriver, s.strategy.DriverName()) } if mode == RestoreOverride { @@ -117,29 +119,35 @@ } } + // saveRaw now relies entirely on the strategy to provide the correct UPSERT query saveRaw := func(ctx context.Context, table string, rows []*RawRow) error { for _, r := range rows { if r == nil || r.Name == "" { continue } - var query string - switch s.driver { - case "postgres": - query = fmt.Sprintf(` - INSERT INTO %s (name, data, enabled, updated_at) - VALUES ($1, $2, true, now()) - ON CONFLICT (name) - DO UPDATE SET data = $2, enabled = true, updated_at = now()`, table) - default: - query = fmt.Sprintf(` - INSERT INTO %s (name, data, enabled, updated_at) - VALUES (?, ?, 1, CURRENT_TIMESTAMP) - ON CONFLICT(name) - DO UPDATE SET data=excluded.data, enabled=1, updated_at=CURRENT_TIMESTAMP`, table) + // Get the correct UPSERT query from the strategy + query := s.strategy.RestoreRawRowSQL(table) + + var args []interface{} + + // Handle arguments based on table (secrets needs 3 args, others need 2) + if table == "secrets" { + args = []interface{}{r.Name, string(r.Data), r.Domain} + } else { + args = []interface{}{r.Name, string(r.Data)} } - if _, err := s.db.ExecContext(ctx, query, r.Name, string(r.Data)); err != nil { + // Postgres requires the data argument to be repeated for the ON CONFLICT clause + if s.strategy.DriverName() == "postgres" { + // The implementation of RestoreRawRowSQL for Postgres must ensure + // that the query only expects the standard set of positional arguments + // for the VALUES clause ($1, $2, [$3]), as the data is repeated + // using the EXCLUDED prefix in Postgres' ON CONFLICT syntax. + // No extra args needed here, the strategy handles the SQL structure. + } + + if _, err := s.db.ExecContext(ctx, query, args...); err != nil { return fmt.Errorf("failed to upsert %s '%s': %w", table, r.Name, err) } } @@ -156,6 +164,7 @@ return err } + // Certificates still use the core SaveCertificate method, which is already refactored. for _, cert := range dump.Certificates { if cert == nil || cert.Domain == "" { continue @@ -168,7 +177,7 @@ return nil } -// clearTable deletes all rows from the given table (safe against SQLi). +// clearTable is now simplified. func (s *Storage) clearTable(ctx context.Context, table string) error { valid := map[string]bool{ "clusters": true, @@ -180,10 +189,8 @@ return fmt.Errorf("invalid table name: %s", table) } - query := fmt.Sprintf("DELETE FROM %s", table) - if s.driver == "postgres" { - query = fmt.Sprintf("TRUNCATE TABLE %s RESTART IDENTITY", table) - } + // Delegate the specific clear/truncate query to the strategy + query := s.strategy.ClearTableSQL(table) _, err := s.db.ExecContext(ctx, query) if err != nil { diff --git a/internal/pkg/storage/strategy.go b/internal/pkg/storage/strategy.go new file mode 100644 index 0000000..ce5b31b --- /dev/null +++ b/internal/pkg/storage/strategy.go @@ -0,0 +1,53 @@ +package storage + +import ( + "database/sql" + "errors" + "fmt" + "strings" +) + +// ============================================================================= +// I. STRATEGY INTERFACE & FACTORY +// ============================================================================= + +// SQLStrategy abstracts all database-specific SQL generation and logic. +// This is the core of the Strategy Pattern implementation. +type SQLStrategy interface { + DriverName() string + // Placeholder returns the correct placeholder string for the nth argument (e.g., $1 or ?). + Placeholder(n int) string + InitSchemaSQL() string + + // Save methods: Generate the full INSERT/UPDATE query using the provided placeholders. + SaveCertificateSQL(placeholders []string) string + SaveSecretSQL(placeholders []string) string + SaveClusterSQL(placeholders []string) string + SaveListenerSQL(placeholders []string) string + + // Dialect-specific functions (e.g., now() vs CURRENT_TIMESTAMP) + GetTimeNow() string + GetTrueValue() string + GetFalseValue() string + + // Dump/Restore Logic + DumpSelectFields(table string) string + ScanRawRow(rows *sql.Rows, row *RawRow, table string) error + RestoreRawRowSQL(table string) string + ClearTableSQL(table string) string +} + +// NewSQLStrategy is the factory function to create the correct strategy implementation. +func NewSQLStrategy(driver string) (SQLStrategy, error) { + switch strings.ToLower(driver) { + case "postgres": + return &PostgresStrategy{}, nil + case "sqlite", "sqlite3": + return &SQLiteStrategy{}, nil + // Add support for new databases here! + case "mysql": + return nil, errors.New("MySQL strategy not yet implemented") + default: + return nil, fmt.Errorf("unsupported database driver: %s", driver) + } +} diff --git a/static/style.css b/static/style.css index 7cd0c6a..2b69652 100644 --- a/static/style.css +++ b/static/style.css @@ -207,7 +207,7 @@ text-decoration: none; } -.toolbar button:hover { +.toolbar button:hover, .toolbar a:hover { background-color: var(--primary-color); color: white; /* Text inverts to white on hover */ }