Newer
Older
EnvoyControlPlane / internal / snapshot.go
package internal

import (
	"context"
	"encoding/json"
	"os"
	"time"

	clusterv3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3"
	routev3 "github.com/envoyproxy/go-control-plane/envoy/config/route/v3"
	"github.com/envoyproxy/go-control-plane/pkg/cache/types"
	cachev3 "github.com/envoyproxy/go-control-plane/pkg/cache/v3"
	resourcev3 "github.com/envoyproxy/go-control-plane/pkg/resource/v3"
	"google.golang.org/protobuf/types/known/durationpb"
)

// SnapshotManager wraps a SnapshotCache and provides file loading/modifying
type SnapshotManager struct {
	Cache  cachev3.SnapshotCache
	NodeID string
}

// NewSnapshotManager creates a new manager for a given cache and node
func NewSnapshotManager(cache cachev3.SnapshotCache, nodeID string) *SnapshotManager {
	return &SnapshotManager{
		Cache:  cache,
		NodeID: nodeID,
	}
}

// LoadSnapshotFromFile loads a snapshot from a JSON file
func (sm *SnapshotManager) LoadSnapshotFromFile(filePath string) error {
	data, err := os.ReadFile(filePath)
	if err != nil {
		return err
	}

	var raw map[string][]json.RawMessage
	if err := json.Unmarshal(data, &raw); err != nil {
		return err
	}

	resources := make(map[resourcev3.Type][]types.Resource)

	for typStr, arr := range raw {
		typ := resourcev3.Type(typStr)
		for _, r := range arr {
			switch typ {
			case resourcev3.ClusterType:
				var c clusterv3.Cluster
				if err := json.Unmarshal(r, &c); err != nil {
					return err
				}
				resources[typ] = append(resources[typ], &c)
			case resourcev3.RouteType:
				var rt routev3.RouteConfiguration
				if err := json.Unmarshal(r, &rt); err != nil {
					return err
				}
				resources[typ] = append(resources[typ], &rt)
			default:
				// skip unknown types
			}
		}
	}

	snap, _ := cachev3.NewSnapshot("snap-from-file", resources)
	return sm.Cache.SetSnapshot(context.TODO(), sm.NodeID, snap)
}

// AddCluster adds a cluster to the snapshot
func (sm *SnapshotManager) AddCluster(cluster *clusterv3.Cluster) error {
	snap, err := sm.Cache.GetSnapshot(sm.NodeID)
	var clusters []types.Resource
	var routes []types.Resource
	if err != nil {
		clusters = []types.Resource{}
		routes = []types.Resource{}
	} else {
		// Convert map to slice
		clusters = mapToSlice(snap.GetResources(string(resourcev3.ClusterType)))
		routes = mapToSlice(snap.GetResources(string(resourcev3.RouteType)))
	}

	clusters = append(clusters, cluster)

	newSnap, _ := cachev3.NewSnapshot(
		"snap-"+cluster.GetName(),
		map[resourcev3.Type][]types.Resource{
			resourcev3.ClusterType: clusters,
			resourcev3.RouteType:   routes,
		},
	)

	return sm.Cache.SetSnapshot(context.TODO(), sm.NodeID, newSnap)
}

// AddRoute adds a route configuration to the snapshot
func (sm *SnapshotManager) AddRoute(route *routev3.RouteConfiguration) error {
	snap, err := sm.Cache.GetSnapshot(sm.NodeID)
	var clusters []types.Resource
	var routes []types.Resource
	if err != nil {
		clusters = []types.Resource{}
		routes = []types.Resource{}
	} else {
		clusters = mapToSlice(snap.GetResources(string(resourcev3.ClusterType)))
		routes = mapToSlice(snap.GetResources(string(resourcev3.RouteType)))
	}

	routes = append(routes, route)

	newSnap, _ := cachev3.NewSnapshot(
		"snap-"+route.GetName(),
		map[resourcev3.Type][]types.Resource{
			resourcev3.ClusterType: clusters,
			resourcev3.RouteType:   routes,
		},
	)

	return sm.Cache.SetSnapshot(context.TODO(), sm.NodeID, newSnap)
}

