Newer
Older
EnvoyControlPlane / internal / pkg / storage / storage.go
package storage

import (
	"context"
	"database/sql"
	"encoding/json"
	"errors"
	"fmt"
	"strings"
	"time"

	clusterv3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3"
	listenerv3 "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3"
	secretv3 "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3"

	"google.golang.org/protobuf/encoding/protojson"
)

// =============================================================================
// I. TYPES, CONSTANTS, & FACTORY
// =============================================================================

// Storage abstracts database persistence
type Storage struct {
	db       *sql.DB
	strategy SQLStrategy // Use the strategy interface instead of the driver string
}

// CertStorage represents the persistent data needed for certificate renewal.
type CertStorage struct {
	Domain         string
	Email          string
	CertPEM        []byte
	KeyPEM         []byte
	AccountKey     []byte
	AccountURL     string
	IssuerType     string
	SecretName     string
	EnableRotation bool
	RenewBefore    time.Duration
}

// CertSecret is a structure to hold the secret data plus its domain link
type CertSecret struct {
	Secret *secretv3.Secret
	Domain sql.NullString
}

// SnapshotConfig aggregates xDS resources
type SnapshotConfig struct {
	// Enabled resources (for xDS serving)
	EnabledClusters  []*clusterv3.Cluster
	EnabledListeners []*listenerv3.Listener
	EnabledSecrets   []*secretv3.Secret

	// Disabled resources (for UI display)
	DisabledClusters  []*clusterv3.Cluster
	DisabledListeners []*listenerv3.Listener
	DisabledSecrets   []*secretv3.Secret
}

// RawRow is a temporary struct for DB Dump/Restore logic (not in original, but assumed)
type RawRow struct {
	Name string
	Data []byte
	// Used only by the secrets table
	Domain sql.NullString
}

const (
	DeleteNone DeleteStrategy = iota
	DeleteLogical
	DeleteActual
)

// DeleteStrategy defines the action to take on missing resources
type DeleteStrategy int

// NewStorage initializes a Storage instance using the Factory to get the correct strategy.
func NewStorage(db *sql.DB, driver string) (*Storage, error) {
	strategy, err := NewSQLStrategy(driver)
	if err != nil {
		return nil, err
	}
	return &Storage{db: db, strategy: strategy}, nil
}

// placeholder is now simplified to call the strategy
func (s *Storage) placeholder(n int) string {
	return s.strategy.Placeholder(n)
}

// =============================================================================
// II. CORE METHODS
// =============================================================================

// InitSchema is now simplified, calling the strategy's InitSchemaSQL.
func (s *Storage) InitSchema(ctx context.Context) error {
	schema := s.strategy.InitSchemaSQL()

	// EXEC SCHEMA
	_, err := s.db.ExecContext(ctx, schema)
	if err != nil {
		if strings.Contains(err.Error(), "already exists") {
			return nil
		}
	}
	return err
}

// -----------------------------------------------------------------------------
// CERTIFICATE METHODS
// -----------------------------------------------------------------------------

// SaveCertificate uses the strategy's SQL generation.
func (s *Storage) SaveCertificate(ctx context.Context, cert *CertStorage) error {
	renewBeforeNanos := cert.RenewBefore.Nanoseconds()

	// 1. Generate placeholders based on strategy (e.g., $1...$10 or ?...?)
	ph := make([]string, 10)
	for i := 0; i < 10; i++ {
		ph[i] = s.placeholder(i + 1)
	}

	// 2. Get the full query from the strategy
	query := s.strategy.SaveCertificateSQL(ph)

	_, err := s.db.ExecContext(ctx, query,
		cert.Domain,
		cert.Email,
		cert.CertPEM,
		cert.KeyPEM,
		cert.AccountKey,
		cert.AccountURL,
		cert.IssuerType,
		cert.SecretName,
		cert.EnableRotation,
		renewBeforeNanos,
	)
	return err
}

