Newer
Older
EnvoyControlPlane / internal / pkg / storage / storage_dump.go
package storage

import (
	"context"
	"encoding/json"
	"fmt"
)

// DBDump holds the complete state of the database for dumping/restoring.
type DBDump struct {
	// Envoy xDS Resources - Use raw bytes for JSON/JSONB data from the DB
	Clusters  [][]byte `json:"clusters,omitempty"`
	Listeners [][]byte `json:"listeners,omitempty"`
	Secrets   [][]byte `json:"secrets,omitempty"`
	// Certificate Management Resources (Standard Go structs)
	Certificates []*CertStorage `json:"certificates,omitempty"`
	// Metadata field to capture database type for validation
	DBDriver string `json:"db_driver"`
}

// DBDumpRestoreMode defines the strategy for restoring the data.
type DBDumpRestoreMode int

const (
	// RestoreMerge merges the incoming data with existing data (UPSERT).
	RestoreMerge DBDumpRestoreMode = iota
	// RestoreOverride deletes all existing data and inserts the new data.
	RestoreOverride
)

// Dump exports all database content as raw JSON data into a single JSON byte slice.
func (s *Storage) Dump(ctx context.Context) ([]byte, error) {
	// 1. Load raw JSON data for Envoy resources (avoids PB conversion)
	rawClusters, err := s.LoadAllRawData(ctx, "clusters")
	if err != nil {
		return nil, fmt.Errorf("failed to load all clusters for dump: %w", err)
	}
	rawListeners, err := s.LoadAllRawData(ctx, "listeners")
	if err != nil {
		return nil, fmt.Errorf("failed to load all listeners for dump: %w", err)
	}
	rawSecrets, err := s.LoadAllRawData(ctx, "secrets")
	if err != nil {
		return nil, fmt.Errorf("failed to load all secrets for dump: %w", err)
	}

	// 2. Load Certificates (already standard Go structs)
	// Note: LoadAllCertificates is defined in storage.go
	protoCertificates, err := s.LoadAllCertificates(ctx)
	if err != nil {
		return nil, fmt.Errorf("failed to load all certificates for dump: %w", err)
	}

	// 3. Assemble the single top-level JSON message
	dump := &DBDump{
		Clusters:     rawClusters,
		Listeners:    rawListeners,
		Secrets:      rawSecrets,
		Certificates: protoCertificates,
		DBDriver:     s.driver, // Add metadata
	}

	// 4. Marshal the entire structure directly to JSON
	data, err := json.MarshalIndent(dump, "", "  ")
	if err != nil {
		return nil, fmt.Errorf("failed to marshal database dump to JSON: %w", err)
	}

	return data, nil
}

// Restore imports database content from a JSON byte slice, inserting raw data.
func (s *Storage) Restore(ctx context.Context, data []byte, mode DBDumpRestoreMode) error {
	var dump DBDump

	// 1. Unmarshal top-level structure using standard JSON
	if err := json.Unmarshal(data, &dump); err != nil {
		return fmt.Errorf("failed to unmarshal database dump from JSON: %w", err)
	}

	// 2. Validate Metadata
	if dump.DBDriver != s.driver {
		return fmt.Errorf("database driver mismatch: dump is for '%s', current is '%s'", dump.DBDriver, s.driver)
	}

	// --- 3. Override Mode: Clear Existing Tables ---
	if mode == RestoreOverride {
		if err := s.clearTable(ctx, "clusters"); err != nil {
			return fmt.Errorf("failed to clear clusters table for override: %w", err)
		}
		if err := s.clearTable(ctx, "listeners"); err != nil {
			return fmt.Errorf("failed to clear listeners table for override: %w", err)
		}
		if err := s.clearTable(ctx, "secrets"); err != nil {
			return fmt.Errorf("failed to clear secrets table for override: %w", err)
		}
		if err := s.clearTable(ctx, "certificates"); err != nil {
			return fmt.Errorf("failed to clear certificates table for override: %w", err)
		}
	}

	// --- 4. Insert/Upsert Data using Raw JSON ---

	// Clusters
	if err := s.SaveRawData(ctx, "clusters", dump.Clusters); err != nil {
		return fmt.Errorf("failed to save restored cluster raw data: %w", err)
	}

	// Listeners
	if err := s.SaveRawData(ctx, "listeners", dump.Listeners); err != nil {
		return fmt.Errorf("failed to save restored listener raw data: %w", err)
	}

	// Secrets
	if err := s.SaveRawData(ctx, "secrets", dump.Secrets); err != nil {
		return fmt.Errorf("failed to save restored secret raw data: %w", err)
	}

	// Certificates (Standard Go struct - SaveCertificate is defined in storage.go)
	for _, cert := range dump.Certificates {
		if err := s.SaveCertificate(ctx, cert); err != nil {
			return fmt.Errorf("failed to save restored certificate %s: %w", cert.Domain, err)
		}
	}

	return nil
}

