Newer
Older
EnvoyControlPlane / internal / storage / storage.go
package storage

import (
	"context"
	"database/sql"
	"encoding/json"
	"fmt"
	"strings"

	clusterv3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3"
	listenerv3 "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3"

	// routev3 "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" // REMOVED

	"google.golang.org/protobuf/encoding/protojson"
)

// Storage abstracts database persistence
type Storage struct {
	db     *sql.DB
	driver string
}

// DeleteStrategy defines the action to take on missing resources
type DeleteStrategy int

const (
	// DeleteNone performs only UPSERT for items in the list (default behavior)
	DeleteNone DeleteStrategy = iota
	// DeleteLogical marks missing resources as disabled (now applicable to clusters and listeners)
	DeleteLogical
	// DeleteActual removes missing resources physically from the database
	DeleteActual
)

// NewStorage initializes a Storage instance
func NewStorage(db *sql.DB, driver string) *Storage {
	return &Storage{db: db, driver: driver}
}

// placeholder returns correct SQL placeholder based on driver
func (s *Storage) placeholder(n int) string {
	if s.driver == "postgres" {
		return fmt.Sprintf("$%d", n)
	}
	return "?"
}

// InitSchema ensures required tables exist
func (s *Storage) InitSchema(ctx context.Context) error {
	var schema string
	switch s.driver {
	case "postgres":
		schema = `
        CREATE TABLE IF NOT EXISTS clusters (
            id SERIAL PRIMARY KEY,
            name TEXT UNIQUE NOT NULL,
            data JSONB NOT NULL,
            enabled BOOLEAN DEFAULT true,
            updated_at TIMESTAMP DEFAULT now()
        );
        -- REMOVED routes table
        CREATE TABLE IF NOT EXISTS listeners (
            id SERIAL PRIMARY KEY,
            name TEXT UNIQUE NOT NULL,
            data JSONB NOT NULL,
            enabled BOOLEAN DEFAULT true,
            updated_at TIMESTAMP DEFAULT now()
        );`
	default: // SQLite
		schema = `
        CREATE TABLE IF NOT EXISTS clusters (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            name TEXT UNIQUE NOT NULL,
            data TEXT NOT NULL,
            enabled BOOLEAN DEFAULT 1,
            updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
        );
        -- REMOVED routes table
        CREATE TABLE IF NOT EXISTS listeners (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            name TEXT UNIQUE NOT NULL,
            data TEXT NOT NULL,
            enabled BOOLEAN DEFAULT 1,
            updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
        );`
	}
	_, err := s.db.ExecContext(ctx, schema)
	return err
}

// SaveCluster inserts or updates a cluster
func (s *Storage) SaveCluster(ctx context.Context, cluster *clusterv3.Cluster) error {
	data, err := protojson.Marshal(cluster)
	if err != nil {
		return err
	}

	var query string
	switch s.driver {
	case "postgres":
		// Explicitly set enabled=true on update to re-enable a logically deleted cluster
		query = fmt.Sprintf(`
            INSERT INTO clusters (name, data, enabled, updated_at)
            VALUES (%s, %s, true, now())
            ON CONFLICT (name) DO UPDATE SET data = %s, enabled = true, updated_at = now()`,
			s.placeholder(1), s.placeholder(2), s.placeholder(2))
	default: // SQLite
		// Explicitly set enabled=1 on update to re-enable a logically deleted cluster
		query = `
            INSERT INTO clusters (name, data, enabled, updated_at)
            VALUES (?, ?, 1, CURRENT_TIMESTAMP)
            ON CONFLICT(name) DO UPDATE SET data=excluded.data, enabled=1, updated_at=CURRENT_TIMESTAMP`
	}

	_, err = s.db.ExecContext(ctx, query, cluster.GetName(), string(data))
	return err
}

// SaveRoute inserts or updates a route // REMOVED
// func (s *Storage) SaveRoute(ctx context.Context, route *routev3.RouteConfiguration) error {
// 	// ... (route logic removed)
// }