// LoadCertificate is largely simplified as only the placeholder needed change.
func (s *Storage) LoadCertificate(ctx context.Context, domain string) (*CertStorage, error) {
	// Use placeholder(1) and let the strategy handle the SQL dialect
	query := fmt.Sprintf(`SELECT email, cert_pem, key_pem, account_key, account_url, issuer_type, secret_name, enable_rotation, renew_before FROM certificates WHERE domain = %s`, s.placeholder(1))

	row := s.db.QueryRowContext(ctx, query, domain)

	cert := &CertStorage{Domain: domain}
	var renewBeforeNanos int64

	err := row.Scan(
		&cert.Email,
		&cert.CertPEM,
		&cert.KeyPEM,
		&cert.AccountKey,
		&cert.AccountURL,
		&cert.IssuerType,
		&cert.SecretName,
		&cert.EnableRotation,
		&renewBeforeNanos,
	)

	cert.RenewBefore = time.Duration(renewBeforeNanos)

	if errors.Is(err, sql.ErrNoRows) {
		return nil, fmt.Errorf("certificate for domain %s not found", domain)
	}
	if err != nil {
		return nil, fmt.Errorf("failed to scan certificate data for %s: %w", domain, err)
	}

	return cert, nil
}

// LoadAllCertificates is unchanged from the original, as it didn't have driver logic.
func (s *Storage) LoadAllCertificates(ctx context.Context) ([]*CertStorage, error) {
	query := `SELECT domain, email, cert_pem, key_pem, account_key, account_url, issuer_type, secret_name, enable_rotation, renew_before FROM certificates`

	rows, err := s.db.QueryContext(ctx, query)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var certs []*CertStorage
	for rows.Next() {
		cert := &CertStorage{}
		var renewBeforeNanos int64

		if err := rows.Scan(
			&cert.Domain,
			&cert.Email,
			&cert.CertPEM,
			&cert.KeyPEM,
			&cert.AccountKey,
			&cert.AccountURL,
			&cert.IssuerType,
			&cert.SecretName,
			&cert.EnableRotation,
			&renewBeforeNanos,
		); err != nil {
			return nil, fmt.Errorf("failed to scan all certificate data: %w", err)
		}

		cert.RenewBefore = time.Duration(renewBeforeNanos)
		certs = append(certs, cert)
	}
	if err := rows.Err(); err != nil {
		return nil, err
	}

	return certs, nil
}

// -----------------------------------------------------------------------------
// SECRET METHODS
// -----------------------------------------------------------------------------

// SaveSecret uses the strategy's SQL generation.
func (s *Storage) SaveSecret(ctx context.Context, certSecret *CertSecret) error {
	secret := certSecret.Secret
	data, err := protojson.Marshal(secret)
	if err != nil {
		return err
	}

	// 1. Generate placeholders
	ph := make([]string, 3)
	for i := 0; i < 3; i++ {
		ph[i] = s.placeholder(i + 1)
	}

	// 2. Get the full query from the strategy
	query := s.strategy.SaveSecretSQL(ph)

	// Prepare arguments for ExecContext
	args := []interface{}{
		secret.GetName(),
		string(data),
		certSecret.Domain, // sql.NullString handles NULL appropriately
	}

	// For Postgres, need to repeat data and domain for the ON CONFLICT clause
	if s.strategy.DriverName() == "postgres" {
		args = append(args, string(data), certSecret.Domain)
	}

	_, err = s.db.ExecContext(ctx, query, args...)
	return err
}

