Newer
Older
EnvoyControlPlane / internal / snapshot.go
package internal

import (
	"context"
	"encoding/json"
	"fmt"
	"os"
	"time"

	clusterv3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3"
	corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
	endpointv3 "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3"
	listenerv3 "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3"
	routev3 "github.com/envoyproxy/go-control-plane/envoy/config/route/v3"
	_ "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/jwt_authn/v3"
	_ "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/lua/v3"
	_ "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/oauth2/v3"
	_ "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/router/v3"
	_ "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/listener/tls_inspector/v3"
	secretv3 "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3"

	// New Import for Router Filter
	_ "github.com/envoyproxy/go-control-plane/envoy/service/runtime/v3"
	"github.com/envoyproxy/go-control-plane/pkg/cache/types"
	cachev3 "github.com/envoyproxy/go-control-plane/pkg/cache/v3"
	resourcev3 "github.com/envoyproxy/go-control-plane/pkg/resource/v3"
	"google.golang.org/protobuf/encoding/protojson"
	"google.golang.org/protobuf/types/known/durationpb"
	yaml "gopkg.in/yaml.v3"
)

// SnapshotManager wraps a SnapshotCache and provides file loading/modifying
type SnapshotManager struct {
	Cache  cachev3.SnapshotCache
	NodeID string
}

// NewSnapshotManager creates a new manager for a given cache and node
func NewSnapshotManager(cache cachev3.SnapshotCache, nodeID string) *SnapshotManager {
	return &SnapshotManager{
		Cache:  cache,
		NodeID: nodeID,
	}
}

// YamlResources is a helper struct to unmarshal the common Envoy YAML file structure
type YamlResources struct {
	Resources []yaml.Node `yaml:"resources"`
}

func (sm *SnapshotManager) LoadSnapshotFromFile(filePath string) (map[resourcev3.Type][]types.Resource, error) {
	data, err := os.ReadFile(filePath)
	if err != nil {
		return nil, fmt.Errorf("failed to read file: %w", err)
	}

	var raw interface{}
	if err := yaml.Unmarshal(data, &raw); err != nil {
		return nil, fmt.Errorf("failed to unmarshal YAML/JSON file %s: %w", filePath, err)
	}

	resources := make(map[resourcev3.Type][]types.Resource)

	var walk func(node interface{}) error
	walk = func(node interface{}) error {
		switch v := node.(type) {
		case map[string]interface{}:
			if typStr, ok := v["@type"].(string); ok {
				typ := resourcev3.Type(typStr)

				// only process known top-level xDS resources
				switch typ {
				case resourcev3.ClusterType,
					resourcev3.RouteType,
					resourcev3.ListenerType,
					resourcev3.EndpointType,
					resourcev3.SecretType,
					resourcev3.RuntimeType:

					// Remove @type before unmarshalling
					delete(v, "@type")

					jsonBytes, err := json.Marshal(v)
					if err != nil {
						return fmt.Errorf("failed to marshal resource node to JSON: %w", err)
					}

					fmt.Printf("Detected resource type: %s\n", typ)

					switch typ {
					case resourcev3.ClusterType:
						var c clusterv3.Cluster
						if err := protojson.Unmarshal(jsonBytes, &c); err != nil {
							return fmt.Errorf("failed to unmarshal Cluster: %w", err)
						}
						resources[typ] = append(resources[typ], &c)

					case resourcev3.RouteType:
						var rt routev3.RouteConfiguration
						if err := protojson.Unmarshal(jsonBytes, &rt); err != nil {
							return fmt.Errorf("failed to unmarshal RouteConfiguration: %w", err)
						}
						resources[typ] = append(resources[typ], &rt)

					case resourcev3.ListenerType:
						var l listenerv3.Listener
						if err := protojson.Unmarshal(jsonBytes, &l); err != nil {
							return fmt.Errorf("failed to unmarshal Listener: %w", err)
						}
						resources[typ] = append(resources[typ], &l)

					case resourcev3.EndpointType:
						var eds endpointv3.ClusterLoadAssignment
						if err := protojson.Unmarshal(jsonBytes, &eds); err != nil {
							return fmt.Errorf("failed to unmarshal ClusterLoadAssignment: %w", err)
						}
						resources[typ] = append(resources[typ], &eds)

					case resourcev3.SecretType:
						var sec secretv3.Secret
						if err := protojson.Unmarshal(jsonBytes, &sec); err != nil {
							return fmt.Errorf("failed to unmarshal Secret: %w", err)
						}
						resources[typ] = append(resources[typ], &sec)
					}
				default:
					// skip nested extension/filter types (handled inside parent)
					fmt.Printf("Skipping nested type: %s\n", typStr)
				}
			}

			// recurse into children
			for _, child := range v {
				if err := walk(child); err != nil {
					return err
				}
			}

		case []interface{}:
			for _, item := range v {
				if err := walk(item); err != nil {
					return err
				}
			}
		}
		return nil
	}

	if err := walk(raw); err != nil {
		return nil, err
	}

	return resources, nil
}