// SaveListener inserts or updates a listener
func (s *Storage) SaveListener(ctx context.Context, listener *listenerv3.Listener) error {
	data, err := protojson.Marshal(listener)
	if err != nil {
		return err
	}

	var query string
	switch s.driver {
	case "postgres":
		// Explicitly set enabled=true on update to re-enable a logically deleted listener
		query = fmt.Sprintf(`
            INSERT INTO listeners (name, data, enabled, updated_at)
            VALUES (%s, %s, true, now())
            ON CONFLICT (name) DO UPDATE SET data = %s, enabled = true, updated_at = now()`,
			s.placeholder(1), s.placeholder(2), s.placeholder(2))
	default: // SQLite
		// Explicitly set enabled=1 on update to re-enable a logically deleted listener
		query = `
            INSERT INTO listeners (name, data, enabled, updated_at)
            VALUES (?, ?, 1, CURRENT_TIMESTAMP)
            ON CONFLICT(name) DO UPDATE SET data=excluded.data, enabled=1, updated_at=CURRENT_TIMESTAMP`
	}

	_, err = s.db.ExecContext(ctx, query, listener.GetName(), string(data))
	return err
}

// LoadEnabledClusters retrieves all enabled clusters
func (s *Storage) LoadEnabledClusters(ctx context.Context) ([]*clusterv3.Cluster, error) {
	query := `SELECT data FROM clusters`
	if s.driver == "postgres" {
		query += ` WHERE enabled = true`
	} else {
		query += ` WHERE enabled = 1`
	}

	rows, err := s.db.QueryContext(ctx, query)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var clusters []*clusterv3.Cluster
	for rows.Next() {
		var raw json.RawMessage
		// FIX: Handle type difference between Postgres (JSONB) and SQLite (TEXT)
		if s.driver != "postgres" {
			var dataStr string
			if err := rows.Scan(&dataStr); err != nil {
				return nil, err
			}
			raw = json.RawMessage(dataStr) // Convert string to json.RawMessage
		} else {
			if err := rows.Scan(&raw); err != nil {
				return nil, err
			}
		}

		var cluster clusterv3.Cluster
		if err := protojson.Unmarshal(raw, &cluster); err != nil {
			return nil, err
		}
		clusters = append(clusters, &cluster)
	}
	return clusters, nil
}

// LoadAllClusters retrieves all clusters, regardless of their enabled status
func (s *Storage) LoadAllClusters(ctx context.Context) ([]*clusterv3.Cluster, error) {
	rows, err := s.db.QueryContext(ctx, `SELECT data FROM clusters`)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var clusters []*clusterv3.Cluster
	for rows.Next() {
		var raw json.RawMessage
		// FIX: Handle type difference between Postgres (JSONB) and SQLite (TEXT)
		if s.driver != "postgres" {
			var dataStr string
			if err := rows.Scan(&dataStr); err != nil {
				return nil, err
			}
			raw = json.RawMessage(dataStr) // Convert string to json.RawMessage
		} else {
			if err := rows.Scan(&raw); err != nil {
				return nil, err
			}
		}

		var cluster clusterv3.Cluster
		if err := protojson.Unmarshal(raw, &cluster); err != nil {
			return nil, err
		}
		clusters = append(clusters, &cluster)
	}
	return clusters, nil
}

// LoadEnabledRoutes retrieves all enabled routes // REMOVED
// func (s *Storage) LoadEnabledRoutes(ctx context.Context) ([]*routev3.RouteConfiguration, error) {
// 	// ... (route logic removed)
// }

// LoadAllRoutes retrieves all routes, regardless of their enabled status // REMOVED
// func (s *Storage) LoadAllRoutes(ctx context.Context) ([]*routev3.RouteConfiguration, error) {
// 	// ... (route logic removed)
// }

// LoadEnabledListeners retrieves all enabled listeners
func (s *Storage) LoadEnabledListeners(ctx context.Context) ([]*listenerv3.Listener, error) {
	query := `SELECT data FROM listeners`
	if s.driver == "postgres" {
		query += ` WHERE enabled = true`
	} else {
		query += ` WHERE enabled = 1`
	}

	rows, err := s.db.QueryContext(ctx, query)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var listeners []*listenerv3.Listener
	for rows.Next() {
		var raw json.RawMessage
		// FIX: Handle type difference between Postgres (JSONB) and SQLite (TEXT)
		if s.driver != "postgres" {
			var dataStr string
			if err := rows.Scan(&dataStr); err != nil {
				return nil, err
			}
			raw = json.RawMessage(dataStr) // Convert string to json.RawMessage
		} else {
			if err := rows.Scan(&raw); err != nil {
				return nil, err
			}
		}

		var l listenerv3.Listener
		if err := protojson.Unmarshal(raw, &l); err != nil {
			return nil, err
		}
		listeners = append(listeners, &l)
	}
	return listeners, nil
}