// UpdateSecretDomain is now simplified.
func (s *Storage) UpdateSecretDomain(ctx context.Context, secretName string, domainName string) error {
	var domainValue interface{} = domainName
	if domainName == "" {
		domainValue = nil
	}

	query := fmt.Sprintf(`
		UPDATE secrets
		SET domain = %s, updated_at = %s
		WHERE name = %s`,
		s.placeholder(1), s.strategy.GetTimeNow(), s.placeholder(2))

	res, err := s.db.ExecContext(ctx, query, domainValue, secretName)
	if err != nil {
		return fmt.Errorf("failed to update secret domain for %s: %w", secretName, err)
	}

	rowsAffected, err := res.RowsAffected()
	if err != nil {
		return fmt.Errorf("failed to check rows affected after updating secret domain for %s: %w", secretName, err)
	}

	if rowsAffected == 0 {
		return fmt.Errorf("secret with name %s not found", secretName)
	}

	return nil
}

// -----------------------------------------------------------------------------
// CLUSTER & LISTENER METHODS
// -----------------------------------------------------------------------------

// SaveCluster uses the strategy's SQL generation.
func (s *Storage) SaveCluster(ctx context.Context, cluster *clusterv3.Cluster) error {
	data, err := protojson.Marshal(cluster)
	if err != nil {
		return err
	}

	// 1. Generate placeholders
	ph := make([]string, 2)
	for i := 0; i < 2; i++ {
		ph[i] = s.placeholder(i + 1)
	}

	// 2. Get the full query from the strategy
	query := s.strategy.SaveClusterSQL(ph)

	// Arguments are the same for all drivers
	args := []interface{}{cluster.GetName(), string(data)}

	// For Postgres, the data value is repeated in the ON CONFLICT clause
	if s.strategy.DriverName() == "postgres" {
		args = append(args, string(data))
	}

	_, err = s.db.ExecContext(ctx, query, args...)
	return err
}

// SaveListener uses the strategy's SQL generation.
func (s *Storage) SaveListener(ctx context.Context, listener *listenerv3.Listener) error {
	data, err := protojson.Marshal(listener)
	if err != nil {
		return err
	}

	// 1. Generate placeholders
	ph := make([]string, 2)
	for i := 0; i < 2; i++ {
		ph[i] = s.placeholder(i + 1)
	}

	// 2. Get the full query from the strategy
	query := s.strategy.SaveListenerSQL(ph)

	// Arguments are the same for all drivers
	args := []interface{}{listener.GetName(), string(data)}

	// For Postgres, the data value is repeated in the ON CONFLICT clause
	if s.strategy.DriverName() == "postgres" {
		args = append(args, string(data))
	}

	_, err = s.db.ExecContext(ctx, query, args...)
	return err
}

// -----------------------------------------------------------------------------
// LOAD ALL METHODS (SIMPLIFIED driver-specific logic into a helper)
// -----------------------------------------------------------------------------

// getEnabledStatus extracts the boolean status from the database value,
// handling driver-specific types (BOOLEAN for Postgres, INT/TEXT for SQLite).
func (s *Storage) getEnabledStatus(rows *sql.Rows, raw *json.RawMessage) (bool, error) {
	if s.strategy.DriverName() == "postgres" {
		var enabledBool sql.NullBool
		if err := rows.Scan(raw, &enabledBool); err != nil {
			return false, err
		}
		return enabledBool.Bool, nil
	}

	// SQLite/Generic handling: Read TEXT data and dynamic-type enabled field
	var dataStr string
	var enabledAny interface{}
	if err := rows.Scan(&dataStr, &enabledAny); err != nil {
		return false, err
	}
	*raw = json.RawMessage(dataStr)

	switch v := enabledAny.(type) {
	case int64:
		return v == 1, nil
	case int:
		return v == 1, nil
	case bool:
		return v, nil
	case []byte:
		return string(v) == "1" || strings.ToLower(string(v)) == "true", nil
	case string:
		return v == "1" || strings.ToLower(v) == "true", nil
	default:
		return false, fmt.Errorf("unsupported enabled column type for driver %s: %T", s.strategy.DriverName(), v)
	}
}

