package storage
import (
"context"
"database/sql"
"encoding/json"
"fmt"
)
// NOTE: This file assumes the SQLStrategy interface and concrete implementations
// (PostgresStrategy, SQLiteStrategy) defined in the previous response are present.
// DBDump holds the complete state of the database for dumping/restoring.
type DBDump struct {
Clusters []*RawRow `json:"clusters,omitempty"`
Listeners []*RawRow `json:"listeners,omitempty"`
Secrets []*RawRow `json:"secrets,omitempty"`
Certificates []*CertStorage `json:"certificates,omitempty"`
DBDriver string `json:"db_driver"`
}
// DBDumpRestoreMode defines the strategy for restoring the data.
type DBDumpRestoreMode int
const (
// RestoreMerge merges the incoming data with existing data (UPSERT).
RestoreMerge DBDumpRestoreMode = iota
// RestoreOverride deletes all existing data and inserts the new data.
RestoreOverride
)
// -----------------------------------------------------------------------------
// DUMP METHODS (REFACTORED)
// -----------------------------------------------------------------------------
// Dump exports all database content as JSON including certificates.
func (s *Storage) Dump(ctx context.Context) ([]byte, error) {
// Refactored Load helper uses the strategy for SQL generation and scanning
load := func(ctx context.Context, table string) ([]*RawRow, error) {
if table != "clusters" && table != "listeners" && table != "secrets" {
return nil, fmt.Errorf("invalid table for dump: %s", table)
}
// Use the strategy to get the fields to select
fields := s.strategy.DumpSelectFields(table)
query := fmt.Sprintf(`SELECT %s FROM %s`, fields, table)
rows, err := s.db.QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to query table %s: %w", table, err)
}
defer rows.Close()
var out []*RawRow
for rows.Next() {
row := &RawRow{}
// Delegate all scanning logic to the strategy
if err := s.strategy.ScanRawRow(rows, row, table); err != nil {
return nil, fmt.Errorf("scan error on %s: %w", table, err)
}
out = append(out, row)
}
return out, rows.Err()
}
clusters, err := load(ctx, "clusters")
if err != nil {
return nil, err
}
listeners, err := load(ctx, "listeners")
if err != nil {
return nil, err
}
secrets, err := load(ctx, "secrets")
if err != nil {
return nil, err
}
certs, err := s.LoadAllCertificates(ctx)
if err != nil && err != sql.ErrNoRows {
return nil, fmt.Errorf("failed to load certificates: %w", err)
}
dump := &DBDump{
Clusters: clusters,
Listeners: listeners,
Secrets: secrets,
Certificates: certs,
DBDriver: s.strategy.DriverName(), // Use strategy for driver name
}
data, err := json.MarshalIndent(dump, "", " ")
if err != nil {
return nil, fmt.Errorf("failed to marshal DB dump: %w", err)
}
return data, nil
}
// -----------------------------------------------------------------------------
// RESTORE METHODS (REFACTORED)
// -----------------------------------------------------------------------------
// Restore imports database content from a JSON dump (merge or override).
func (s *Storage) Restore(ctx context.Context, data []byte, mode DBDumpRestoreMode) error {
var dump DBDump
if err := json.Unmarshal(data, &dump); err != nil {
return fmt.Errorf("failed to parse dump JSON: %w", err)
}
if dump.DBDriver != s.strategy.DriverName() {
return fmt.Errorf("database driver mismatch: dump='%s', current='%s'", dump.DBDriver, s.strategy.DriverName())
}
if mode == RestoreOverride {
for _, tbl := range []string{"clusters", "listeners", "secrets", "certificates"} {
if err := s.clearTable(ctx, tbl); err != nil {
return fmt.Errorf("failed to clear %s: %w", tbl, err)
}
}
}
// saveRaw now relies entirely on the strategy to provide the correct UPSERT query
saveRaw := func(ctx context.Context, table string, rows []*RawRow) error {
for _, r := range rows {
if r == nil || r.Name == "" {
continue
}
// Get the correct UPSERT query from the strategy
query := s.strategy.RestoreRawRowSQL(table)
var args []interface{}
// Handle arguments based on table (secrets needs 3 args, others need 2)
if table == "secrets" {
args = []interface{}{r.Name, string(r.Data), r.Domain}
} else {
args = []interface{}{r.Name, string(r.Data)}
}
// Postgres requires the data argument to be repeated for the ON CONFLICT clause
if s.strategy.DriverName() == "postgres" {
// The implementation of RestoreRawRowSQL for Postgres must ensure
// that the query only expects the standard set of positional arguments
// for the VALUES clause ($1, $2, [$3]), as the data is repeated
// using the EXCLUDED prefix in Postgres' ON CONFLICT syntax.
// No extra args needed here, the strategy handles the SQL structure.
}
if _, err := s.db.ExecContext(ctx, query, args...); err != nil {
return fmt.Errorf("failed to upsert %s '%s': %w", table, r.Name, err)
}
}
return nil
}
if err := saveRaw(ctx, "clusters", dump.Clusters); err != nil {
return err
}
if err := saveRaw(ctx, "listeners", dump.Listeners); err != nil {
return err
}
if err := saveRaw(ctx, "secrets", dump.Secrets); err != nil {
return err
}
// Certificates still use the core SaveCertificate method, which is already refactored.
for _, cert := range dump.Certificates {
if cert == nil || cert.Domain == "" {
continue
}
if err := s.SaveCertificate(ctx, cert); err != nil {
return fmt.Errorf("failed to upsert certificate '%s': %w", cert.Domain, err)
}
}
return nil
}
// clearTable is now simplified.
func (s *Storage) clearTable(ctx context.Context, table string) error {
valid := map[string]bool{
"clusters": true,
"listeners": true,
"secrets": true,
"certificates": true,
}
if !valid[table] {
return fmt.Errorf("invalid table name: %s", table)
}
// Delegate the specific clear/truncate query to the strategy
query := s.strategy.ClearTableSQL(table)
_, err := s.db.ExecContext(ctx, query)
if err != nil {
return fmt.Errorf("failed to clear table %s: %w", table, err)
}
return nil
}