// LoadAllListeners retrieves all listeners, regardless of their enabled status
func (s *Storage) LoadAllListeners(ctx context.Context) ([]*listenerv3.Listener, error) {
	rows, err := s.db.QueryContext(ctx, `SELECT data FROM listeners`)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var listeners []*listenerv3.Listener
	for rows.Next() {
		var raw json.RawMessage
		// FIX: Handle type difference between Postgres (JSONB) and SQLite (TEXT)
		if s.driver != "postgres" {
			var dataStr string
			if err := rows.Scan(&dataStr); err != nil {
				return nil, err
			}
			raw = json.RawMessage(dataStr) // Convert string to json.RawMessage
		} else {
			if err := rows.Scan(&raw); err != nil {
				return nil, err
			}
		}

		var l listenerv3.Listener
		if err := protojson.Unmarshal(raw, &l); err != nil {
			return nil, err
		}
		listeners = append(listeners, &l)
	}
	return listeners, nil
}

// RebuildSnapshot rebuilds full snapshot from DB
func (s *Storage) RebuildSnapshot(ctx context.Context) (*SnapshotConfig, error) {
	// 1. Load Enabled Resources (for xDS serving)
	enabledClusters, err := s.LoadEnabledClusters(ctx)
	if err != nil {
		return nil, err
	}
	// enabledRoutes, err := s.LoadEnabledRoutes(ctx) // REMOVED
	// if err != nil {
	// 	return nil, err
	// }
	enabledListeners, err := s.LoadEnabledListeners(ctx)
	if err != nil {
		return nil, err
	}

	// 2. Load ALL Resources (for comparison and disabled set)
	allClusters, err := s.LoadAllClusters(ctx)
	if err != nil {
		return nil, err
	}
	// allRoutes, err := s.LoadAllRoutes(ctx) // REMOVED
	// if err != nil {
	// 	return nil, err
	// }
	allListeners, err := s.LoadAllListeners(ctx)
	if err != nil {
		return nil, err
	}

	// 3. Separate Disabled Resources

	// Clusters
	enabledClusterNames := make(map[string]struct{}, len(enabledClusters))
	for _, c := range enabledClusters {
		enabledClusterNames[c.GetName()] = struct{}{}
	}
	var disabledClusters []*clusterv3.Cluster
	for _, c := range allClusters {
		if _, found := enabledClusterNames[c.GetName()]; !found {
			disabledClusters = append(disabledClusters, c)
		}
	}

	// Routes // REMOVED
	// enabledRouteNames := make(map[string]struct{}, 0)
	// var disabledRoutes []*routev3.RouteConfiguration
	// for _, r := range allRoutes {
	// 	if _, found := enabledRouteNames[r.GetName()]; !found {
	// 		disabledRoutes = append(disabledRoutes, r)
	// 	}
	// }

	// Listeners
	enabledListenerNames := make(map[string]struct{}, len(enabledListeners))
	for _, l := range enabledListeners {
		enabledListenerNames[l.GetName()] = struct{}{}
	}
	var disabledListeners []*listenerv3.Listener
	for _, l := range allListeners {
		if _, found := enabledListenerNames[l.GetName()]; !found {
			disabledListeners = append(disabledListeners, l)
		}
	}

	return &SnapshotConfig{
		EnabledClusters: enabledClusters,
		// EnabledRoutes:     nil, // REMOVED
		EnabledListeners: enabledListeners,
		DisabledClusters: disabledClusters,
		// DisabledRoutes:    nil, // REMOVED
		DisabledListeners: disabledListeners,
	}, nil
}

