package storage
import (
"context"
"database/sql"
"encoding/json"
"fmt"
)
// DBDump holds the complete state of the database for dumping/restoring.
type DBDump struct {
Clusters []*RawRow `json:"clusters,omitempty"`
Listeners []*RawRow `json:"listeners,omitempty"`
Secrets []*RawRow `json:"secrets,omitempty"`
Certificates []*CertStorage `json:"certificates,omitempty"`
DBDriver string `json:"db_driver"`
}
// RawRow represents a single row of xDS data with name + JSON body.
type RawRow struct {
Name string `json:"name"`
Data json.RawMessage `json:"data"`
}
// DBDumpRestoreMode defines the strategy for restoring the data.
type DBDumpRestoreMode int
const (
// RestoreMerge merges the incoming data with existing data (UPSERT).
RestoreMerge DBDumpRestoreMode = iota
// RestoreOverride deletes all existing data and inserts the new data.
RestoreOverride
)
// Dump exports all database content as JSON including certificates.
func (s *Storage) Dump(ctx context.Context) ([]byte, error) {
load := func(ctx context.Context, table string) ([]*RawRow, error) {
if table != "clusters" && table != "listeners" && table != "secrets" {
return nil, fmt.Errorf("invalid table for dump: %s", table)
}
query := fmt.Sprintf(`SELECT name, data FROM %s`, table)
rows, err := s.db.QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to query table %s: %w", table, err)
}
defer rows.Close()
var out []*RawRow
for rows.Next() {
row := &RawRow{}
if s.driver == "postgres" {
if err := rows.Scan(&row.Name, &row.Data); err != nil {
return nil, fmt.Errorf("scan error on %s: %w", table, err)
}
} else {
var dataStr string
if err := rows.Scan(&row.Name, &dataStr); err != nil {
return nil, fmt.Errorf("scan error on %s: %w", table, err)
}
row.Data = []byte(dataStr)
}
out = append(out, row)
}
return out, rows.Err()
}
clusters, err := load(ctx, "clusters")
if err != nil {
return nil, err
}
listeners, err := load(ctx, "listeners")
if err != nil {
return nil, err
}
secrets, err := load(ctx, "secrets")
if err != nil {
return nil, err
}
certs, err := s.LoadAllCertificates(ctx)
if err != nil && err != sql.ErrNoRows {
return nil, fmt.Errorf("failed to load certificates: %w", err)
}
dump := &DBDump{
Clusters: clusters,
Listeners: listeners,
Secrets: secrets,
Certificates: certs,
DBDriver: s.driver,
}
data, err := json.MarshalIndent(dump, "", " ")
if err != nil {
return nil, fmt.Errorf("failed to marshal DB dump: %w", err)
}
return data, nil
}
// Restore imports database content from a JSON dump (merge or override).
func (s *Storage) Restore(ctx context.Context, data []byte, mode DBDumpRestoreMode) error {
var dump DBDump
if err := json.Unmarshal(data, &dump); err != nil {
return fmt.Errorf("failed to parse dump JSON: %w", err)
}
if dump.DBDriver != s.driver {
return fmt.Errorf("database driver mismatch: dump='%s', current='%s'", dump.DBDriver, s.driver)
}
if mode == RestoreOverride {
for _, tbl := range []string{"clusters", "listeners", "secrets", "certificates"} {
if err := s.clearTable(ctx, tbl); err != nil {
return fmt.Errorf("failed to clear %s: %w", tbl, err)
}
}
}
saveRaw := func(ctx context.Context, table string, rows []*RawRow) error {
for _, r := range rows {
if r == nil || r.Name == "" {
continue
}
var query string
switch s.driver {
case "postgres":
query = fmt.Sprintf(`
INSERT INTO %s (name, data, enabled, updated_at)
VALUES ($1, $2, true, now())
ON CONFLICT (name)
DO UPDATE SET data = $2, enabled = true, updated_at = now()`, table)
default:
query = fmt.Sprintf(`
INSERT INTO %s (name, data, enabled, updated_at)
VALUES (?, ?, 1, CURRENT_TIMESTAMP)
ON CONFLICT(name)
DO UPDATE SET data=excluded.data, enabled=1, updated_at=CURRENT_TIMESTAMP`, table)
}
if _, err := s.db.ExecContext(ctx, query, r.Name, string(r.Data)); err != nil {
return fmt.Errorf("failed to upsert %s '%s': %w", table, r.Name, err)
}
}
return nil
}
if err := saveRaw(ctx, "clusters", dump.Clusters); err != nil {
return err
}
if err := saveRaw(ctx, "listeners", dump.Listeners); err != nil {
return err
}
if err := saveRaw(ctx, "secrets", dump.Secrets); err != nil {
return err
}
for _, cert := range dump.Certificates {
if cert == nil || cert.Domain == "" {
continue
}
if err := s.SaveCertificate(ctx, cert); err != nil {
return fmt.Errorf("failed to upsert certificate '%s': %w", cert.Domain, err)
}
}
return nil
}
// clearTable deletes all rows from the given table (safe against SQLi).
func (s *Storage) clearTable(ctx context.Context, table string) error {
valid := map[string]bool{
"clusters": true,
"listeners": true,
"secrets": true,
"certificates": true,
}
if !valid[table] {
return fmt.Errorf("invalid table name: %s", table)
}
query := fmt.Sprintf("DELETE FROM %s", table)
if s.driver == "postgres" {
query = fmt.Sprintf("TRUNCATE TABLE %s RESTART IDENTITY", table)
}
_, err := s.db.ExecContext(ctx, query)
if err != nil {
return fmt.Errorf("failed to clear table %s: %w", table, err)
}
return nil
}