// LoadAllClusters uses the helper function.
func (s *Storage) LoadAllClusters(ctx context.Context) (enabled []*clusterv3.Cluster, disabled []*clusterv3.Cluster, err error) {
	query := `SELECT data, enabled FROM clusters`

	rows, err := s.db.QueryContext(ctx, query)
	if err != nil {
		return nil, nil, err
	}
	defer rows.Close()

	for rows.Next() {
		var raw json.RawMessage
		enabledStatus, err := s.getEnabledStatus(rows, &raw)
		if err != nil {
			return nil, nil, fmt.Errorf("failed to scan cluster data: %w", err)
		}

		var cluster clusterv3.Cluster
		if err := protojson.Unmarshal(raw, &cluster); err != nil {
			return nil, nil, err
		}

		if enabledStatus {
			enabled = append(enabled, &cluster)
		} else {
			disabled = append(disabled, &cluster)
		}
	}
	if err := rows.Err(); err != nil {
		return nil, nil, err
	}
	return enabled, disabled, nil
}

// LoadAllListeners uses the helper function.
func (s *Storage) LoadAllListeners(ctx context.Context) (enabled []*listenerv3.Listener, disabled []*listenerv3.Listener, err error) {
	query := `SELECT data, enabled FROM listeners`

	rows, err := s.db.QueryContext(ctx, query)
	if err != nil {
		return nil, nil, err
	}
	defer rows.Close()

	for rows.Next() {
		var raw json.RawMessage
		enabledStatus, err := s.getEnabledStatus(rows, &raw)
		if err != nil {
			return nil, nil, fmt.Errorf("failed to scan listener data: %w", err)
		}

		var listener listenerv3.Listener
		if err := protojson.Unmarshal(raw, &listener); err != nil {
			return nil, nil, err
		}

		if enabledStatus {
			enabled = append(enabled, &listener)
		} else {
			disabled = append(disabled, &listener)
		}
	}
	if err := rows.Err(); err != nil {
		return nil, nil, err
	}
	return enabled, disabled, nil
}

// LoadAllSecrets uses the helper function.
func (s *Storage) LoadAllSecrets(ctx context.Context) (enabled []*secretv3.Secret, disabled []*secretv3.Secret, err error) {
	query := `SELECT data, enabled FROM secrets`

	rows, err := s.db.QueryContext(ctx, query)
	if err != nil {
		return nil, nil, err
	}
	defer rows.Close()

	for rows.Next() {
		var raw json.RawMessage
		enabledStatus, err := s.getEnabledStatus(rows, &raw)
		if err != nil {
			return nil, nil, fmt.Errorf("failed to scan secret data: %w", err)
		}

		var secret secretv3.Secret
		if err := protojson.Unmarshal(raw, &secret); err != nil {
			return nil, nil, err
		}

		if enabledStatus {
			enabled = append(enabled, &secret)
		} else {
			disabled = append(disabled, &secret)
		}
	}
	if err := rows.Err(); err != nil {
		return nil, nil, err
	}
	return enabled, disabled, nil
}

// -----------------------------------------------------------------------------
// SNAPSHOT MANAGEMENT
// -----------------------------------------------------------------------------

// RebuildSnapshot (unchanged logic)
func (s *Storage) RebuildSnapshot(ctx context.Context) (*SnapshotConfig, error) {
	enabledClusters, disabledClusters, err := s.LoadAllClusters(ctx)
	if err != nil {
		return nil, fmt.Errorf("failed to load all clusters: %w", err)
	}
	enabledListeners, disabledListeners, err := s.LoadAllListeners(ctx)
	if err != nil {
		return nil, fmt.Errorf("failed to load all listeners: %w", err)
	}
	enabledSecrets, disabledSecrets, err := s.LoadAllSecrets(ctx)
	if err != nil {
		return nil, fmt.Errorf("failed to load all secrets: %w", err)
	}

	return &SnapshotConfig{
		EnabledClusters:   enabledClusters,
		EnabledListeners:  enabledListeners,
		EnabledSecrets:    enabledSecrets,
		DisabledClusters:  disabledClusters,
		DisabledListeners: disabledListeners,
		DisabledSecrets:   disabledSecrets,
	}, nil
}