// SetSnapshot sets a full snapshot
func (sm *SnapshotManager) SetSnapshot(ctx context.Context, version string, resources map[resourcev3.Type][]types.Resource) error {
	snap, err := cachev3.NewSnapshot(version, resources)
	if err != nil {
		return fmt.Errorf("failed to create snapshot: %w", err)
	}
	return sm.Cache.SetSnapshot(ctx, sm.NodeID, snap)
}

// ---------------- Add / Remove / List ----------------

// AddCluster adds a cluster to the snapshot
func (sm *SnapshotManager) AddCluster(cluster *clusterv3.Cluster) error {
	snap, _ := sm.Cache.GetSnapshot(sm.NodeID)
	clusters := mapToSlice(snap.GetResources(string(resourcev3.ClusterType)))
	routes := mapToSlice(snap.GetResources(string(resourcev3.RouteType)))
	listeners := mapToSlice(snap.GetResources(string(resourcev3.ListenerType)))

	clusters = append(clusters, cluster)

	newSnap, _ := cachev3.NewSnapshot(
		"snap-"+cluster.GetName(),
		map[resourcev3.Type][]types.Resource{
			resourcev3.ClusterType:  clusters,
			resourcev3.RouteType:    routes,
			resourcev3.ListenerType: listeners,
		},
	)
	return sm.Cache.SetSnapshot(context.TODO(), sm.NodeID, newSnap)
}

// AddRoute adds a route configuration to the snapshot
func (sm *SnapshotManager) AddRoute(route *routev3.RouteConfiguration) error {
	snap, _ := sm.Cache.GetSnapshot(sm.NodeID)
	clusters := mapToSlice(snap.GetResources(string(resourcev3.ClusterType)))
	routes := mapToSlice(snap.GetResources(string(resourcev3.RouteType)))
	listeners := mapToSlice(snap.GetResources(string(resourcev3.ListenerType)))

	routes = append(routes, route)

	newSnap, _ := cachev3.NewSnapshot(
		"snap-"+route.GetName(),
		map[resourcev3.Type][]types.Resource{
			resourcev3.ClusterType:  clusters,
			resourcev3.RouteType:    routes,
			resourcev3.ListenerType: listeners,
		},
	)
	return sm.Cache.SetSnapshot(context.TODO(), sm.NodeID, newSnap)
}

// AddListener adds a listener to the snapshot
func (sm *SnapshotManager) AddListener(listener *listenerv3.Listener) error {
	snap, _ := sm.Cache.GetSnapshot(sm.NodeID)
	clusters := mapToSlice(snap.GetResources(string(resourcev3.ClusterType)))
	routes := mapToSlice(snap.GetResources(string(resourcev3.RouteType)))
	listeners := mapToSlice(snap.GetResources(string(resourcev3.ListenerType)))

	listeners = append(listeners, listener)

	newSnap, _ := cachev3.NewSnapshot(
		"snap-"+listener.GetName(),
		map[resourcev3.Type][]types.Resource{
			resourcev3.ClusterType:  clusters,
			resourcev3.RouteType:    routes,
			resourcev3.ListenerType: listeners,
		},
	)
	return sm.Cache.SetSnapshot(context.TODO(), sm.NodeID, newSnap)
}

