package storage
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
clusterv3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3"
listenerv3 "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3"
// routev3 "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" // REMOVED
"google.golang.org/protobuf/encoding/protojson"
)
// Storage abstracts database persistence
type Storage struct {
db *sql.DB
driver string
}
// DeleteStrategy defines the action to take on missing resources
type DeleteStrategy int
const (
// DeleteNone performs only UPSERT for items in the list (default behavior)
DeleteNone DeleteStrategy = iota
// DeleteLogical marks missing resources as disabled (now applicable to clusters and listeners)
DeleteLogical
// DeleteActual removes missing resources physically from the database
DeleteActual
)
// NewStorage initializes a Storage instance
func NewStorage(db *sql.DB, driver string) *Storage {
return &Storage{db: db, driver: driver}
}
// placeholder returns correct SQL placeholder based on driver
func (s *Storage) placeholder(n int) string {
if s.driver == "postgres" {
return fmt.Sprintf("$%d", n)
}
return "?"
}
// InitSchema ensures required tables exist
func (s *Storage) InitSchema(ctx context.Context) error {
var schema string
switch s.driver {
case "postgres":
schema = `
CREATE TABLE IF NOT EXISTS clusters (
id SERIAL PRIMARY KEY,
name TEXT UNIQUE NOT NULL,
data JSONB NOT NULL,
enabled BOOLEAN DEFAULT true,
updated_at TIMESTAMP DEFAULT now()
);
-- REMOVED routes table
CREATE TABLE IF NOT EXISTS listeners (
id SERIAL PRIMARY KEY,
name TEXT UNIQUE NOT NULL,
data JSONB NOT NULL,
enabled BOOLEAN DEFAULT true,
updated_at TIMESTAMP DEFAULT now()
);`
default: // SQLite
schema = `
CREATE TABLE IF NOT EXISTS clusters (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE NOT NULL,
data TEXT NOT NULL,
enabled BOOLEAN DEFAULT 1,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
-- REMOVED routes table
CREATE TABLE IF NOT EXISTS listeners (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE NOT NULL,
data TEXT NOT NULL,
enabled BOOLEAN DEFAULT 1,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);`
}
_, err := s.db.ExecContext(ctx, schema)
return err
}
// SaveCluster inserts or updates a cluster
func (s *Storage) SaveCluster(ctx context.Context, cluster *clusterv3.Cluster) error {
data, err := protojson.Marshal(cluster)
if err != nil {
return err
}
var query string
switch s.driver {
case "postgres":
// Explicitly set enabled=true on update to re-enable a logically deleted cluster
query = fmt.Sprintf(`
INSERT INTO clusters (name, data, enabled, updated_at)
VALUES (%s, %s, true, now())
ON CONFLICT (name) DO UPDATE SET data = %s, enabled = true, updated_at = now()`,
s.placeholder(1), s.placeholder(2), s.placeholder(2))
default: // SQLite
// Explicitly set enabled=1 on update to re-enable a logically deleted cluster
query = `
INSERT INTO clusters (name, data, enabled, updated_at)
VALUES (?, ?, 1, CURRENT_TIMESTAMP)
ON CONFLICT(name) DO UPDATE SET data=excluded.data, enabled=1, updated_at=CURRENT_TIMESTAMP`
}
_, err = s.db.ExecContext(ctx, query, cluster.GetName(), string(data))
return err
}
// SaveRoute inserts or updates a route // REMOVED
// func (s *Storage) SaveRoute(ctx context.Context, route *routev3.RouteConfiguration) error {
// // ... (route logic removed)
// }
// SaveListener inserts or updates a listener
func (s *Storage) SaveListener(ctx context.Context, listener *listenerv3.Listener) error {
data, err := protojson.Marshal(listener)
if err != nil {
return err
}
var query string
switch s.driver {
case "postgres":
// Explicitly set enabled=true on update to re-enable a logically deleted listener
query = fmt.Sprintf(`
INSERT INTO listeners (name, data, enabled, updated_at)
VALUES (%s, %s, true, now())
ON CONFLICT (name) DO UPDATE SET data = %s, enabled = true, updated_at = now()`,
s.placeholder(1), s.placeholder(2), s.placeholder(2))
default: // SQLite
// Explicitly set enabled=1 on update to re-enable a logically deleted listener
query = `
INSERT INTO listeners (name, data, enabled, updated_at)
VALUES (?, ?, 1, CURRENT_TIMESTAMP)
ON CONFLICT(name) DO UPDATE SET data=excluded.data, enabled=1, updated_at=CURRENT_TIMESTAMP`
}
_, err = s.db.ExecContext(ctx, query, listener.GetName(), string(data))
return err
}
// LoadEnabledClusters retrieves all enabled clusters
func (s *Storage) LoadEnabledClusters(ctx context.Context) ([]*clusterv3.Cluster, error) {
query := `SELECT data FROM clusters`
if s.driver == "postgres" {
query += ` WHERE enabled = true`
} else {
query += ` WHERE enabled = 1`
}
rows, err := s.db.QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
var clusters []*clusterv3.Cluster
for rows.Next() {
var raw json.RawMessage
// FIX: Handle type difference between Postgres (JSONB) and SQLite (TEXT)
if s.driver != "postgres" {
var dataStr string
if err := rows.Scan(&dataStr); err != nil {
return nil, err
}
raw = json.RawMessage(dataStr) // Convert string to json.RawMessage
} else {
if err := rows.Scan(&raw); err != nil {
return nil, err
}
}
var cluster clusterv3.Cluster
if err := protojson.Unmarshal(raw, &cluster); err != nil {
return nil, err
}
clusters = append(clusters, &cluster)
}
return clusters, nil
}
// LoadAllClusters retrieves all clusters, regardless of their enabled status
func (s *Storage) LoadAllClusters(ctx context.Context) ([]*clusterv3.Cluster, error) {
rows, err := s.db.QueryContext(ctx, `SELECT data FROM clusters`)
if err != nil {
return nil, err
}
defer rows.Close()
var clusters []*clusterv3.Cluster
for rows.Next() {
var raw json.RawMessage
// FIX: Handle type difference between Postgres (JSONB) and SQLite (TEXT)
if s.driver != "postgres" {
var dataStr string
if err := rows.Scan(&dataStr); err != nil {
return nil, err
}
raw = json.RawMessage(dataStr) // Convert string to json.RawMessage
} else {
if err := rows.Scan(&raw); err != nil {
return nil, err
}
}
var cluster clusterv3.Cluster
if err := protojson.Unmarshal(raw, &cluster); err != nil {
return nil, err
}
clusters = append(clusters, &cluster)
}
return clusters, nil
}
// LoadEnabledRoutes retrieves all enabled routes // REMOVED
// func (s *Storage) LoadEnabledRoutes(ctx context.Context) ([]*routev3.RouteConfiguration, error) {
// // ... (route logic removed)
// }
// LoadAllRoutes retrieves all routes, regardless of their enabled status // REMOVED
// func (s *Storage) LoadAllRoutes(ctx context.Context) ([]*routev3.RouteConfiguration, error) {
// // ... (route logic removed)
// }
// LoadEnabledListeners retrieves all enabled listeners
func (s *Storage) LoadEnabledListeners(ctx context.Context) ([]*listenerv3.Listener, error) {
query := `SELECT data FROM listeners`
if s.driver == "postgres" {
query += ` WHERE enabled = true`
} else {
query += ` WHERE enabled = 1`
}
rows, err := s.db.QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
var listeners []*listenerv3.Listener
for rows.Next() {
var raw json.RawMessage
// FIX: Handle type difference between Postgres (JSONB) and SQLite (TEXT)
if s.driver != "postgres" {
var dataStr string
if err := rows.Scan(&dataStr); err != nil {
return nil, err
}
raw = json.RawMessage(dataStr) // Convert string to json.RawMessage
} else {
if err := rows.Scan(&raw); err != nil {
return nil, err
}
}
var l listenerv3.Listener
if err := protojson.Unmarshal(raw, &l); err != nil {
return nil, err
}
listeners = append(listeners, &l)
}
return listeners, nil
}
// LoadAllListeners retrieves all listeners, regardless of their enabled status
func (s *Storage) LoadAllListeners(ctx context.Context) ([]*listenerv3.Listener, error) {
rows, err := s.db.QueryContext(ctx, `SELECT data FROM listeners`)
if err != nil {
return nil, err
}
defer rows.Close()
var listeners []*listenerv3.Listener
for rows.Next() {
var raw json.RawMessage
// FIX: Handle type difference between Postgres (JSONB) and SQLite (TEXT)
if s.driver != "postgres" {
var dataStr string
if err := rows.Scan(&dataStr); err != nil {
return nil, err
}
raw = json.RawMessage(dataStr) // Convert string to json.RawMessage
} else {
if err := rows.Scan(&raw); err != nil {
return nil, err
}
}
var l listenerv3.Listener
if err := protojson.Unmarshal(raw, &l); err != nil {
return nil, err
}
listeners = append(listeners, &l)
}
return listeners, nil
}
// RebuildSnapshot rebuilds full snapshot from DB
func (s *Storage) RebuildSnapshot(ctx context.Context) (*SnapshotConfig, error) {
// 1. Load Enabled Resources (for xDS serving)
enabledClusters, err := s.LoadEnabledClusters(ctx)
if err != nil {
return nil, err
}
// enabledRoutes, err := s.LoadEnabledRoutes(ctx) // REMOVED
// if err != nil {
// return nil, err
// }
enabledListeners, err := s.LoadEnabledListeners(ctx)
if err != nil {
return nil, err
}
// 2. Load ALL Resources (for comparison and disabled set)
allClusters, err := s.LoadAllClusters(ctx)
if err != nil {
return nil, err
}
// allRoutes, err := s.LoadAllRoutes(ctx) // REMOVED
// if err != nil {
// return nil, err
// }
allListeners, err := s.LoadAllListeners(ctx)
if err != nil {
return nil, err
}
// 3. Separate Disabled Resources
// Clusters
enabledClusterNames := make(map[string]struct{}, len(enabledClusters))
for _, c := range enabledClusters {
enabledClusterNames[c.GetName()] = struct{}{}
}
var disabledClusters []*clusterv3.Cluster
for _, c := range allClusters {
if _, found := enabledClusterNames[c.GetName()]; !found {
disabledClusters = append(disabledClusters, c)
}
}
// Routes // REMOVED
// enabledRouteNames := make(map[string]struct{}, 0)
// var disabledRoutes []*routev3.RouteConfiguration
// for _, r := range allRoutes {
// if _, found := enabledRouteNames[r.GetName()]; !found {
// disabledRoutes = append(disabledRoutes, r)
// }
// }
// Listeners
enabledListenerNames := make(map[string]struct{}, len(enabledListeners))
for _, l := range enabledListeners {
enabledListenerNames[l.GetName()] = struct{}{}
}
var disabledListeners []*listenerv3.Listener
for _, l := range allListeners {
if _, found := enabledListenerNames[l.GetName()]; !found {
disabledListeners = append(disabledListeners, l)
}
}
return &SnapshotConfig{
EnabledClusters: enabledClusters,
// EnabledRoutes: nil, // REMOVED
EnabledListeners: enabledListeners,
DisabledClusters: disabledClusters,
// DisabledRoutes: nil, // REMOVED
DisabledListeners: disabledListeners,
}, nil
}
// SnapshotConfig aggregates xDS resources
type SnapshotConfig struct {
// Enabled resources (for xDS serving)
EnabledClusters []*clusterv3.Cluster
// EnabledRoutes []*routev3.RouteConfiguration // REMOVED
EnabledListeners []*listenerv3.Listener
// Disabled resources (for UI display)
DisabledClusters []*clusterv3.Cluster
// DisabledRoutes []*routev3.RouteConfiguration // REMOVED
DisabledListeners []*listenerv3.Listener
}
// EnableCluster toggles a cluster
func (s *Storage) EnableCluster(ctx context.Context, name string, enabled bool) error {
query := `UPDATE clusters SET enabled = ?, updated_at = CURRENT_TIMESTAMP WHERE name = ?`
if s.driver == "postgres" {
query = `UPDATE clusters SET enabled = $1, updated_at = now() WHERE name = $2`
}
_, err := s.db.ExecContext(ctx, query, enabled, name)
return err
}
// EnableRoute toggles a route // REMOVED
// func (s *Storage) EnableRoute(ctx context.Context, name string, enabled bool) error {
// // ... (route logic removed)
// }
// EnableListener toggles a listener
func (s *Storage) EnableListener(ctx context.Context, name string, enabled bool) error {
query := `UPDATE listeners SET enabled = ?, updated_at = CURRENT_TIMESTAMP WHERE name = ?`
if s.driver == "postgres" {
query = `UPDATE listeners SET enabled = $1, updated_at = now() WHERE name = $2`
}
_, err := s.db.ExecContext(ctx, query, enabled, name)
return err
}
// disableMissingResources updates the 'enabled' status for resources in 'table'
// whose 'name' is NOT in 'inputNames'.
func (s *Storage) disableMissingResources(ctx context.Context, table string, inputNames []string) error {
if table != "clusters" && table != "listeners" { // CHECK UPDATED
return fmt.Errorf("logical delete (disable) is only supported for tables with an 'enabled' column (clusters, listeners)")
}
// 1. Build placeholders and args
placeholders := make([]string, len(inputNames))
args := make([]interface{}, len(inputNames))
for i, name := range inputNames {
if s.driver == "postgres" {
placeholders[i] = fmt.Sprintf("$%d", i+1)
} else {
placeholders[i] = "?"
}
args[i] = name
}
disabledValue := "false"
if s.driver != "postgres" {
disabledValue = "0"
}
var updateTime string
if s.driver == "postgres" {
updateTime = "now()"
} else {
updateTime = "CURRENT_TIMESTAMP"
}
// If no names are provided, disable ALL currently enabled resources
whereClause := ""
if len(inputNames) > 0 {
whereClause = fmt.Sprintf("WHERE name NOT IN (%s)", strings.Join(placeholders, ", "))
}
// 2. Construct and execute the UPDATE query
query := fmt.Sprintf(`
UPDATE %s
SET enabled = %s, updated_at = %s
%s`,
table, disabledValue, updateTime, whereClause)
_, err := s.db.ExecContext(ctx, query, args...)
return err
}
// deleteMissingResources physically deletes resources from 'table' whose 'name' is NOT in 'inputNames'.
func (s *Storage) deleteMissingResources(ctx context.Context, table string, inputNames []string) error {
if table != "clusters" && table != "listeners" { // CHECK UPDATED
return fmt.Errorf("physical delete is only supported for tables: clusters, listeners")
}
// 1. Build placeholders and args
placeholders := make([]string, len(inputNames))
args := make([]interface{}, len(inputNames))
for i, name := range inputNames {
if s.driver == "postgres" {
placeholders[i] = fmt.Sprintf("$%d", i+1)
} else {
placeholders[i] = "?"
}
args[i] = name
}
// If no names are provided, delete ALL resources
whereClause := ""
if len(inputNames) > 0 {
whereClause = fmt.Sprintf("WHERE name NOT IN (%s)", strings.Join(placeholders, ", "))
}
// 2. Construct and execute the DELETE query
query := fmt.Sprintf(`
DELETE FROM %s
%s`,
table, whereClause)
_, err := s.db.ExecContext(ctx, query, args...)
return err
}
func (s *Storage) SaveSnapshot(ctx context.Context, cfg *SnapshotConfig, strategy DeleteStrategy) error {
if cfg == nil {
return fmt.Errorf("SnapshotConfig is nil")
}
// Use a transaction for atomicity
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer func() {
if err != nil {
tx.Rollback()
return
}
err = tx.Commit()
}()
// Note: Only Enabledxxx resources are UPSERTED. Disabledxxx resources are
// left alone unless the deletion strategy removes them.
// --- 1. Save/Upsert Clusters and Collect Names ---
clusterNames := make([]string, 0, len(cfg.EnabledClusters))
for _, c := range cfg.EnabledClusters {
if err = s.SaveCluster(ctx, c); err != nil {
return fmt.Errorf("failed to save cluster %s: %w", c.GetName(), err)
}
clusterNames = append(clusterNames, c.GetName())
}
// --- 2. Save/Upsert Routes and Collect Names --- // REMOVED
// routeNames := make([]string, 0, len(cfg.EnabledRoutes))
// for _, r := range cfg.EnabledRoutes {
// if err = s.SaveRoute(ctx, r); err != nil {
// return fmt.Errorf("failed to save route %s: %w", r.GetName(), err)
// }
// routeNames = append(routeNames, r.GetName())
// }
// --- 3. Save/Upsert Listeners and Collect Names ---
listenerNames := make([]string, 0, len(cfg.EnabledListeners))
for _, l := range cfg.EnabledListeners {
if err = s.SaveListener(ctx, l); err != nil {
return fmt.Errorf("failed to save listener %s: %w", l.GetName(), err)
}
listenerNames = append(listenerNames, l.GetName())
}
// --- 4. Apply Deletion Strategy ---
switch strategy {
case DeleteLogical:
// Logical Delete (Disable) for all resource types: marks resources NOT in the current enabled list as disabled
if err = s.disableMissingResources(ctx, "clusters", clusterNames); err != nil {
return fmt.Errorf("failed to logically delete missing clusters: %w", err)
}
// if err = s.disableMissingResources(ctx, "routes", routeNames); err != nil { // REMOVED
// return fmt.Errorf("failed to logically delete missing routes: %w", err)
// }
if err = s.disableMissingResources(ctx, "listeners", listenerNames); err != nil {
return fmt.Errorf("failed to logically delete missing listeners: %w", err)
}
case DeleteActual:
// Actual Delete (Physical Removal) for all resources: removes resources NOT in the current enabled list
if err = s.deleteMissingResources(ctx, "clusters", clusterNames); err != nil {
return fmt.Errorf("failed to physically delete missing clusters: %w", err)
}
// if err = s.deleteMissingResources(ctx, "routes", routeNames); err != nil { // REMOVED
// return fmt.Errorf("failed to physically delete missing routes: %w", err)
// }
if err = s.deleteMissingResources(ctx, "listeners", listenerNames); err != nil {
return fmt.Errorf("failed to physically delete missing listeners: %w", err)
}
case DeleteNone:
// Do nothing for missing resources
return nil
}
return err
}
// RemoveListener deletes a listener by name
func (s *Storage) RemoveListener(ctx context.Context, name string) error {
query := `DELETE FROM listeners WHERE name = ?`
if s.driver == "postgres" {
query = `DELETE FROM listeners WHERE name = $1`
}
_, err := s.db.ExecContext(ctx, query, name)
return err
}
// RemoveCluster deletes a cluster by name
func (s *Storage) RemoveCluster(ctx context.Context, name string) error {
query := `DELETE FROM clusters WHERE name = ?`
if s.driver == "postgres" {
query = `DELETE FROM clusters WHERE name = $1`
}
_, err := s.db.ExecContext(ctx, query, name)
return err
}