// SnapshotConfig aggregates xDS resources
type SnapshotConfig struct {
	// Enabled resources (for xDS serving)
	EnabledClusters []*clusterv3.Cluster
	// EnabledRoutes    []*routev3.RouteConfiguration // REMOVED
	EnabledListeners []*listenerv3.Listener

	// Disabled resources (for UI display)
	DisabledClusters []*clusterv3.Cluster
	// DisabledRoutes    []*routev3.RouteConfiguration // REMOVED
	DisabledListeners []*listenerv3.Listener
}

// EnableCluster toggles a cluster
func (s *Storage) EnableCluster(ctx context.Context, name string, enabled bool) error {
	query := `UPDATE clusters SET enabled = ?, updated_at = CURRENT_TIMESTAMP WHERE name = ?`
	if s.driver == "postgres" {
		query = `UPDATE clusters SET enabled = $1, updated_at = now() WHERE name = $2`
	}
	_, err := s.db.ExecContext(ctx, query, enabled, name)
	return err
}

// EnableRoute toggles a route // REMOVED
// func (s *Storage) EnableRoute(ctx context.Context, name string, enabled bool) error {
// 	// ... (route logic removed)
// }

// EnableListener toggles a listener
func (s *Storage) EnableListener(ctx context.Context, name string, enabled bool) error {
	query := `UPDATE listeners SET enabled = ?, updated_at = CURRENT_TIMESTAMP WHERE name = ?`
	if s.driver == "postgres" {
		query = `UPDATE listeners SET enabled = $1, updated_at = now() WHERE name = $2`
	}
	_, err := s.db.ExecContext(ctx, query, enabled, name)
	return err
}

// disableMissingResources updates the 'enabled' status for resources in 'table'
// whose 'name' is NOT in 'inputNames'.
func (s *Storage) disableMissingResources(ctx context.Context, table string, inputNames []string) error {
	if table != "clusters" && table != "listeners" { // CHECK UPDATED
		return fmt.Errorf("logical delete (disable) is only supported for tables with an 'enabled' column (clusters, listeners)")
	}

	// 1. Build placeholders and args
	placeholders := make([]string, len(inputNames))
	args := make([]interface{}, len(inputNames))
	for i, name := range inputNames {
		if s.driver == "postgres" {
			placeholders[i] = fmt.Sprintf("$%d", i+1)
		} else {
			placeholders[i] = "?"
		}
		args[i] = name
	}

	disabledValue := "false"
	if s.driver != "postgres" {
		disabledValue = "0"
	}

	var updateTime string
	if s.driver == "postgres" {
		updateTime = "now()"
	} else {
		updateTime = "CURRENT_TIMESTAMP"
	}

	// If no names are provided, disable ALL currently enabled resources
	whereClause := ""
	if len(inputNames) > 0 {
		whereClause = fmt.Sprintf("WHERE name NOT IN (%s)", strings.Join(placeholders, ", "))
	}

	// 2. Construct and execute the UPDATE query
	query := fmt.Sprintf(`
        UPDATE %s
        SET enabled = %s, updated_at = %s
        %s`,
		table, disabledValue, updateTime, whereClause)

	_, err := s.db.ExecContext(ctx, query, args...)
	return err
}

// deleteMissingResources physically deletes resources from 'table' whose 'name' is NOT in 'inputNames'.
func (s *Storage) deleteMissingResources(ctx context.Context, table string, inputNames []string) error {
	if table != "clusters" && table != "listeners" { // CHECK UPDATED
		return fmt.Errorf("physical delete is only supported for tables: clusters, listeners")
	}

	// 1. Build placeholders and args
	placeholders := make([]string, len(inputNames))
	args := make([]interface{}, len(inputNames))
	for i, name := range inputNames {
		if s.driver == "postgres" {
			placeholders[i] = fmt.Sprintf("$%d", i+1)
		} else {
			placeholders[i] = "?"
		}
		args[i] = name
	}

	// If no names are provided, delete ALL resources
	whereClause := ""
	if len(inputNames) > 0 {
		whereClause = fmt.Sprintf("WHERE name NOT IN (%s)", strings.Join(placeholders, ", "))
	}

	// 2. Construct and execute the DELETE query
	query := fmt.Sprintf(`
        DELETE FROM %s
        %s`,
		table, whereClause)

	_, err := s.db.ExecContext(ctx, query, args...)
	return err
}