// RemoveCluster removes a cluster by name
func (sm *SnapshotManager) RemoveCluster(name string) error {
	snap, _ := sm.Cache.GetSnapshot(sm.NodeID)
	clusters := []types.Resource{}
	routes := mapToSlice(snap.GetResources(string(resourcev3.RouteType)))
	listeners := mapToSlice(snap.GetResources(string(resourcev3.ListenerType)))

	for _, r := range snap.GetResources(string(resourcev3.ClusterType)) {
		if c, ok := r.(*clusterv3.Cluster); ok && c.GetName() != name {
			clusters = append(clusters, c)
		}
	}

	newSnap, _ := cachev3.NewSnapshot(
		"snap-remove-"+name,
		map[resourcev3.Type][]types.Resource{
			resourcev3.ClusterType:  clusters,
			resourcev3.RouteType:    routes,
			resourcev3.ListenerType: listeners,
		},
	)
	return sm.Cache.SetSnapshot(context.TODO(), sm.NodeID, newSnap)
}

// RemoveRoute removes a route by name
func (sm *SnapshotManager) RemoveRoute(name string) error {
	snap, _ := sm.Cache.GetSnapshot(sm.NodeID)
	clusters := mapToSlice(snap.GetResources(string(resourcev3.ClusterType)))
	routes := []types.Resource{}
	listeners := mapToSlice(snap.GetResources(string(resourcev3.ListenerType)))

	for _, r := range snap.GetResources(string(resourcev3.RouteType)) {
		if rt, ok := r.(*routev3.RouteConfiguration); ok && rt.GetName() != name {
			routes = append(routes, rt)
		}
	}

	newSnap, _ := cachev3.NewSnapshot(
		"snap-remove-route-"+name,
		map[resourcev3.Type][]types.Resource{
			resourcev3.ClusterType:  clusters,
			resourcev3.RouteType:    routes,
			resourcev3.ListenerType: listeners,
		},
	)
	return sm.Cache.SetSnapshot(context.TODO(), sm.NodeID, newSnap)
}

// RemoveListener removes a listener by name
func (sm *SnapshotManager) RemoveListener(name string) error {
	snap, _ := sm.Cache.GetSnapshot(sm.NodeID)
	clusters := mapToSlice(snap.GetResources(string(resourcev3.ClusterType)))
	routes := mapToSlice(snap.GetResources(string(resourcev3.RouteType)))
	listeners := []types.Resource{}

	for _, r := range snap.GetResources(string(resourcev3.ListenerType)) {
		if l, ok := r.(*listenerv3.Listener); ok && l.GetName() != name {
			listeners = append(listeners, l)
		}
	}

	newSnap, _ := cachev3.NewSnapshot(
		"snap-remove-listener-"+name,
		map[resourcev3.Type][]types.Resource{
			resourcev3.ClusterType:  clusters,
			resourcev3.RouteType:    routes,
			resourcev3.ListenerType: listeners,
		},
	)
	return sm.Cache.SetSnapshot(context.TODO(), sm.NodeID, newSnap)
}

// ---------------- List ----------------

func (sm *SnapshotManager) ListClusters() ([]*clusterv3.Cluster, error) {
	snap, err := sm.Cache.GetSnapshot(sm.NodeID)
	if err != nil {
		return nil, err
	}
	clusters := []*clusterv3.Cluster{}
	for _, r := range snap.GetResources(string(resourcev3.ClusterType)) {
		if c, ok := r.(*clusterv3.Cluster); ok {
			clusters = append(clusters, c)
		}
	}
	return clusters, nil
}

