package storage
import (
"context"
"encoding/json"
"fmt"
)
// DBDump holds the complete state of the database for dumping/restoring.
type DBDump struct {
// Envoy xDS Resources - Use raw bytes for JSON/JSONB data from the DB
Clusters [][]byte `json:"clusters,omitempty"`
Listeners [][]byte `json:"listeners,omitempty"`
Secrets [][]byte `json:"secrets,omitempty"`
// Certificate Management Resources (Standard Go structs)
Certificates []*CertStorage `json:"certificates,omitempty"`
// Metadata field to capture database type for validation
DBDriver string `json:"db_driver"`
}
// DBDumpRestoreMode defines the strategy for restoring the data.
type DBDumpRestoreMode int
const (
// RestoreMerge merges the incoming data with existing data (UPSERT).
RestoreMerge DBDumpRestoreMode = iota
// RestoreOverride deletes all existing data and inserts the new data.
RestoreOverride
)
// Dump exports all database content as raw JSON data into a single JSON byte slice.
func (s *Storage) Dump(ctx context.Context) ([]byte, error) {
// 1. Load raw JSON data for Envoy resources (avoids PB conversion)
rawClusters, err := s.LoadAllRawData(ctx, "clusters")
if err != nil {
return nil, fmt.Errorf("failed to load all clusters for dump: %w", err)
}
rawListeners, err := s.LoadAllRawData(ctx, "listeners")
if err != nil {
return nil, fmt.Errorf("failed to load all listeners for dump: %w", err)
}
rawSecrets, err := s.LoadAllRawData(ctx, "secrets")
if err != nil {
return nil, fmt.Errorf("failed to load all secrets for dump: %w", err)
}
// 2. Load Certificates (already standard Go structs)
// Note: LoadAllCertificates is defined in storage.go
protoCertificates, err := s.LoadAllCertificates(ctx)
if err != nil {
return nil, fmt.Errorf("failed to load all certificates for dump: %w", err)
}
// 3. Assemble the single top-level JSON message
dump := &DBDump{
Clusters: rawClusters,
Listeners: rawListeners,
Secrets: rawSecrets,
Certificates: protoCertificates,
DBDriver: s.driver, // Add metadata
}
// 4. Marshal the entire structure directly to JSON
data, err := json.MarshalIndent(dump, "", " ")
if err != nil {
return nil, fmt.Errorf("failed to marshal database dump to JSON: %w", err)
}
return data, nil
}
// Restore imports database content from a JSON byte slice, inserting raw data.
func (s *Storage) Restore(ctx context.Context, data []byte, mode DBDumpRestoreMode) error {
var dump DBDump
// 1. Unmarshal top-level structure using standard JSON
if err := json.Unmarshal(data, &dump); err != nil {
return fmt.Errorf("failed to unmarshal database dump from JSON: %w", err)
}
// 2. Validate Metadata
if dump.DBDriver != s.driver {
return fmt.Errorf("database driver mismatch: dump is for '%s', current is '%s'", dump.DBDriver, s.driver)
}
// --- 3. Override Mode: Clear Existing Tables ---
if mode == RestoreOverride {
if err := s.clearTable(ctx, "clusters"); err != nil {
return fmt.Errorf("failed to clear clusters table for override: %w", err)
}
if err := s.clearTable(ctx, "listeners"); err != nil {
return fmt.Errorf("failed to clear listeners table for override: %w", err)
}
if err := s.clearTable(ctx, "secrets"); err != nil {
return fmt.Errorf("failed to clear secrets table for override: %w", err)
}
if err := s.clearTable(ctx, "certificates"); err != nil {
return fmt.Errorf("failed to clear certificates table for override: %w", err)
}
}
// --- 4. Insert/Upsert Data using Raw JSON ---
// Clusters
if err := s.SaveRawData(ctx, "clusters", dump.Clusters); err != nil {
return fmt.Errorf("failed to save restored cluster raw data: %w", err)
}
// Listeners
if err := s.SaveRawData(ctx, "listeners", dump.Listeners); err != nil {
return fmt.Errorf("failed to save restored listener raw data: %w", err)
}
// Secrets
if err := s.SaveRawData(ctx, "secrets", dump.Secrets); err != nil {
return fmt.Errorf("failed to save restored secret raw data: %w", err)
}
// Certificates (Standard Go struct - SaveCertificate is defined in storage.go)
for _, cert := range dump.Certificates {
if err := s.SaveCertificate(ctx, cert); err != nil {
return fmt.Errorf("failed to save restored certificate %s: %w", cert.Domain, err)
}
}
return nil
}
// LoadAllRawData retrieves the raw data column ([]byte/JSONB) from a table.
func (s *Storage) LoadAllRawData(ctx context.Context, table string) ([][]byte, error) {
if table != "clusters" && table != "listeners" && table != "secrets" {
return nil, fmt.Errorf("invalid table name: %s", table)
}
query := fmt.Sprintf(`SELECT data FROM %s`, table)
rows, err := s.db.QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
var data [][]byte
for rows.Next() {
var rawData []byte
// Handle Postgres JSONB (scans to []byte) vs SQLite TEXT (scans to string, then convert)
if s.driver == "postgres" {
if err := rows.Scan(&rawData); err != nil {
return nil, fmt.Errorf("failed to scan raw postgres data from %s: %w", table, err)
}
} else { // SQLite
var dataStr string
if err := rows.Scan(&dataStr); err != nil {
return nil, fmt.Errorf("failed to scan raw sqlite data from %s: %w", table, err)
}
rawData = []byte(dataStr)
}
data = append(data, rawData)
}
if err := rows.Err(); err != nil {
return nil, err
}
return data, nil
}
// SaveRawData handles UPSERT for raw JSON data for clusters, listeners, and secrets.
func (s *Storage) SaveRawData(ctx context.Context, table string, rawData [][]byte) error {
if table != "clusters" && table != "listeners" && table != "secrets" {
return fmt.Errorf("invalid table name for raw data save: %s", table)
}
for _, data := range rawData {
// To get the name (for UPSERT), we must minimally unmarshal the name field.
var nameExtractor struct {
Name string `json:"name"`
}
if err := json.Unmarshal(data, &nameExtractor); err != nil {
return fmt.Errorf("failed to extract name from raw data for table %s: %w", table, err)
}
// Use the same UPSERT logic as original Save* methods
var query string
switch s.driver {
case "postgres":
query = fmt.Sprintf(`
INSERT INTO %s (name, data, enabled, updated_at)
VALUES ($1, $2, true, now())
ON CONFLICT (name) DO UPDATE SET data = $2, enabled = true, updated_at = now()`, table)
default: // SQLite
query = fmt.Sprintf(`
INSERT INTO %s (name, data, enabled, updated_at)
VALUES (?, ?, 1, CURRENT_TIMESTAMP)
ON CONFLICT(name) DO UPDATE SET data=excluded.data, enabled=1, updated_at=CURRENT_TIMESTAMP`, table)
}
_, err := s.db.ExecContext(ctx, query, nameExtractor.Name, string(data))
if err != nil {
return fmt.Errorf("failed to upsert raw data for %s: %w", table, err)
}
}
return nil
}
// clearTable is a helper function to delete all rows from a table.
func (s *Storage) clearTable(ctx context.Context, table string) error {
// Simple validation to prevent SQL injection on table names
if table != "clusters" && table != "listeners" && table != "secrets" && table != "certificates" {
return fmt.Errorf("invalid table name for clearing: %s", table)
}
var query string
if s.driver == "postgres" {
query = fmt.Sprintf("TRUNCATE TABLE %s RESTART IDENTITY", table)
} else {
query = fmt.Sprintf("DELETE FROM %s", table)
}
// Assuming s.db is your database connection pool
_, err := s.db.ExecContext(ctx, query)
if err != nil {
return fmt.Errorf("error clearing table %s: %w", table, err)
}
return nil
}