Newer
Older
EnvoyControlPlane / internal / pkg / storage / storage_dump.go
package storage

import (
	"context"
	"database/sql"
	"encoding/json"
	"fmt"
)

// NOTE: This file assumes the SQLStrategy interface and concrete implementations
// (PostgresStrategy, SQLiteStrategy) defined in the previous response are present.

// DBDump holds the complete state of the database for dumping/restoring.
type DBDump struct {
	Clusters     []*RawRow      `json:"clusters,omitempty"`
	Listeners    []*RawRow      `json:"listeners,omitempty"`
	Secrets      []*RawRow      `json:"secrets,omitempty"`
	Certificates []*CertStorage `json:"certificates,omitempty"`
	DBDriver     string         `json:"db_driver"`
}

// DBDumpRestoreMode defines the strategy for restoring the data.
type DBDumpRestoreMode int

const (
	// RestoreMerge merges the incoming data with existing data (UPSERT).
	RestoreMerge DBDumpRestoreMode = iota
	// RestoreOverride deletes all existing data and inserts the new data.
	RestoreOverride
)

// -----------------------------------------------------------------------------
// DUMP METHODS (REFACTORED)
// -----------------------------------------------------------------------------

// Dump exports all database content as JSON including certificates.
func (s *Storage) Dump(ctx context.Context) ([]byte, error) {
	// Refactored Load helper uses the strategy for SQL generation and scanning
	load := func(ctx context.Context, table string) ([]*RawRow, error) {
		if table != "clusters" && table != "listeners" && table != "secrets" {
			return nil, fmt.Errorf("invalid table for dump: %s", table)
		}

		// Use the strategy to get the fields to select
		fields := s.strategy.DumpSelectFields(table)
		query := fmt.Sprintf(`SELECT %s FROM %s`, fields, table)

		rows, err := s.db.QueryContext(ctx, query)
		if err != nil {
			return nil, fmt.Errorf("failed to query table %s: %w", table, err)
		}
		defer rows.Close()

		var out []*RawRow
		for rows.Next() {
			row := &RawRow{}
			// Delegate all scanning logic to the strategy
			if err := s.strategy.ScanRawRow(rows, row, table); err != nil {
				return nil, fmt.Errorf("scan error on %s: %w", table, err)
			}
			out = append(out, row)
		}
		return out, rows.Err()
	}

	clusters, err := load(ctx, "clusters")
	if err != nil {
		return nil, err
	}
	listeners, err := load(ctx, "listeners")
	if err != nil {
		return nil, err
	}
	secrets, err := load(ctx, "secrets")
	if err != nil {
		return nil, err
	}

	certs, err := s.LoadAllCertificates(ctx)
	if err != nil && err != sql.ErrNoRows {
		return nil, fmt.Errorf("failed to load certificates: %w", err)
	}

	dump := &DBDump{
		Clusters:     clusters,
		Listeners:    listeners,
		Secrets:      secrets,
		Certificates: certs,
		DBDriver:     s.strategy.DriverName(), // Use strategy for driver name
	}

	data, err := json.MarshalIndent(dump, "", "  ")
	if err != nil {
		return nil, fmt.Errorf("failed to marshal DB dump: %w", err)
	}
	return data, nil
}

// -----------------------------------------------------------------------------
// RESTORE METHODS (REFACTORED)
// -----------------------------------------------------------------------------

// Restore imports database content from a JSON dump (merge or override).
func (s *Storage) Restore(ctx context.Context, data []byte, mode DBDumpRestoreMode) error {
	var dump DBDump
	if err := json.Unmarshal(data, &dump); err != nil {
		return fmt.Errorf("failed to parse dump JSON: %w", err)
	}

	if dump.DBDriver != s.strategy.DriverName() {
		return fmt.Errorf("database driver mismatch: dump='%s', current='%s'", dump.DBDriver, s.strategy.DriverName())
	}

	if mode == RestoreOverride {
		for _, tbl := range []string{"clusters", "listeners", "secrets", "certificates"} {
			if err := s.clearTable(ctx, tbl); err != nil {
				return fmt.Errorf("failed to clear %s: %w", tbl, err)
			}
		}
	}

	// saveRaw now relies entirely on the strategy to provide the correct UPSERT query
	saveRaw := func(ctx context.Context, table string, rows []*RawRow) error {
		for _, r := range rows {
			if r == nil || r.Name == "" {
				continue
			}

			// Get the correct UPSERT query from the strategy
			query := s.strategy.RestoreRawRowSQL(table)

			var args []interface{}

			// Handle arguments based on table (secrets needs 3 args, others need 2)
			if table == "secrets" {
				args = []interface{}{r.Name, string(r.Data), r.Domain}
			} else {
				args = []interface{}{r.Name, string(r.Data)}
			}

			// Postgres requires the data argument to be repeated for the ON CONFLICT clause
			if s.strategy.DriverName() == "postgres" {
				// The implementation of RestoreRawRowSQL for Postgres must ensure
				// that the query only expects the standard set of positional arguments
				// for the VALUES clause ($1, $2, [$3]), as the data is repeated
				// using the EXCLUDED prefix in Postgres' ON CONFLICT syntax.
				// No extra args needed here, the strategy handles the SQL structure.
			}

			if _, err := s.db.ExecContext(ctx, query, args...); err != nil {
				return fmt.Errorf("failed to upsert %s '%s': %w", table, r.Name, err)
			}
		}
		return nil
	}

	if err := saveRaw(ctx, "clusters", dump.Clusters); err != nil {
		return err
	}
	if err := saveRaw(ctx, "listeners", dump.Listeners); err != nil {
		return err
	}
	if err := saveRaw(ctx, "secrets", dump.Secrets); err != nil {
		return err
	}

	// Certificates still use the core SaveCertificate method, which is already refactored.
	for _, cert := range dump.Certificates {
		if cert == nil || cert.Domain == "" {
			continue
		}
		if err := s.SaveCertificate(ctx, cert); err != nil {
			return fmt.Errorf("failed to upsert certificate '%s': %w", cert.Domain, err)
		}
	}

	return nil
}

// clearTable is now simplified.
func (s *Storage) clearTable(ctx context.Context, table string) error {
	valid := map[string]bool{
		"clusters":     true,
		"listeners":    true,
		"secrets":      true,
		"certificates": true,
	}
	if !valid[table] {
		return fmt.Errorf("invalid table name: %s", table)
	}

	// Delegate the specific clear/truncate query to the strategy
	query := s.strategy.ClearTableSQL(table)

	_, err := s.db.ExecContext(ctx, query)
	if err != nil {
		return fmt.Errorf("failed to clear table %s: %w", table, err)
	}
	return nil
}