func (sm *SnapshotManager) ListRoutes() ([]*routev3.RouteConfiguration, error) {
	snap, err := sm.Cache.GetSnapshot(sm.NodeID)
	if err != nil {
		return nil, err
	}
	routes := []*routev3.RouteConfiguration{}
	for _, r := range snap.GetResources(string(resourcev3.RouteType)) {
		if rt, ok := r.(*routev3.RouteConfiguration); ok {
			routes = append(routes, rt)
		}
	}
	return routes, nil
}

func (sm *SnapshotManager) ListListeners() ([]*listenerv3.Listener, error) {
	snap, err := sm.Cache.GetSnapshot(sm.NodeID)
	if err != nil {
		return nil, err
	}
	listeners := []*listenerv3.Listener{}
	for _, r := range snap.GetResources(string(resourcev3.ListenerType)) {
		if l, ok := r.(*listenerv3.Listener); ok {
			listeners = append(listeners, l)
		}
	}
	return listeners, nil
}

// ---------------- Get ----------------

func (sm *SnapshotManager) GetCluster(name string) (*clusterv3.Cluster, error) {
	snap, err := sm.Cache.GetSnapshot(sm.NodeID)
	if err != nil {
		return nil, err
	}
	r, ok := snap.GetResources(string(resourcev3.ClusterType))[name]
	if !ok {
		return nil, fmt.Errorf("cluster %s not found", name)
	}
	if c, ok := r.(*clusterv3.Cluster); ok {
		return c, nil
	}
	return nil, fmt.Errorf("resource %s found, but is not a Cluster", name)
}

func (sm *SnapshotManager) GetRoute(name string) (*routev3.RouteConfiguration, error) {
	snap, err := sm.Cache.GetSnapshot(sm.NodeID)
	if err != nil {
		return nil, err
	}
	r, ok := snap.GetResources(string(resourcev3.RouteType))[name]
	if !ok {
		return nil, fmt.Errorf("route %s not found", name)
	}
	if rt, ok := r.(*routev3.RouteConfiguration); ok {
		return rt, nil
	}
	return nil, fmt.Errorf("resource %s found, but is not a RouteConfiguration", name)
}

func (sm *SnapshotManager) GetListener(name string) (*listenerv3.Listener, error) {
	snap, err := sm.Cache.GetSnapshot(sm.NodeID)
	if err != nil {
		return nil, err
	}
	r, ok := snap.GetResources(string(resourcev3.ListenerType))[name]
	if !ok {
		return nil, fmt.Errorf("listener %s not found", name)
	}
	if l, ok := r.(*listenerv3.Listener); ok {
		return l, nil
	}
	return nil, fmt.Errorf("resource %s found, but is not a Listener", name)
}

// ---------------- Save ----------------

func (sm *SnapshotManager) SaveSnapshotToFile(filePath string) error {
	snap, err := sm.Cache.GetSnapshot(sm.NodeID)
	if err != nil {
		return err
	}

	out := make(map[string][]interface{})

	for _, r := range snap.GetResources(string(resourcev3.ClusterType)) {
		if c, ok := r.(*clusterv3.Cluster); ok {
			out[string(resourcev3.ClusterType)] = append(out[string(resourcev3.ClusterType)], c)
		}
	}
	for _, r := range snap.GetResources(string(resourcev3.RouteType)) {
		if rt, ok := r.(*routev3.RouteConfiguration); ok {
			out[string(resourcev3.RouteType)] = append(out[string(resourcev3.RouteType)], rt)
		}
	}
	for _, r := range snap.GetResources(string(resourcev3.ListenerType)) {
		if l, ok := r.(*listenerv3.Listener); ok {
			out[string(resourcev3.ListenerType)] = append(out[string(resourcev3.ListenerType)], l)
		}
	}

	data, err := json.MarshalIndent(out, "", "  ")
	if err != nil {
		return err
	}

	return os.WriteFile(filePath, data, 0644)
}

// ---------------- Helpers ----------------