// LoadAllRawData retrieves the raw data column ([]byte/JSONB) from a table.
func (s *Storage) LoadAllRawData(ctx context.Context, table string) ([][]byte, error) {
	if table != "clusters" && table != "listeners" && table != "secrets" {
		return nil, fmt.Errorf("invalid table name: %s", table)
	}

	query := fmt.Sprintf(`SELECT data FROM %s`, table)

	rows, err := s.db.QueryContext(ctx, query)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var data [][]byte
	for rows.Next() {
		var rawData []byte
		// Handle Postgres JSONB (scans to []byte) vs SQLite TEXT (scans to string, then convert)
		if s.driver == "postgres" {
			if err := rows.Scan(&rawData); err != nil {
				return nil, fmt.Errorf("failed to scan raw postgres data from %s: %w", table, err)
			}
		} else { // SQLite
			var dataStr string
			if err := rows.Scan(&dataStr); err != nil {
				return nil, fmt.Errorf("failed to scan raw sqlite data from %s: %w", table, err)
			}
			rawData = []byte(dataStr)
		}
		data = append(data, rawData)
	}
	if err := rows.Err(); err != nil {
		return nil, err
	}

	return data, nil
}

// SaveRawData handles UPSERT for raw JSON data for clusters, listeners, and secrets.
func (s *Storage) SaveRawData(ctx context.Context, table string, rawData [][]byte) error {
	if table != "clusters" && table != "listeners" && table != "secrets" {
		return fmt.Errorf("invalid table name for raw data save: %s", table)
	}

	for _, data := range rawData {
		// To get the name (for UPSERT), we must minimally unmarshal the name field.
		var nameExtractor struct {
			Name string `json:"name"`
		}
		if err := json.Unmarshal(data, &nameExtractor); err != nil {
			return fmt.Errorf("failed to extract name from raw data for table %s: %w", table, err)
		}

		// Use the same UPSERT logic as original Save* methods
		var query string
		switch s.driver {
		case "postgres":
			query = fmt.Sprintf(`
				INSERT INTO %s (name, data, enabled, updated_at)
				VALUES ($1, $2, true, now())
				ON CONFLICT (name) DO UPDATE SET data = $2, enabled = true, updated_at = now()`, table)
		default: // SQLite
			query = fmt.Sprintf(`
				INSERT INTO %s (name, data, enabled, updated_at)
				VALUES (?, ?, 1, CURRENT_TIMESTAMP)
				ON CONFLICT(name) DO UPDATE SET data=excluded.data, enabled=1, updated_at=CURRENT_TIMESTAMP`, table)
		}

		_, err := s.db.ExecContext(ctx, query, nameExtractor.Name, string(data))
		if err != nil {
			return fmt.Errorf("failed to upsert raw data for %s: %w", table, err)
		}
	}

	return nil
}

// clearTable is a helper function to delete all rows from a table.
func (s *Storage) clearTable(ctx context.Context, table string) error {
	// Simple validation to prevent SQL injection on table names
	if table != "clusters" && table != "listeners" && table != "secrets" && table != "certificates" {
		return fmt.Errorf("invalid table name for clearing: %s", table)
	}

	var query string
	if s.driver == "postgres" {
		query = fmt.Sprintf("TRUNCATE TABLE %s RESTART IDENTITY", table)
	} else {
		query = fmt.Sprintf("DELETE FROM %s", table)
	}

	// Assuming s.db is your database connection pool
	_, err := s.db.ExecContext(ctx, query)
	if err != nil {
		return fmt.Errorf("error clearing table %s: %w", table, err)
	}
	return nil
}