package storage
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
clusterv3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3"
corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
listenerv3 "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3"
secretv3 "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3"
"google.golang.org/protobuf/encoding/protojson"
)
// =============================================================================
// I. TYPES, CONSTANTS, & FACTORY
// =============================================================================
// Storage abstracts database persistence
type Storage struct {
db *sql.DB
strategy SQLStrategy // Use the strategy interface instead of the driver string
}
// CertStorage represents the persistent data needed for certificate renewal.
type CertStorage struct {
Domain string
Email string
CertPEM []byte
KeyPEM []byte
AccountKey []byte
AccountURL string
IssuerType string
SecretName string
EnableRotation bool
RenewBefore time.Duration
}
// CertSecret is a structure to hold the secret data plus its domain link
type CertSecret struct {
Secret *secretv3.Secret
Domain sql.NullString
}
// SnapshotConfig aggregates xDS resources
type SnapshotConfig struct {
// Enabled resources (for xDS serving)
EnabledClusters []*clusterv3.Cluster
EnabledListeners []*listenerv3.Listener
EnabledSecrets []*secretv3.Secret
EnabledExtensionConfigs []*corev3.TypedExtensionConfig
// Disabled resources (for UI display)
DisabledClusters []*clusterv3.Cluster
DisabledListeners []*listenerv3.Listener
DisabledSecrets []*secretv3.Secret
DisabledExtensionConfigs []*corev3.TypedExtensionConfig
}
// RawRow is a temporary struct for DB Dump/Restore logic (not in original, but assumed)
type RawRow struct {
Name string
Data json.RawMessage `json:"data"`
}
const (
DeleteNone DeleteStrategy = iota
DeleteLogical
DeleteActual
)
// DeleteStrategy defines the action to take on missing resources
type DeleteStrategy int
// NewStorage initializes a Storage instance using the Factory to get the correct strategy.
func NewStorage(db *sql.DB, driver string) (*Storage, error) {
strategy, err := NewSQLStrategy(driver)
if err != nil {
return nil, err
}
return &Storage{db: db, strategy: strategy}, nil
}
// placeholder is now simplified to call the strategy
func (s *Storage) placeholder(n int) string {
return s.strategy.Placeholder(n)
}
// =============================================================================
// II. CORE METHODS
// =============================================================================
// InitSchema is now simplified, calling the strategy's InitSchemaSQL.
func (s *Storage) InitSchema(ctx context.Context) error {
schema := s.strategy.InitSchemaSQL()
// EXEC SCHEMA
_, err := s.db.ExecContext(ctx, schema)
if err != nil {
if strings.Contains(err.Error(), "already exists") {
return nil
}
}
return err
}
// -----------------------------------------------------------------------------
// CERTIFICATE METHODS
// -----------------------------------------------------------------------------
// SaveCertificate uses the strategy's SQL generation.
func (s *Storage) SaveCertificate(ctx context.Context, cert *CertStorage) error {
renewBeforeNanos := cert.RenewBefore.Nanoseconds()
// 1. Generate placeholders based on strategy (e.g., $1...$10 or ?...?)
ph := make([]string, 10)
for i := 0; i < 10; i++ {
ph[i] = s.placeholder(i + 1)
}
// 2. Get the full query from the strategy
query := s.strategy.SaveCertificateSQL(ph)
_, err := s.db.ExecContext(ctx, query,
cert.Domain,
cert.Email,
cert.CertPEM,
cert.KeyPEM,
cert.AccountKey,
cert.AccountURL,
cert.IssuerType,
cert.SecretName,
cert.EnableRotation,
renewBeforeNanos,
)
return err
}
// LoadCertificate is largely simplified as only the placeholder needed change.
func (s *Storage) LoadCertificate(ctx context.Context, domain string) (*CertStorage, error) {
// Use placeholder(1) and let the strategy handle the SQL dialect
query := fmt.Sprintf(`SELECT email, cert_pem, key_pem, account_key, account_url, issuer_type, secret_name, enable_rotation, renew_before FROM certificates WHERE domain = %s`, s.placeholder(1))
row := s.db.QueryRowContext(ctx, query, domain)
cert := &CertStorage{Domain: domain}
var renewBeforeNanos int64
err := row.Scan(
&cert.Email,
&cert.CertPEM,
&cert.KeyPEM,
&cert.AccountKey,
&cert.AccountURL,
&cert.IssuerType,
&cert.SecretName,
&cert.EnableRotation,
&renewBeforeNanos,
)
cert.RenewBefore = time.Duration(renewBeforeNanos)
if errors.Is(err, sql.ErrNoRows) {
return nil, fmt.Errorf("certificate for domain %s not found", domain)
}
if err != nil {
return nil, fmt.Errorf("failed to scan certificate data for %s: %w", domain, err)
}
return cert, nil
}
func (s *Storage) LoadCertificateBySecretName(ctx context.Context, secretName string) (*CertStorage, error) {
// We expect one result, similar to LoadCertificate, but querying by secret_name.
// Use placeholder(1) and let the strategy handle the SQL dialect
query := fmt.Sprintf(`SELECT domain, email, cert_pem, key_pem, account_key, account_url, issuer_type, secret_name, enable_rotation, renew_before FROM certificates WHERE secret_name = %s`, s.placeholder(1))
row := s.db.QueryRowContext(ctx, query, secretName)
cert := &CertStorage{}
var renewBeforeNanos int64
err := row.Scan(
&cert.Domain,
&cert.Email,
&cert.CertPEM,
&cert.KeyPEM,
&cert.AccountKey,
&cert.AccountURL,
&cert.IssuerType,
&cert.SecretName,
&cert.EnableRotation,
&renewBeforeNanos,
)
cert.RenewBefore = time.Duration(renewBeforeNanos)
if errors.Is(err, sql.ErrNoRows) {
// Return a specific error if no certificate is found
return nil, fmt.Errorf("certificate with secret name %s not found", secretName)
}
if err != nil {
return nil, fmt.Errorf("failed to scan certificate data for secret %s: %w", secretName, err)
}
return cert, nil
}
// UpdateCertRotationSettings updates the enable_rotation and renew_before fields
// for a specific certificate domain.
func (s *Storage) UpdateCertRotationSettings(ctx context.Context, cert *CertStorage) error {
renewBeforeNanos := cert.RenewBefore.Nanoseconds()
// 1. Define the UPDATE query to only target the rotation-related fields and updated_at.
query := fmt.Sprintf(`
UPDATE certificates
SET
enable_rotation = %s,
renew_before = %s,
updated_at = %s
WHERE domain = %s`,
s.placeholder(1), // enable_rotation
s.placeholder(2), // renew_before
s.strategy.GetTimeNow(), // updated_at (value/function)
s.placeholder(3), // domain (for WHERE)
)
// 2. Prepare the arguments in the correct order for the placeholders.
args := []interface{}{
cert.EnableRotation, // $1: true to enable, false to disable
renewBeforeNanos, // $2: The new renew_before duration
cert.Domain, // $3: The domain for the WHERE clause
}
// 3. Execute the update.
res, err := s.db.ExecContext(ctx, query, args...)
if err != nil {
return fmt.Errorf("failed to update certificate rotation settings for %s: %w", cert.Domain, err)
}
// 4. Check if a row was updated.
rowsAffected, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("failed to check rows affected after updating cert rotation for %s: %w", cert.Domain, err)
}
if rowsAffected == 0 {
return fmt.Errorf("certificate for domain %s not found to update rotation settings", cert.Domain)
}
return nil
}
// LoadAllCertificates is unchanged from the original, as it didn't have driver logic.
func (s *Storage) LoadAllCertificates(ctx context.Context) ([]*CertStorage, error) {
query := `SELECT domain, email, cert_pem, key_pem, account_key, account_url, issuer_type, secret_name, enable_rotation, renew_before FROM certificates`
rows, err := s.db.QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
var certs []*CertStorage
for rows.Next() {
cert := &CertStorage{}
var renewBeforeNanos int64
if err := rows.Scan(
&cert.Domain,
&cert.Email,
&cert.CertPEM,
&cert.KeyPEM,
&cert.AccountKey,
&cert.AccountURL,
&cert.IssuerType,
&cert.SecretName,
&cert.EnableRotation,
&renewBeforeNanos,
); err != nil {
return nil, fmt.Errorf("failed to scan all certificate data: %w", err)
}
cert.RenewBefore = time.Duration(renewBeforeNanos)
certs = append(certs, cert)
}
if err := rows.Err(); err != nil {
return nil, err
}
return certs, nil
}
// -----------------------------------------------------------------------------
// SECRET METHODS
// -----------------------------------------------------------------------------
// SaveSecret uses the strategy's SQL generation.
func (s *Storage) SaveSecret(ctx context.Context, certSecret *CertSecret) error {
secret := certSecret.Secret
data, err := protojson.Marshal(secret)
if err != nil {
return err
}
// 1. Generate placeholders
ph := make([]string, 3)
for i := 0; i < 3; i++ {
ph[i] = s.placeholder(i + 1)
}
// 2. Get the full query from the strategy
query := s.strategy.SaveSecretSQL(ph)
// Prepare arguments for ExecContext
args := []interface{}{
secret.GetName(),
string(data),
certSecret.Domain, // sql.NullString handles NULL appropriately
}
// For Postgres, need to repeat data and domain for the ON CONFLICT clause
if s.strategy.DriverName() == "postgres" {
args = append(args, string(data), certSecret.Domain)
}
_, err = s.db.ExecContext(ctx, query, args...)
return err
}
// UpdateSecretDomain is now simplified.
func (s *Storage) UpdateSecretDomain(ctx context.Context, secretName string, domainName string) error {
var domainValue interface{} = domainName
if domainName == "" {
domainValue = nil
}
query := fmt.Sprintf(`
UPDATE secrets
SET domain = %s, updated_at = %s
WHERE name = %s`,
s.placeholder(1), s.strategy.GetTimeNow(), s.placeholder(2))
res, err := s.db.ExecContext(ctx, query, domainValue, secretName)
if err != nil {
return fmt.Errorf("failed to update secret domain for %s: %w", secretName, err)
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("failed to check rows affected after updating secret domain for %s: %w", secretName, err)
}
if rowsAffected == 0 {
return fmt.Errorf("secret with name %s not found", secretName)
}
return nil
}
// -----------------------------------------------------------------------------
// CLUSTER & LISTENER METHODS
// -----------------------------------------------------------------------------
// SaveCluster uses the strategy's SQL generation.
func (s *Storage) SaveCluster(ctx context.Context, cluster *clusterv3.Cluster) error {
data, err := protojson.Marshal(cluster)
if err != nil {
return err
}
// 1. Generate placeholders
ph := make([]string, 2)
for i := 0; i < 2; i++ {
ph[i] = s.placeholder(i + 1)
}
// 2. Get the full query from the strategy
query := s.strategy.SaveClusterSQL(ph)
// Arguments are the same for all drivers
args := []interface{}{cluster.GetName(), string(data)}
// For Postgres, the data value is repeated in the ON CONFLICT clause
if s.strategy.DriverName() == "postgres" {
args = append(args, string(data))
}
_, err = s.db.ExecContext(ctx, query, args...)
return err
}
// SaveListener uses the strategy's SQL generation.
func (s *Storage) SaveListener(ctx context.Context, listener *listenerv3.Listener) error {
data, err := protojson.Marshal(listener)
if err != nil {
return err
}
// 1. Generate placeholders
ph := make([]string, 2)
for i := 0; i < 2; i++ {
ph[i] = s.placeholder(i + 1)
}
// 2. Get the full query from the strategy
query := s.strategy.SaveListenerSQL(ph)
// Arguments are the same for all drivers
args := []interface{}{listener.GetName(), string(data)}
// For Postgres, the data value is repeated in the ON CONFLICT clause
if s.strategy.DriverName() == "postgres" {
args = append(args, string(data))
}
_, err = s.db.ExecContext(ctx, query, args...)
return err
}
// SaveExtensionConfig uses the strategy's SQL generation.
func (s *Storage) SaveExtensionConfig(ctx context.Context, extConfig *corev3.TypedExtensionConfig) error {
data, err := protojson.Marshal(extConfig)
if err != nil {
return err
}
// 1. Generate placeholders
ph := make([]string, 2)
for i := 0; i < 2; i++ {
ph[i] = s.placeholder(i + 1)
}
// 2. Get the full query from the strategy
query := s.strategy.SaveExtensionConfigSQL(ph)
// Arguments are the same for all drivers
args := []interface{}{extConfig.GetName(), string(data)}
// For Postgres, the data value is repeated in the ON CONFLICT clause
if s.strategy.DriverName() == "postgres" {
args = append(args, string(data))
}
_, err = s.db.ExecContext(ctx, query, args...)
return err
}
// -----------------------------------------------------------------------------
// LOAD ALL METHODS (SIMPLIFIED driver-specific logic into a helper)
// -----------------------------------------------------------------------------
// getEnabledStatus extracts the boolean status from the database value,
// handling driver-specific types (BOOLEAN for Postgres, INT/TEXT for SQLite).
func (s *Storage) getEnabledStatus(rows *sql.Rows, raw *json.RawMessage) (bool, error) {
if s.strategy.DriverName() == "postgres" {
var enabledBool sql.NullBool
if err := rows.Scan(raw, &enabledBool); err != nil {
return false, err
}
return enabledBool.Bool, nil
}
// SQLite/Generic handling: Read TEXT data and dynamic-type enabled field
var dataStr string
var enabledAny interface{}
if err := rows.Scan(&dataStr, &enabledAny); err != nil {
return false, err
}
*raw = json.RawMessage(dataStr)
switch v := enabledAny.(type) {
case int64:
return v == 1, nil
case int:
return v == 1, nil
case bool:
return v, nil
case []byte:
return string(v) == "1" || strings.ToLower(string(v)) == "true", nil
case string:
return v == "1" || strings.ToLower(v) == "true", nil
default:
return false, fmt.Errorf("unsupported enabled column type for driver %s: %T", s.strategy.DriverName(), v)
}
}
// LoadAllClusters uses the helper function.
func (s *Storage) LoadAllClusters(ctx context.Context) (enabled []*clusterv3.Cluster, disabled []*clusterv3.Cluster, err error) {
query := `SELECT data, enabled FROM clusters`
rows, err := s.db.QueryContext(ctx, query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
for rows.Next() {
var raw json.RawMessage
enabledStatus, err := s.getEnabledStatus(rows, &raw)
if err != nil {
return nil, nil, fmt.Errorf("failed to scan cluster data: %w", err)
}
var cluster clusterv3.Cluster
if err := protojson.Unmarshal(raw, &cluster); err != nil {
return nil, nil, err
}
if enabledStatus {
enabled = append(enabled, &cluster)
} else {
disabled = append(disabled, &cluster)
}
}
if err := rows.Err(); err != nil {
return nil, nil, err
}
return enabled, disabled, nil
}
// LoadAllListeners uses the helper function.
func (s *Storage) LoadAllListeners(ctx context.Context) (enabled []*listenerv3.Listener, disabled []*listenerv3.Listener, err error) {
query := `SELECT data, enabled FROM listeners`
rows, err := s.db.QueryContext(ctx, query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
for rows.Next() {
var raw json.RawMessage
enabledStatus, err := s.getEnabledStatus(rows, &raw)
if err != nil {
return nil, nil, fmt.Errorf("failed to scan listener data: %w", err)
}
var listener listenerv3.Listener
if err := protojson.Unmarshal(raw, &listener); err != nil {
return nil, nil, err
}
if enabledStatus {
enabled = append(enabled, &listener)
} else {
disabled = append(disabled, &listener)
}
}
if err := rows.Err(); err != nil {
return nil, nil, err
}
return enabled, disabled, nil
}
// LoadAllSecrets uses the helper function.
func (s *Storage) LoadAllSecrets(ctx context.Context) (enabled []*secretv3.Secret, disabled []*secretv3.Secret, err error) {
query := `SELECT data, enabled FROM secrets`
rows, err := s.db.QueryContext(ctx, query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
for rows.Next() {
var raw json.RawMessage
enabledStatus, err := s.getEnabledStatus(rows, &raw)
if err != nil {
return nil, nil, fmt.Errorf("failed to scan secret data: %w", err)
}
var secret secretv3.Secret
if err := protojson.Unmarshal(raw, &secret); err != nil {
return nil, nil, err
}
if enabledStatus {
enabled = append(enabled, &secret)
} else {
disabled = append(disabled, &secret)
}
}
if err := rows.Err(); err != nil {
return nil, nil, err
}
return enabled, disabled, nil
}
// LoadAllExtensionConfigs uses the helper function.
func (s *Storage) LoadAllExtensionConfigs(ctx context.Context) (enabled []*corev3.TypedExtensionConfig, disabled []*corev3.TypedExtensionConfig, err error) {
query := `SELECT data, enabled FROM extension_configs`
rows, err := s.db.QueryContext(ctx, query)
if err != nil {
return nil, nil, err
}
defer rows.Close()
for rows.Next() {
var raw json.RawMessage
enabledStatus, err := s.getEnabledStatus(rows, &raw)
if err != nil {
return nil, nil, fmt.Errorf("failed to scan extension_config data: %w", err)
}
var extConfig corev3.TypedExtensionConfig
if err := protojson.Unmarshal(raw, &extConfig); err != nil {
return nil, nil, err
}
if enabledStatus {
enabled = append(enabled, &extConfig)
} else {
disabled = append(disabled, &extConfig)
}
}
if err := rows.Err(); err != nil {
return nil, nil, err
}
return enabled, disabled, nil
}
// -----------------------------------------------------------------------------
// SNAPSHOT MANAGEMENT
// -----------------------------------------------------------------------------
// RebuildSnapshot (unchanged logic)
func (s *Storage) RebuildSnapshot(ctx context.Context) (*SnapshotConfig, error) {
enabledClusters, disabledClusters, err := s.LoadAllClusters(ctx)
if err != nil {
return nil, fmt.Errorf("failed to load all clusters: %w", err)
}
enabledListeners, disabledListeners, err := s.LoadAllListeners(ctx)
if err != nil {
return nil, fmt.Errorf("failed to load all listeners: %w", err)
}
enabledSecrets, disabledSecrets, err := s.LoadAllSecrets(ctx)
if err != nil {
return nil, fmt.Errorf("failed to load all secrets: %w", err)
}
enabledExtensionConfigs, disabledExtensionConfigs, err := s.LoadAllExtensionConfigs(ctx)
if err != nil {
return nil, fmt.Errorf("failed to load all extension configs: %w", err)
}
return &SnapshotConfig{
EnabledClusters: enabledClusters,
EnabledListeners: enabledListeners,
EnabledSecrets: enabledSecrets,
DisabledClusters: disabledClusters,
DisabledListeners: disabledListeners,
DisabledSecrets: disabledSecrets,
EnabledExtensionConfigs: enabledExtensionConfigs,
DisabledExtensionConfigs: disabledExtensionConfigs,
}, nil
}
// SaveSnapshot (unchanged logic, relying on refactored Save methods)
func (s *Storage) SaveSnapshot(ctx context.Context, cfg *SnapshotConfig, strategy DeleteStrategy) error {
if cfg == nil {
return fmt.Errorf("SnapshotConfig is nil")
}
// Use a transaction for atomicity
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer func() {
if err != nil {
tx.Rollback()
return
}
err = tx.Commit()
}()
// --- 1. Save/Upsert Clusters and Collect Names ---
clusterNames := make([]string, 0, len(cfg.EnabledClusters))
for _, c := range cfg.EnabledClusters {
// NOTE: These SaveXXX methods now use the refactored logic via s.strategy
if err = s.SaveCluster(ctx, c); err != nil {
return fmt.Errorf("failed to save cluster %s: %w", c.GetName(), err)
}
clusterNames = append(clusterNames, c.GetName())
}
// --- 2. Save/Upsert Listeners and Collect Names ---
listenerNames := make([]string, 0, len(cfg.EnabledListeners))
for _, l := range cfg.EnabledListeners {
if err = s.SaveListener(ctx, l); err != nil {
return fmt.Errorf("failed to save listener %s: %w", l.GetName(), err)
}
listenerNames = append(listenerNames, l.GetName())
}
// --- 3. Save/Upsert Secrets and Collect Names ---
secretNames := make([]string, 0, len(cfg.EnabledSecrets))
for _, sec := range cfg.EnabledSecrets {
certSecret := &CertSecret{
Secret: sec,
Domain: sql.NullString{Valid: false},
}
if err = s.SaveSecret(ctx, certSecret); err != nil {
return fmt.Errorf("failed to save secret %s: %w", sec.GetName(), err)
}
secretNames = append(secretNames, sec.GetName())
}
extensionConfigNames := make([]string, 0, len(cfg.EnabledExtensionConfigs))
for _, ec := range cfg.EnabledExtensionConfigs {
if err = s.SaveExtensionConfig(ctx, ec); err != nil {
return fmt.Errorf("failed to save extension config %s: %w", ec.GetName(), err)
}
extensionConfigNames = append(extensionConfigNames, ec.GetName())
}
// --- 4. Apply Deletion Strategy ---
switch strategy {
case DeleteLogical:
if err = s.disableMissingResources(ctx, "clusters", clusterNames); err != nil {
return fmt.Errorf("failed to logically delete missing clusters: %w", err)
}
if err = s.disableMissingResources(ctx, "listeners", listenerNames); err != nil {
return fmt.Errorf("failed to logically delete missing listeners: %w", err)
}
if err = s.disableMissingResources(ctx, "secrets", secretNames); err != nil {
return fmt.Errorf("failed to logically delete missing secrets: %w", err)
}
if err = s.disableMissingResources(ctx, "extension_configs", extensionConfigNames); err != nil {
return fmt.Errorf("failed to logically delete missing extension configs: %w", err)
}
case DeleteActual:
if err = s.deleteMissingResources(ctx, "clusters", clusterNames); err != nil {
return fmt.Errorf("failed to physically delete missing clusters: %w", err)
}
if err = s.deleteMissingResources(ctx, "listeners", listenerNames); err != nil {
return fmt.Errorf("failed to physically delete missing listeners: %w", err)
}
if err = s.deleteMissingResources(ctx, "secrets", secretNames); err != nil {
return fmt.Errorf("failed to physically delete missing secrets: %w", err)
}
if err = s.deleteMissingResources(ctx, "extension_configs", extensionConfigNames); err != nil {
return fmt.Errorf("failed to physically delete missing extension configs: %w", err)
}
case DeleteNone:
return nil
}
return err
}
// -----------------------------------------------------------------------------
// ENABLE/DISABLE & DELETE METHODS
// -----------------------------------------------------------------------------
// EnableCluster is now simplified.
func (s *Storage) EnableCluster(ctx context.Context, name string, enabled bool) error {
query := fmt.Sprintf(`UPDATE clusters SET enabled = %s, updated_at = %s WHERE name = %s`,
s.placeholder(1), s.strategy.GetTimeNow(), s.placeholder(2))
_, err := s.db.ExecContext(ctx, query, enabled, name)
return err
}
// EnableListener is now simplified.
func (s *Storage) EnableListener(ctx context.Context, name string, enabled bool) error {
query := fmt.Sprintf(`UPDATE listeners SET enabled = %s, updated_at = %s WHERE name = %s`,
s.placeholder(1), s.strategy.GetTimeNow(), s.placeholder(2))
_, err := s.db.ExecContext(ctx, query, enabled, name)
return err
}
// EnableSecret is now simplified.
func (s *Storage) EnableSecret(ctx context.Context, name string, enabled bool) error {
query := fmt.Sprintf(`UPDATE secrets SET enabled = %s, updated_at = %s WHERE name = %s`,
s.placeholder(1), s.strategy.GetTimeNow(), s.placeholder(2))
_, err := s.db.ExecContext(ctx, query, enabled, name)
return err
}
func (s *Storage) EnableExtensionConfig(ctx context.Context, name string, enabled bool) error {
query := fmt.Sprintf(`UPDATE extension_configs SET enabled = %s, updated_at = %s WHERE name = %s`,
s.placeholder(1), s.strategy.GetTimeNow(), s.placeholder(2))
_, err := s.db.ExecContext(ctx, query, enabled, name)
return err
}
// RemoveListener is now simplified.
func (s *Storage) RemoveListener(ctx context.Context, name string) error {
query := fmt.Sprintf(`DELETE FROM listeners WHERE name = %s`, s.placeholder(1))
_, err := s.db.ExecContext(ctx, query, name)
return err
}
// RemoveCluster is now simplified.
func (s *Storage) RemoveCluster(ctx context.Context, name string) error {
query := fmt.Sprintf(`DELETE FROM clusters WHERE name = %s`, s.placeholder(1))
_, err := s.db.ExecContext(ctx, query, name)
return err
}
// RemoveSecret is now simplified.
func (s *Storage) RemoveSecret(ctx context.Context, name string) error {
query := fmt.Sprintf(`DELETE FROM secrets WHERE name = %s`, s.placeholder(1))
_, err := s.db.ExecContext(ctx, query, name)
return err
}
func (s *Storage) RemoveExtensionConfig(ctx context.Context, name string) error {
query := fmt.Sprintf(`DELETE FROM extension_configs WHERE name = %s`, s.placeholder(1))
_, err := s.db.ExecContext(ctx, query, name)
return err
}
// disableMissingResources uses the strategy for dialect-specific values.
func (s *Storage) disableMissingResources(ctx context.Context, table string, inputNames []string) error {
if table != "clusters" && table != "listeners" && table != "secrets" && table != "extension_configs" {
return fmt.Errorf("logical delete (disable) is only supported for tables with an 'enabled' column (clusters, listeners, secrets, extension_configs)")
}
// 1. Build placeholders and args
placeholders := make([]string, len(inputNames))
args := make([]interface{}, len(inputNames))
for i, name := range inputNames {
placeholders[i] = s.placeholder(i + 1)
args[i] = name
}
disabledValue := s.strategy.GetFalseValue()
updateTime := s.strategy.GetTimeNow()
// If no names are provided, disable ALL currently enabled resources
whereClause := ""
if len(inputNames) > 0 {
whereClause = fmt.Sprintf("WHERE name NOT IN (%s)", strings.Join(placeholders, ", "))
}
// 2. Construct and execute the UPDATE query
query := fmt.Sprintf(`
UPDATE %s
SET enabled = %s, updated_at = %s
%s`,
table, disabledValue, updateTime, whereClause)
_, err := s.db.ExecContext(ctx, query, args...)
return err
}
// deleteMissingResources is unchanged except for using the generic s.placeholder(i+1)
func (s *Storage) deleteMissingResources(ctx context.Context, table string, inputNames []string) error {
if table != "clusters" && table != "listeners" && table != "secrets" && table != "extension_configs" {
return fmt.Errorf("physical delete is only supported for tables: clusters, listeners, secrets, extension_configs")
}
// 1. Build placeholders and args
placeholders := make([]string, len(inputNames))
args := make([]interface{}, len(inputNames))
for i, name := range inputNames {
placeholders[i] = s.placeholder(i + 1)
args[i] = name
}
// If no names are provided, delete ALL resources
whereClause := ""
if len(inputNames) > 0 {
whereClause = fmt.Sprintf("WHERE name NOT IN (%s)", strings.Join(placeholders, ", "))
}
// 2. Construct and execute the DELETE query
query := fmt.Sprintf(`
DELETE FROM %s
%s`,
table, whereClause)
_, err := s.db.ExecContext(ctx, query, args...)
return err
}