func mapToSlice(m map[string]types.Resource) []types.Resource {
	out := []types.Resource{}
	for _, r := range m {
		out = append(out, r)
	}
	return out
}

// NewCluster creates a simple cluster
func NewCluster(name string) *clusterv3.Cluster {
	return &clusterv3.Cluster{
		Name:           name,
		ConnectTimeout: durationpb.New(5 * time.Second),
		ClusterDiscoveryType: &clusterv3.Cluster_Type{
			Type: clusterv3.Cluster_EDS,
		},
		LbPolicy: clusterv3.Cluster_ROUND_ROBIN,
	}
}

// NewRoute creates a simple route tied to a cluster
func NewRoute(name, clusterName, prefix string) *routev3.RouteConfiguration {
	return &routev3.RouteConfiguration{
		Name: name,
		VirtualHosts: []*routev3.VirtualHost{
			{
				Name:    "vh-" + name,
				Domains: []string{"*"},
				Routes: []*routev3.Route{
					{
						Match: &routev3.RouteMatch{
							PathSpecifier: &routev3.RouteMatch_Prefix{Prefix: prefix},
						},
						Action: &routev3.Route_Route{
							Route: &routev3.RouteAction{
								ClusterSpecifier: &routev3.RouteAction_Cluster{Cluster: clusterName},
							},
						},
					},
				},
			},
		},
	}
}

// NewListener creates a simple TCP listener for a given port
func NewListener(name string, port uint32) *listenerv3.Listener {
	return &listenerv3.Listener{
		Name: name,
		Address: &corev3.Address{
			Address: &corev3.Address_SocketAddress{
				SocketAddress: &corev3.SocketAddress{
					Protocol: corev3.SocketAddress_TCP,
					Address:  "0.0.0.0",
					PortSpecifier: &corev3.SocketAddress_PortValue{
						PortValue: port,
					},
				},
			},
		},
	}
}

// ---------------- Generic Helpers for all xDS types ----------------

// AddResource adds any resource to the snapshot dynamically
func (sm *SnapshotManager) AddResource(resource types.Resource, typ resourcev3.Type) error {
	snap, _ := sm.Cache.GetSnapshot(sm.NodeID)

	// Convert existing resources to slices
	clusters := mapToSlice(snap.GetResources(string(resourcev3.ClusterType)))
	routes := mapToSlice(snap.GetResources(string(resourcev3.RouteType)))
	listeners := mapToSlice(snap.GetResources(string(resourcev3.ListenerType)))
	endpoints := mapToSlice(snap.GetResources(string(resourcev3.EndpointType)))
	secrets := mapToSlice(snap.GetResources(string(resourcev3.SecretType)))
	runtimes := mapToSlice(snap.GetResources(string(resourcev3.RuntimeType)))
	extConfigs := mapToSlice(snap.GetResources(string(resourcev3.ExtensionConfigType)))

	// Append to the appropriate slice
	switch typ {
	case resourcev3.ClusterType:
		clusters = append(clusters, resource)
	case resourcev3.RouteType:
		routes = append(routes, resource)
	case resourcev3.ListenerType:
		listeners = append(listeners, resource)
	case resourcev3.EndpointType:
		endpoints = append(endpoints, resource)
	case resourcev3.SecretType:
		secrets = append(secrets, resource)
	case resourcev3.RuntimeType:
		runtimes = append(runtimes, resource)
	case resourcev3.ExtensionConfigType:
		extConfigs = append(extConfigs, resource)
	default:
		return fmt.Errorf("unsupported resource type: %s", typ)
	}

	newSnap, _ := cachev3.NewSnapshot(
		"snap-generic-"+resource.(interface{ GetName() string }).GetName(),
		map[resourcev3.Type][]types.Resource{
			resourcev3.ClusterType:         clusters,
			resourcev3.RouteType:           routes,
			resourcev3.ListenerType:        listeners,
			resourcev3.EndpointType:        endpoints,
			resourcev3.SecretType:          secrets,
			resourcev3.RuntimeType:         runtimes,
			resourcev3.ExtensionConfigType: extConfigs,
		},
	)
	return sm.Cache.SetSnapshot(context.TODO(), sm.NodeID, newSnap)
}