// SaveSnapshot (unchanged logic, relying on refactored Save methods)
func (s *Storage) SaveSnapshot(ctx context.Context, cfg *SnapshotConfig, strategy DeleteStrategy) error {
	if cfg == nil {
		return fmt.Errorf("SnapshotConfig is nil")
	}

	// Use a transaction for atomicity
	tx, err := s.db.BeginTx(ctx, nil)
	if err != nil {
		return fmt.Errorf("failed to begin transaction: %w", err)
	}
	defer func() {
		if err != nil {
			tx.Rollback()
			return
		}
		err = tx.Commit()
	}()

	// --- 1. Save/Upsert Clusters and Collect Names ---
	clusterNames := make([]string, 0, len(cfg.EnabledClusters))
	for _, c := range cfg.EnabledClusters {
		// NOTE: These SaveXXX methods now use the refactored logic via s.strategy
		if err = s.SaveCluster(ctx, c); err != nil {
			return fmt.Errorf("failed to save cluster %s: %w", c.GetName(), err)
		}
		clusterNames = append(clusterNames, c.GetName())
	}

	// --- 2. Save/Upsert Listeners and Collect Names ---
	listenerNames := make([]string, 0, len(cfg.EnabledListeners))
	for _, l := range cfg.EnabledListeners {
		if err = s.SaveListener(ctx, l); err != nil {
			return fmt.Errorf("failed to save listener %s: %w", l.GetName(), err)
		}
		listenerNames = append(listenerNames, l.GetName())
	}

	// --- 3. Save/Upsert Secrets and Collect Names ---
	secretNames := make([]string, 0, len(cfg.EnabledSecrets))
	for _, sec := range cfg.EnabledSecrets {
		certSecret := &CertSecret{
			Secret: sec,
			Domain: sql.NullString{Valid: false},
		}
		if err = s.SaveSecret(ctx, certSecret); err != nil {
			return fmt.Errorf("failed to save secret %s: %w", sec.GetName(), err)
		}
		secretNames = append(secretNames, sec.GetName())
	}

	// --- 4. Apply Deletion Strategy ---
	switch strategy {
	case DeleteLogical:
		if err = s.disableMissingResources(ctx, "clusters", clusterNames); err != nil {
			return fmt.Errorf("failed to logically delete missing clusters: %w", err)
		}
		if err = s.disableMissingResources(ctx, "listeners", listenerNames); err != nil {
			return fmt.Errorf("failed to logically delete missing listeners: %w", err)
		}
		if err = s.disableMissingResources(ctx, "secrets", secretNames); err != nil {
			return fmt.Errorf("failed to logically delete missing secrets: %w", err)
		}

	case DeleteActual:
		if err = s.deleteMissingResources(ctx, "clusters", clusterNames); err != nil {
			return fmt.Errorf("failed to physically delete missing clusters: %w", err)
		}
		if err = s.deleteMissingResources(ctx, "listeners", listenerNames); err != nil {
			return fmt.Errorf("failed to physically delete missing listeners: %w", err)
		}
		if err = s.deleteMissingResources(ctx, "secrets", secretNames); err != nil {
			return fmt.Errorf("failed to physically delete missing secrets: %w", err)
		}

	case DeleteNone:
		return nil
	}

	return err
}

// -----------------------------------------------------------------------------
// ENABLE/DISABLE & DELETE METHODS
// -----------------------------------------------------------------------------

// EnableCluster is now simplified.
func (s *Storage) EnableCluster(ctx context.Context, name string, enabled bool) error {
	query := fmt.Sprintf(`UPDATE clusters SET enabled = %s, updated_at = %s WHERE name = %s`,
		s.placeholder(1), s.strategy.GetTimeNow(), s.placeholder(2))
	_, err := s.db.ExecContext(ctx, query, enabled, name)
	return err
}

