Newer
Older
EnvoyControlPlane / internal / pkg / storage / storage_dump.go
@jerxie jerxie on 12 Nov 5 KB improve the storage dump
package storage

import (
	"context"
	"database/sql"
	"encoding/json"
	"fmt"
)

// DBDump holds the complete state of the database for dumping/restoring.
type DBDump struct {
	Clusters     []*RawRow      `json:"clusters,omitempty"`
	Listeners    []*RawRow      `json:"listeners,omitempty"`
	Secrets      []*RawRow      `json:"secrets,omitempty"`
	Certificates []*CertStorage `json:"certificates,omitempty"`
	DBDriver     string         `json:"db_driver"`
}

// RawRow represents a single row of xDS data with name + JSON body.
type RawRow struct {
	Name string          `json:"name"`
	Data json.RawMessage `json:"data"`
}

// DBDumpRestoreMode defines the strategy for restoring the data.
type DBDumpRestoreMode int

const (
	// RestoreMerge merges the incoming data with existing data (UPSERT).
	RestoreMerge DBDumpRestoreMode = iota
	// RestoreOverride deletes all existing data and inserts the new data.
	RestoreOverride
)

// Dump exports all database content as JSON including certificates.
func (s *Storage) Dump(ctx context.Context) ([]byte, error) {
	load := func(ctx context.Context, table string) ([]*RawRow, error) {
		if table != "clusters" && table != "listeners" && table != "secrets" {
			return nil, fmt.Errorf("invalid table for dump: %s", table)
		}

		query := fmt.Sprintf(`SELECT name, data FROM %s`, table)
		rows, err := s.db.QueryContext(ctx, query)
		if err != nil {
			return nil, fmt.Errorf("failed to query table %s: %w", table, err)
		}
		defer rows.Close()

		var out []*RawRow
		for rows.Next() {
			row := &RawRow{}
			if s.driver == "postgres" {
				if err := rows.Scan(&row.Name, &row.Data); err != nil {
					return nil, fmt.Errorf("scan error on %s: %w", table, err)
				}
			} else {
				var dataStr string
				if err := rows.Scan(&row.Name, &dataStr); err != nil {
					return nil, fmt.Errorf("scan error on %s: %w", table, err)
				}
				row.Data = []byte(dataStr)
			}
			out = append(out, row)
		}
		return out, rows.Err()
	}

	clusters, err := load(ctx, "clusters")
	if err != nil {
		return nil, err
	}
	listeners, err := load(ctx, "listeners")
	if err != nil {
		return nil, err
	}
	secrets, err := load(ctx, "secrets")
	if err != nil {
		return nil, err
	}

	certs, err := s.LoadAllCertificates(ctx)
	if err != nil && err != sql.ErrNoRows {
		return nil, fmt.Errorf("failed to load certificates: %w", err)
	}

	dump := &DBDump{
		Clusters:     clusters,
		Listeners:    listeners,
		Secrets:      secrets,
		Certificates: certs,
		DBDriver:     s.driver,
	}

	data, err := json.MarshalIndent(dump, "", "  ")
	if err != nil {
		return nil, fmt.Errorf("failed to marshal DB dump: %w", err)
	}
	return data, nil
}

// Restore imports database content from a JSON dump (merge or override).
func (s *Storage) Restore(ctx context.Context, data []byte, mode DBDumpRestoreMode) error {
	var dump DBDump
	if err := json.Unmarshal(data, &dump); err != nil {
		return fmt.Errorf("failed to parse dump JSON: %w", err)
	}

	if dump.DBDriver != s.driver {
		return fmt.Errorf("database driver mismatch: dump='%s', current='%s'", dump.DBDriver, s.driver)
	}

	if mode == RestoreOverride {
		for _, tbl := range []string{"clusters", "listeners", "secrets", "certificates"} {
			if err := s.clearTable(ctx, tbl); err != nil {
				return fmt.Errorf("failed to clear %s: %w", tbl, err)
			}
		}
	}

	saveRaw := func(ctx context.Context, table string, rows []*RawRow) error {
		for _, r := range rows {
			if r == nil || r.Name == "" {
				continue
			}

			var query string
			switch s.driver {
			case "postgres":
				query = fmt.Sprintf(`
					INSERT INTO %s (name, data, enabled, updated_at)
					VALUES ($1, $2, true, now())
					ON CONFLICT (name)
					DO UPDATE SET data = $2, enabled = true, updated_at = now()`, table)
			default:
				query = fmt.Sprintf(`
					INSERT INTO %s (name, data, enabled, updated_at)
					VALUES (?, ?, 1, CURRENT_TIMESTAMP)
					ON CONFLICT(name)
					DO UPDATE SET data=excluded.data, enabled=1, updated_at=CURRENT_TIMESTAMP`, table)
			}

			if _, err := s.db.ExecContext(ctx, query, r.Name, string(r.Data)); err != nil {
				return fmt.Errorf("failed to upsert %s '%s': %w", table, r.Name, err)
			}
		}
		return nil
	}

	if err := saveRaw(ctx, "clusters", dump.Clusters); err != nil {
		return err
	}
	if err := saveRaw(ctx, "listeners", dump.Listeners); err != nil {
		return err
	}
	if err := saveRaw(ctx, "secrets", dump.Secrets); err != nil {
		return err
	}

	for _, cert := range dump.Certificates {
		if cert == nil || cert.Domain == "" {
			continue
		}
		if err := s.SaveCertificate(ctx, cert); err != nil {
			return fmt.Errorf("failed to upsert certificate '%s': %w", cert.Domain, err)
		}
	}

	return nil
}

// clearTable deletes all rows from the given table (safe against SQLi).
func (s *Storage) clearTable(ctx context.Context, table string) error {
	valid := map[string]bool{
		"clusters":     true,
		"listeners":    true,
		"secrets":      true,
		"certificates": true,
	}
	if !valid[table] {
		return fmt.Errorf("invalid table name: %s", table)
	}

	query := fmt.Sprintf("DELETE FROM %s", table)
	if s.driver == "postgres" {
		query = fmt.Sprintf("TRUNCATE TABLE %s RESTART IDENTITY", table)
	}

	_, err := s.db.ExecContext(ctx, query)
	if err != nil {
		return fmt.Errorf("failed to clear table %s: %w", table, err)
	}
	return nil
}