// RemoveResource removes any resource by name dynamically
func (sm *SnapshotManager) RemoveResource(name string, typ resourcev3.Type) error {
	snap, _ := sm.Cache.GetSnapshot(sm.NodeID)

	// Convert all to slices
	clusters := mapToSlice(snap.GetResources(string(resourcev3.ClusterType)))
	routes := mapToSlice(snap.GetResources(string(resourcev3.RouteType)))
	listeners := mapToSlice(snap.GetResources(string(resourcev3.ListenerType)))
	endpoints := mapToSlice(snap.GetResources(string(resourcev3.EndpointType)))
	secrets := mapToSlice(snap.GetResources(string(resourcev3.SecretType)))
	runtimes := mapToSlice(snap.GetResources(string(resourcev3.RuntimeType)))
	extConfigs := mapToSlice(snap.GetResources(string(resourcev3.ExtensionConfigType)))

	// Filter the target type
	switch typ {
	case resourcev3.ClusterType:
		clusters = filterResourcesByName[*clusterv3.Cluster](clusters, name)
	case resourcev3.RouteType:
		routes = filterResourcesByName[*routev3.RouteConfiguration](routes, name)
	case resourcev3.ListenerType:
		listeners = filterResourcesByName[*listenerv3.Listener](listeners, name)
	case resourcev3.EndpointType:
		endpoints = filterResourcesByName[types.Resource](endpoints, name) // ClusterLoadAssignment
	case resourcev3.SecretType:
		secrets = filterResourcesByName[types.Resource](secrets, name) // Secret
	case resourcev3.RuntimeType:
		runtimes = filterResourcesByName[types.Resource](runtimes, name)
	case resourcev3.ExtensionConfigType:
		extConfigs = filterResourcesByName[types.Resource](extConfigs, name)
	default:
		return fmt.Errorf("unsupported resource type: %s", typ)
	}

	newSnap, _ := cachev3.NewSnapshot(
		"snap-remove-generic-"+name,
		map[resourcev3.Type][]types.Resource{
			resourcev3.ClusterType:         clusters,
			resourcev3.RouteType:           routes,
			resourcev3.ListenerType:        listeners,
			resourcev3.EndpointType:        endpoints,
			resourcev3.SecretType:          secrets,
			resourcev3.RuntimeType:         runtimes,
			resourcev3.ExtensionConfigType: extConfigs,
		},
	)
	return sm.Cache.SetSnapshot(context.TODO(), sm.NodeID, newSnap)
}

// ListResources returns all resources of a given type
func (sm *SnapshotManager) ListResources(typ resourcev3.Type) ([]types.Resource, error) {
	snap, err := sm.Cache.GetSnapshot(sm.NodeID)
	if err != nil {
		return nil, err
	}
	return mapToSlice(snap.GetResources(string(typ))), nil
}

// GetResource retrieves a resource by name and type
func (sm *SnapshotManager) GetResource(name string, typ resourcev3.Type) (types.Resource, error) {
	snap, err := sm.Cache.GetSnapshot(sm.NodeID)
	if err != nil {
		return nil, err
	}
	r, ok := snap.GetResources(string(typ))[name]
	if !ok {
		return nil, fmt.Errorf("%s %s not found", typ, name)
	}
	return r, nil
}

// ---------------- Generic filter helper ----------------
func filterResourcesByName[T any](resources []types.Resource, name string) []types.Resource {
	filtered := []types.Resource{}
	for _, r := range resources {
		if getNameFunc, ok := r.(interface{ GetName() string }); ok {
			if getNameFunc.GetName() != name {
				filtered = append(filtered, r)
			}
		} else {
			// fallback, include unknown type
			filtered = append(filtered, r)
		}
	}
	return filtered
}