// EnableListener is now simplified.
func (s *Storage) EnableListener(ctx context.Context, name string, enabled bool) error {
	query := fmt.Sprintf(`UPDATE listeners SET enabled = %s, updated_at = %s WHERE name = %s`,
		s.placeholder(1), s.strategy.GetTimeNow(), s.placeholder(2))
	_, err := s.db.ExecContext(ctx, query, enabled, name)
	return err
}

// EnableSecret is now simplified.
func (s *Storage) EnableSecret(ctx context.Context, name string, enabled bool) error {
	query := fmt.Sprintf(`UPDATE secrets SET enabled = %s, updated_at = %s WHERE name = %s`,
		s.placeholder(1), s.strategy.GetTimeNow(), s.placeholder(2))
	_, err := s.db.ExecContext(ctx, query, enabled, name)
	return err
}

// RemoveListener is now simplified.
func (s *Storage) RemoveListener(ctx context.Context, name string) error {
	query := fmt.Sprintf(`DELETE FROM listeners WHERE name = %s`, s.placeholder(1))
	_, err := s.db.ExecContext(ctx, query, name)
	return err
}

// RemoveCluster is now simplified.
func (s *Storage) RemoveCluster(ctx context.Context, name string) error {
	query := fmt.Sprintf(`DELETE FROM clusters WHERE name = %s`, s.placeholder(1))
	_, err := s.db.ExecContext(ctx, query, name)
	return err
}

// RemoveSecret is now simplified.
func (s *Storage) RemoveSecret(ctx context.Context, name string) error {
	query := fmt.Sprintf(`DELETE FROM secrets WHERE name = %s`, s.placeholder(1))
	_, err := s.db.ExecContext(ctx, query, name)
	return err
}

// disableMissingResources uses the strategy for dialect-specific values.
func (s *Storage) disableMissingResources(ctx context.Context, table string, inputNames []string) error {
	if table != "clusters" && table != "listeners" && table != "secrets" {
		return fmt.Errorf("logical delete (disable) is only supported for tables with an 'enabled' column (clusters, listeners, secrets)")
	}

	// 1. Build placeholders and args
	placeholders := make([]string, len(inputNames))
	args := make([]interface{}, len(inputNames))
	for i, name := range inputNames {
		placeholders[i] = s.placeholder(i + 1)
		args[i] = name
	}

	disabledValue := s.strategy.GetFalseValue()
	updateTime := s.strategy.GetTimeNow()

	// If no names are provided, disable ALL currently enabled resources
	whereClause := ""
	if len(inputNames) > 0 {
		whereClause = fmt.Sprintf("WHERE name NOT IN (%s)", strings.Join(placeholders, ", "))
	}

	// 2. Construct and execute the UPDATE query
	query := fmt.Sprintf(`
		UPDATE %s
		SET enabled = %s, updated_at = %s
		%s`,
		table, disabledValue, updateTime, whereClause)

	_, err := s.db.ExecContext(ctx, query, args...)
	return err
}

// deleteMissingResources is unchanged except for using the generic s.placeholder(i+1)
func (s *Storage) deleteMissingResources(ctx context.Context, table string, inputNames []string) error {
	if table != "clusters" && table != "listeners" && table != "secrets" {
		return fmt.Errorf("physical delete is only supported for tables: clusters, listeners, secrets")
	}

	// 1. Build placeholders and args
	placeholders := make([]string, len(inputNames))
	args := make([]interface{}, len(inputNames))
	for i, name := range inputNames {
		placeholders[i] = s.placeholder(i + 1)
		args[i] = name
	}

	// If no names are provided, delete ALL resources
	whereClause := ""
	if len(inputNames) > 0 {
		whereClause = fmt.Sprintf("WHERE name NOT IN (%s)", strings.Join(placeholders, ", "))
	}

	// 2. Construct and execute the DELETE query
	query := fmt.Sprintf(`
		DELETE FROM %s
		%s`,
		table, whereClause)

	_, err := s.db.ExecContext(ctx, query, args...)
	return err
}