func (s *Storage) SaveSnapshot(ctx context.Context, cfg *SnapshotConfig, strategy DeleteStrategy) error {
	if cfg == nil {
		return fmt.Errorf("SnapshotConfig is nil")
	}

	// Use a transaction for atomicity
	tx, err := s.db.BeginTx(ctx, nil)
	if err != nil {
		return fmt.Errorf("failed to begin transaction: %w", err)
	}
	defer func() {
		if err != nil {
			tx.Rollback()
			return
		}
		err = tx.Commit()
	}()

	// Note: Only Enabledxxx resources are UPSERTED. Disabledxxx resources are
	// left alone unless the deletion strategy removes them.

	// --- 1. Save/Upsert Clusters and Collect Names ---
	clusterNames := make([]string, 0, len(cfg.EnabledClusters))
	for _, c := range cfg.EnabledClusters {
		if err = s.SaveCluster(ctx, c); err != nil {
			return fmt.Errorf("failed to save cluster %s: %w", c.GetName(), err)
		}
		clusterNames = append(clusterNames, c.GetName())
	}

	// --- 2. Save/Upsert Routes and Collect Names --- // REMOVED
	// routeNames := make([]string, 0, len(cfg.EnabledRoutes))
	// for _, r := range cfg.EnabledRoutes {
	// 	if err = s.SaveRoute(ctx, r); err != nil {
	// 		return fmt.Errorf("failed to save route %s: %w", r.GetName(), err)
	// 	}
	// 	routeNames = append(routeNames, r.GetName())
	// }

	// --- 3. Save/Upsert Listeners and Collect Names ---
	listenerNames := make([]string, 0, len(cfg.EnabledListeners))
	for _, l := range cfg.EnabledListeners {
		if err = s.SaveListener(ctx, l); err != nil {
			return fmt.Errorf("failed to save listener %s: %w", l.GetName(), err)
		}
		listenerNames = append(listenerNames, l.GetName())
	}

	// --- 4. Apply Deletion Strategy ---
	switch strategy {
	case DeleteLogical:
		// Logical Delete (Disable) for all resource types: marks resources NOT in the current enabled list as disabled
		if err = s.disableMissingResources(ctx, "clusters", clusterNames); err != nil {
			return fmt.Errorf("failed to logically delete missing clusters: %w", err)
		}
		// if err = s.disableMissingResources(ctx, "routes", routeNames); err != nil { // REMOVED
		// 	return fmt.Errorf("failed to logically delete missing routes: %w", err)
		// }
		if err = s.disableMissingResources(ctx, "listeners", listenerNames); err != nil {
			return fmt.Errorf("failed to logically delete missing listeners: %w", err)
		}

	case DeleteActual:
		// Actual Delete (Physical Removal) for all resources: removes resources NOT in the current enabled list
		if err = s.deleteMissingResources(ctx, "clusters", clusterNames); err != nil {
			return fmt.Errorf("failed to physically delete missing clusters: %w", err)
		}
		// if err = s.deleteMissingResources(ctx, "routes", routeNames); err != nil { // REMOVED
		// 	return fmt.Errorf("failed to physically delete missing routes: %w", err)
		// }
		if err = s.deleteMissingResources(ctx, "listeners", listenerNames); err != nil {
			return fmt.Errorf("failed to physically delete missing listeners: %w", err)
		}

	case DeleteNone:
		// Do nothing for missing resources
		return nil
	}

	return err
}

// RemoveListener deletes a listener by name
func (s *Storage) RemoveListener(ctx context.Context, name string) error {
	query := `DELETE FROM listeners WHERE name = ?`
	if s.driver == "postgres" {
		query = `DELETE FROM listeners WHERE name = $1`
	}
	_, err := s.db.ExecContext(ctx, query, name)
	return err
}

// RemoveCluster deletes a cluster by name
func (s *Storage) RemoveCluster(ctx context.Context, name string) error {
	query := `DELETE FROM clusters WHERE name = ?`
	if s.driver == "postgres" {
		query = `DELETE FROM clusters WHERE name = $1`
	}
	_, err := s.db.ExecContext(ctx, query, name)
	return err
}