// RemoveRoute removes a route configuration by name
func (sm *SnapshotManager) RemoveRoute(name string) error {
	snap, err := sm.Cache.GetSnapshot(sm.NodeID)
	if err != nil {
		return err
	}

	// Keep clusters unchanged
	clusters := mapToSlice(snap.GetResources(string(resourcev3.ClusterType)))
	routes := []types.Resource{}

	// Filter routes: keep only those that do not match the given name
	for _, r := range snap.GetResources(string(resourcev3.RouteType)) {
		if rt, ok := r.(*routev3.RouteConfiguration); ok && rt.GetName() != name {
			routes = append(routes, rt)
		}
	}

	// Create a new snapshot with the filtered route list
	newSnap, _ := cachev3.NewSnapshot(
		"snap-remove-route-"+name,
		map[resourcev3.Type][]types.Resource{
			resourcev3.ClusterType: clusters,
			resourcev3.RouteType:   routes,
		},
	)

	return sm.Cache.SetSnapshot(context.TODO(), sm.NodeID, newSnap)
}

// RemoveCluster removes a cluster by name
func (sm *SnapshotManager) RemoveCluster(name string) error {
	snap, err := sm.Cache.GetSnapshot(sm.NodeID)
	if err != nil {
		return err
	}

	clusters := []types.Resource{}
	routes := mapToSlice(snap.GetResources(string(resourcev3.RouteType)))

	for _, r := range snap.GetResources(string(resourcev3.ClusterType)) {
		if c, ok := r.(*clusterv3.Cluster); ok && c.GetName() != name {
			clusters = append(clusters, c)
		}
	}

	newSnap, _ := cachev3.NewSnapshot(
		"snap-remove-"+name,
		map[resourcev3.Type][]types.Resource{
			resourcev3.ClusterType: clusters,
			resourcev3.RouteType:   routes,
		},
	)

	return sm.Cache.SetSnapshot(context.TODO(), sm.NodeID, newSnap)
}

// SaveSnapshotToFile saves snapshot to a JSON file
func (sm *SnapshotManager) SaveSnapshotToFile(filePath string) error {
	snap, err := sm.Cache.GetSnapshot(sm.NodeID)
	if err != nil {
		return err
	}

	out := make(map[string][]interface{})

	for _, r := range snap.GetResources(string(resourcev3.ClusterType)) {
		if c, ok := r.(*clusterv3.Cluster); ok {
			out[string(resourcev3.ClusterType)] = append(out[string(resourcev3.ClusterType)], c)
		}
	}

	for _, r := range snap.GetResources(string(resourcev3.RouteType)) {
		if rt, ok := r.(*routev3.RouteConfiguration); ok {
			out[string(resourcev3.RouteType)] = append(out[string(resourcev3.RouteType)], rt)
		}
	}

	data, err := json.MarshalIndent(out, "", "  ")
	if err != nil {
		return err
	}

	return os.WriteFile(filePath, data, 0644)
}

// ----------------- Helpers -----------------

// Convert map[name]Resource → slice of Resource
func mapToSlice(m map[string]types.Resource) []types.Resource {
	out := make([]types.Resource, 0, len(m))
	for _, r := range m {
		out = append(out, r)
	}
	return out
}

// NewCluster creates a simple cluster
func NewCluster(name string) *clusterv3.Cluster {
	return &clusterv3.Cluster{
		Name:           name,
		ConnectTimeout: durationpb.New(5 * time.Second),
		ClusterDiscoveryType: &clusterv3.Cluster_Type{
			Type: clusterv3.Cluster_EDS,
		},
		LbPolicy: clusterv3.Cluster_ROUND_ROBIN,
	}
}

// NewRoute creates a simple route tied to a cluster
func NewRoute(name, clusterName, prefix string) *routev3.RouteConfiguration {
	return &routev3.RouteConfiguration{
		Name: name,
		VirtualHosts: []*routev3.VirtualHost{
			{
				Name:    "vh-" + name,
				Domains: []string{"*"},
				Routes: []*routev3.Route{
					{
						Match: &routev3.RouteMatch{
							PathSpecifier: &routev3.RouteMatch_Prefix{Prefix: prefix},
						},
						Action: &routev3.Route_Route{
							Route: &routev3.RouteAction{
								ClusterSpecifier: &routev3.RouteAction_Cluster{Cluster: clusterName},
							},
						},
					},
				},
			},
		},
	}
}