diff --git a/data/config.db b/data/config.db index 8bbe6e6..38afb96 100644 --- a/data/config.db +++ b/data/config.db Binary files differ diff --git a/internal/pkg/storage/storage_dump.go b/internal/pkg/storage/storage_dump.go index 8a04a6d..7d75294 100644 --- a/internal/pkg/storage/storage_dump.go +++ b/internal/pkg/storage/storage_dump.go @@ -2,20 +2,24 @@ import ( "context" + "database/sql" "encoding/json" "fmt" ) // DBDump holds the complete state of the database for dumping/restoring. type DBDump struct { - // Envoy xDS Resources - Use raw bytes for JSON/JSONB data from the DB - Clusters [][]byte `json:"clusters,omitempty"` - Listeners [][]byte `json:"listeners,omitempty"` - Secrets [][]byte `json:"secrets,omitempty"` - // Certificate Management Resources (Standard Go structs) + Clusters []*RawRow `json:"clusters,omitempty"` + Listeners []*RawRow `json:"listeners,omitempty"` + Secrets []*RawRow `json:"secrets,omitempty"` Certificates []*CertStorage `json:"certificates,omitempty"` - // Metadata field to capture database type for validation - DBDriver string `json:"db_driver"` + DBDriver string `json:"db_driver"` +} + +// RawRow represents a single row of xDS data with name + JSON body. +type RawRow struct { + Name string `json:"name"` + Data json.RawMessage `json:"data"` } // DBDumpRestoreMode defines the strategy for restoring the data. @@ -28,199 +32,162 @@ RestoreOverride ) -// Dump exports all database content as raw JSON data into a single JSON byte slice. +// Dump exports all database content as JSON including certificates. func (s *Storage) Dump(ctx context.Context) ([]byte, error) { - // 1. Load raw JSON data for Envoy resources (avoids PB conversion) - rawClusters, err := s.LoadAllRawData(ctx, "clusters") - if err != nil { - return nil, fmt.Errorf("failed to load all clusters for dump: %w", err) - } - rawListeners, err := s.LoadAllRawData(ctx, "listeners") - if err != nil { - return nil, fmt.Errorf("failed to load all listeners for dump: %w", err) - } - rawSecrets, err := s.LoadAllRawData(ctx, "secrets") - if err != nil { - return nil, fmt.Errorf("failed to load all secrets for dump: %w", err) + load := func(ctx context.Context, table string) ([]*RawRow, error) { + if table != "clusters" && table != "listeners" && table != "secrets" { + return nil, fmt.Errorf("invalid table for dump: %s", table) + } + + query := fmt.Sprintf(`SELECT name, data FROM %s`, table) + rows, err := s.db.QueryContext(ctx, query) + if err != nil { + return nil, fmt.Errorf("failed to query table %s: %w", table, err) + } + defer rows.Close() + + var out []*RawRow + for rows.Next() { + row := &RawRow{} + if s.driver == "postgres" { + if err := rows.Scan(&row.Name, &row.Data); err != nil { + return nil, fmt.Errorf("scan error on %s: %w", table, err) + } + } else { + var dataStr string + if err := rows.Scan(&row.Name, &dataStr); err != nil { + return nil, fmt.Errorf("scan error on %s: %w", table, err) + } + row.Data = []byte(dataStr) + } + out = append(out, row) + } + return out, rows.Err() } - // 2. Load Certificates (already standard Go structs) - // Note: LoadAllCertificates is defined in storage.go - protoCertificates, err := s.LoadAllCertificates(ctx) + clusters, err := load(ctx, "clusters") if err != nil { - return nil, fmt.Errorf("failed to load all certificates for dump: %w", err) + return nil, err + } + listeners, err := load(ctx, "listeners") + if err != nil { + return nil, err + } + secrets, err := load(ctx, "secrets") + if err != nil { + return nil, err } - // 3. Assemble the single top-level JSON message + certs, err := s.LoadAllCertificates(ctx) + if err != nil && err != sql.ErrNoRows { + return nil, fmt.Errorf("failed to load certificates: %w", err) + } + dump := &DBDump{ - Clusters: rawClusters, - Listeners: rawListeners, - Secrets: rawSecrets, - Certificates: protoCertificates, - DBDriver: s.driver, // Add metadata + Clusters: clusters, + Listeners: listeners, + Secrets: secrets, + Certificates: certs, + DBDriver: s.driver, } - // 4. Marshal the entire structure directly to JSON data, err := json.MarshalIndent(dump, "", " ") if err != nil { - return nil, fmt.Errorf("failed to marshal database dump to JSON: %w", err) + return nil, fmt.Errorf("failed to marshal DB dump: %w", err) } - return data, nil } -// Restore imports database content from a JSON byte slice, inserting raw data. +// Restore imports database content from a JSON dump (merge or override). func (s *Storage) Restore(ctx context.Context, data []byte, mode DBDumpRestoreMode) error { var dump DBDump - - // 1. Unmarshal top-level structure using standard JSON if err := json.Unmarshal(data, &dump); err != nil { - return fmt.Errorf("failed to unmarshal database dump from JSON: %w", err) + return fmt.Errorf("failed to parse dump JSON: %w", err) } - // 2. Validate Metadata if dump.DBDriver != s.driver { - return fmt.Errorf("database driver mismatch: dump is for '%s', current is '%s'", dump.DBDriver, s.driver) + return fmt.Errorf("database driver mismatch: dump='%s', current='%s'", dump.DBDriver, s.driver) } - // --- 3. Override Mode: Clear Existing Tables --- if mode == RestoreOverride { - if err := s.clearTable(ctx, "clusters"); err != nil { - return fmt.Errorf("failed to clear clusters table for override: %w", err) - } - if err := s.clearTable(ctx, "listeners"); err != nil { - return fmt.Errorf("failed to clear listeners table for override: %w", err) - } - if err := s.clearTable(ctx, "secrets"); err != nil { - return fmt.Errorf("failed to clear secrets table for override: %w", err) - } - if err := s.clearTable(ctx, "certificates"); err != nil { - return fmt.Errorf("failed to clear certificates table for override: %w", err) + for _, tbl := range []string{"clusters", "listeners", "secrets", "certificates"} { + if err := s.clearTable(ctx, tbl); err != nil { + return fmt.Errorf("failed to clear %s: %w", tbl, err) + } } } - // --- 4. Insert/Upsert Data using Raw JSON --- + saveRaw := func(ctx context.Context, table string, rows []*RawRow) error { + for _, r := range rows { + if r == nil || r.Name == "" { + continue + } - // Clusters - if err := s.SaveRawData(ctx, "clusters", dump.Clusters); err != nil { - return fmt.Errorf("failed to save restored cluster raw data: %w", err) + var query string + switch s.driver { + case "postgres": + query = fmt.Sprintf(` + INSERT INTO %s (name, data, enabled, updated_at) + VALUES ($1, $2, true, now()) + ON CONFLICT (name) + DO UPDATE SET data = $2, enabled = true, updated_at = now()`, table) + default: + query = fmt.Sprintf(` + INSERT INTO %s (name, data, enabled, updated_at) + VALUES (?, ?, 1, CURRENT_TIMESTAMP) + ON CONFLICT(name) + DO UPDATE SET data=excluded.data, enabled=1, updated_at=CURRENT_TIMESTAMP`, table) + } + + if _, err := s.db.ExecContext(ctx, query, r.Name, string(r.Data)); err != nil { + return fmt.Errorf("failed to upsert %s '%s': %w", table, r.Name, err) + } + } + return nil } - // Listeners - if err := s.SaveRawData(ctx, "listeners", dump.Listeners); err != nil { - return fmt.Errorf("failed to save restored listener raw data: %w", err) + if err := saveRaw(ctx, "clusters", dump.Clusters); err != nil { + return err + } + if err := saveRaw(ctx, "listeners", dump.Listeners); err != nil { + return err + } + if err := saveRaw(ctx, "secrets", dump.Secrets); err != nil { + return err } - // Secrets - if err := s.SaveRawData(ctx, "secrets", dump.Secrets); err != nil { - return fmt.Errorf("failed to save restored secret raw data: %w", err) - } - - // Certificates (Standard Go struct - SaveCertificate is defined in storage.go) for _, cert := range dump.Certificates { + if cert == nil || cert.Domain == "" { + continue + } if err := s.SaveCertificate(ctx, cert); err != nil { - return fmt.Errorf("failed to save restored certificate %s: %w", cert.Domain, err) + return fmt.Errorf("failed to upsert certificate '%s': %w", cert.Domain, err) } } return nil } -// LoadAllRawData retrieves the raw data column ([]byte/JSONB) from a table. -func (s *Storage) LoadAllRawData(ctx context.Context, table string) ([][]byte, error) { - if table != "clusters" && table != "listeners" && table != "secrets" { - return nil, fmt.Errorf("invalid table name: %s", table) - } - - query := fmt.Sprintf(`SELECT data FROM %s`, table) - - rows, err := s.db.QueryContext(ctx, query) - if err != nil { - return nil, err - } - defer rows.Close() - - var data [][]byte - for rows.Next() { - var rawData []byte - // Handle Postgres JSONB (scans to []byte) vs SQLite TEXT (scans to string, then convert) - if s.driver == "postgres" { - if err := rows.Scan(&rawData); err != nil { - return nil, fmt.Errorf("failed to scan raw postgres data from %s: %w", table, err) - } - } else { // SQLite - var dataStr string - if err := rows.Scan(&dataStr); err != nil { - return nil, fmt.Errorf("failed to scan raw sqlite data from %s: %w", table, err) - } - rawData = []byte(dataStr) - } - data = append(data, rawData) - } - if err := rows.Err(); err != nil { - return nil, err - } - - return data, nil -} - -// SaveRawData handles UPSERT for raw JSON data for clusters, listeners, and secrets. -func (s *Storage) SaveRawData(ctx context.Context, table string, rawData [][]byte) error { - if table != "clusters" && table != "listeners" && table != "secrets" { - return fmt.Errorf("invalid table name for raw data save: %s", table) - } - - for _, data := range rawData { - // To get the name (for UPSERT), we must minimally unmarshal the name field. - var nameExtractor struct { - Name string `json:"name"` - } - if err := json.Unmarshal(data, &nameExtractor); err != nil { - return fmt.Errorf("failed to extract name from raw data for table %s: %w", table, err) - } - - // Use the same UPSERT logic as original Save* methods - var query string - switch s.driver { - case "postgres": - query = fmt.Sprintf(` - INSERT INTO %s (name, data, enabled, updated_at) - VALUES ($1, $2, true, now()) - ON CONFLICT (name) DO UPDATE SET data = $2, enabled = true, updated_at = now()`, table) - default: // SQLite - query = fmt.Sprintf(` - INSERT INTO %s (name, data, enabled, updated_at) - VALUES (?, ?, 1, CURRENT_TIMESTAMP) - ON CONFLICT(name) DO UPDATE SET data=excluded.data, enabled=1, updated_at=CURRENT_TIMESTAMP`, table) - } - - _, err := s.db.ExecContext(ctx, query, nameExtractor.Name, string(data)) - if err != nil { - return fmt.Errorf("failed to upsert raw data for %s: %w", table, err) - } - } - - return nil -} - -// clearTable is a helper function to delete all rows from a table. +// clearTable deletes all rows from the given table (safe against SQLi). func (s *Storage) clearTable(ctx context.Context, table string) error { - // Simple validation to prevent SQL injection on table names - if table != "clusters" && table != "listeners" && table != "secrets" && table != "certificates" { - return fmt.Errorf("invalid table name for clearing: %s", table) + valid := map[string]bool{ + "clusters": true, + "listeners": true, + "secrets": true, + "certificates": true, + } + if !valid[table] { + return fmt.Errorf("invalid table name: %s", table) } - var query string + query := fmt.Sprintf("DELETE FROM %s", table) if s.driver == "postgres" { query = fmt.Sprintf("TRUNCATE TABLE %s RESTART IDENTITY", table) - } else { - query = fmt.Sprintf("DELETE FROM %s", table) } - // Assuming s.db is your database connection pool _, err := s.db.ExecContext(ctx, query) if err != nil { - return fmt.Errorf("error clearing table %s: %w", table, err) + return fmt.Errorf("failed to clear table %s: %w", table, err) } return nil }