Newer
Older
EnvoyControlPlane / internal / snapshot / resource_crud.go
package snapshot

import (
	"context"
	"fmt"
	"sort"
	"time"

	listenerv3 "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3"
	"github.com/envoyproxy/go-control-plane/pkg/cache/types"
	cachev3 "github.com/envoyproxy/go-control-plane/pkg/cache/v3"
	resourcev3 "github.com/envoyproxy/go-control-plane/pkg/resource/v3"

	internallog "envoy-control-plane/internal/log"
	"envoy-control-plane/internal/storage"
)

// AppendFilterChainToListener loads the current listener from the cache, appends the provided
// FilterChain to its list of FilterChains, and updates the cache with the new snapshot.
func (sm *SnapshotManager) AppendFilterChainToListener(ctx context.Context, listenerName string, newFilterChain *listenerv3.FilterChain) error {
	log := internallog.LogFromContext(ctx)

	// 1. Get the current Listener from the cache
	resource, err := sm.GetResourceFromCache(listenerName, resourcev3.ListenerType)
	if err != nil {
		return fmt.Errorf("failed to get listener '%s' from cache: %w", listenerName, err)
	}

	listener, ok := resource.(*listenerv3.Listener)
	if !ok {
		return fmt.Errorf("resource '%s' is not a Listener type", listenerName)
	}

	// 2. Append the new FilterChain to the listener's list of filter chains.
	listener.FilterChains = append(listener.FilterChains, newFilterChain)
	log.Infof("Appended new filter chain (match: %v) to listener '%s'", newFilterChain.FilterChainMatch, listenerName)

	// 3. Create a new snapshot with the modified listener (rest of logic remains similar)
	snap, err := sm.Cache.GetSnapshot(sm.NodeID)
	if err != nil {
		return fmt.Errorf("failed to get snapshot for modification: %w", err)
	}

	// Get all current resources
	resources := sm.getAllResourcesFromSnapshot(snap)

	// Replace the old listener with the modified one
	listenerList, ok := resources[resourcev3.ListenerType]
	if !ok {
		return fmt.Errorf("listener resource type not present in snapshot")
	}

	foundAndReplaced := false
	for i, res := range listenerList {
		if namer, ok := res.(interface{ GetName() string }); ok && namer.GetName() == listenerName {
			listenerList[i] = listener
			foundAndReplaced = true
			break
		}
	}

	if !foundAndReplaced {
		return fmt.Errorf("failed to locate listener '%s' in current resource list for replacement", listenerName)
	}

	// Create and set the new snapshot
	version := fmt.Sprintf("listener-update-%s-%d", listenerName, time.Now().UnixNano())
	newSnap, err := cachev3.NewSnapshot(version, resources)
	if err != nil {
		return fmt.Errorf("failed to create new snapshot: %w", err)
	}

	if err := sm.Cache.SetSnapshot(ctx, sm.NodeID, newSnap); err != nil {
		return fmt.Errorf("failed to set new snapshot: %w", err)
	}
	sm.FlushCacheToDB(ctx, storage.DeleteLogical)

	log.Infof("Successfully updated listener '%s' in cache with new filter chain.", listenerName)

	return nil
}

// ServerNamesEqual checks if two slices of server names contain the same elements, ignoring order.
// This is necessary because server_names is a list in the Envoy API, and order shouldn't matter for a match.
func ServerNamesEqual(a, b []string) bool {
	if len(a) != len(b) {
		return false
	}

	// Sort copies of the slices to perform an ordered comparison
	sort.Strings(a)
	sort.Strings(b)

	for i := range a {
		if a[i] != b[i] {
			return false
		}
	}
	return true
}

// UpdateFilterChainOfListener iterates through a listener's filter chains and replaces
// the one that matches the new filter chain's ServerNames.
func (sm *SnapshotManager) UpdateFilterChainOfListener(ctx context.Context, listenerName string, newFilterChain *listenerv3.FilterChain) error {
	log := internallog.LogFromContext(ctx)

	// 1. Get the current Listener from the cache
	resource, err := sm.GetResourceFromCache(listenerName, resourcev3.ListenerType)
	if err != nil {
		return fmt.Errorf("failed to get listener '%s' from cache: %w", listenerName, err)
	}

	listener, ok := resource.(*listenerv3.Listener)
	if !ok {
		return fmt.Errorf("resource '%s' is not a Listener type", listenerName)
	}
	if newFilterChain == nil {
		return fmt.Errorf("new filter chain is nil")
	}

	// Get the server names from the new filter chain for matching
	newServerNames := newFilterChain.GetFilterChainMatch().GetServerNames()
	if len(newServerNames) == 0 {
		// If the new filter chain has no server names, it should typically be considered the default,
		// but explicit domain matching is safer for replacement. For this implementation,
		// we require at least one ServerName to perform a targeted update.
		return fmt.Errorf("new filter chain must specify at least one ServerName for targeted replacement")
	}

	// 2. Iterate and attempt to find the matching filter chain
	foundMatch := false

	// We create a new slice to hold the updated list of filter chains
	var updatedChains []*listenerv3.FilterChain

	for _, existingChain := range listener.FilterChains {
		existingServerNames := existingChain.GetFilterChainMatch().GetServerNames()

		// NOTE: The ServerNamesEqual implementation sorts the slices *in place*.
		// This side-effect is a common bug source. The existing function *should* use copies.
		// Assuming ServerNamesEqual is fixed (or this bug is accepted), the logic holds.
		// We'll keep the call as-is for the fix, but note the potential bug in ServerNamesEqual.
		if ServerNamesEqual(existingServerNames, newServerNames) {
			// Match found! Replace the existing chain with the new one.
			updatedChains = append(updatedChains, newFilterChain)
			foundMatch = true
			log.Debugf("Replaced filter chain with match: %v in listener '%s'", newServerNames, listenerName)
			continue
		}

		// Keep the existing chain if it does not match
		updatedChains = append(updatedChains, existingChain)
	}

	// 3. Handle the result
	if !foundMatch {
		return fmt.Errorf("no existing filter chain found on listener '%s' with matching server names: %v",
			listenerName, newServerNames)
	}

	// 4. Update the listener with the new slice of filter chains
	listener.FilterChains = updatedChains

	// 5. Get current snapshot to extract all resources for the new snapshot
	snap, err := sm.Cache.GetSnapshot(sm.NodeID)
	if err != nil {
		return fmt.Errorf("failed to get snapshot for modification: %w", err)
	}

	// Get all current resources (THIS WAS MISSING)
	resources := sm.getAllResourcesFromSnapshot(snap)

	// Replace the old listener with the modified one in the resource list
	listenerList, ok := resources[resourcev3.ListenerType]
	if !ok {
		return fmt.Errorf("listener resource type not present in snapshot")
	}

	foundAndReplaced := false
	for i, res := range listenerList {
		if namer, ok := res.(interface{ GetName() string }); ok && namer.GetName() == listenerName {
			// The `listener` variable already holds the modified listener
			listenerList[i] = listener
			foundAndReplaced = true
			break
		}
	}

	if !foundAndReplaced {
		// This should not happen if GetResourceFromCache succeeded, but is a good safeguard.
		return fmt.Errorf("failed to locate listener '%s' in current resource list for replacement", listenerName)
	}

	// 6. Create and set the new snapshot
	version := fmt.Sprintf("listener-update-chain-%s-%d", listenerName, time.Now().UnixNano())
	newSnap, err := cachev3.NewSnapshot(version, resources)
	if err != nil {
		return fmt.Errorf("failed to create new snapshot: %w", err)
	}

	if err := sm.Cache.SetSnapshot(ctx, sm.NodeID, newSnap); err != nil {
		return fmt.Errorf("failed to set new snapshot: %w", err)
	}
	sm.FlushCacheToDB(ctx, storage.DeleteLogical)
	log.Infof("Successfully updated filter chain (match: %v) on listener '%s'", newServerNames, listenerName)

	return nil
}

// RemoveFilterChainFromListener loads the current listener from the cache, removes the
// FilterChain that matches the provided ServerNames from the listener's list of FilterChains,
// and updates the cache with the new snapshot.
func (sm *SnapshotManager) RemoveFilterChainFromListener(ctx context.Context, listenerName string, serverNames []string) error {
	log := internallog.LogFromContext(ctx)

	// 1. Validate input and get the current Listener from the cache
	if len(serverNames) == 0 {
		return fmt.Errorf("failed to get server names from filter chain")
	}

	// Use ServerNames for matching, consistent with UpdateFilterChainOfListener
	if len(serverNames) == 0 {
		return fmt.Errorf("target filter chain match must specify at least one ServerName for targeted removal")
	}

	resource, err := sm.GetResourceFromCache(listenerName, resourcev3.ListenerType)
	if err != nil {
		return fmt.Errorf("failed to get listener '%s' from cache: %w", listenerName, err)
	}

	listener, ok := resource.(*listenerv3.Listener)
	if !ok {
		return fmt.Errorf("resource '%s' is not a Listener type", listenerName)
	}

	// 2. Iterate and attempt to find and remove the matching filter chain
	foundMatch := false
	var updatedChains []*listenerv3.FilterChain // New slice for chains to keep

	for _, existingChain := range listener.FilterChains {
		existingServerNames := existingChain.GetFilterChainMatch().GetServerNames()
		if len(serverNames) == 1 && serverNames[0] == "(default)" && len(existingServerNames) == 0 {
			foundMatch = true
			log.Debugf("Removing default filter chain from listener '%s'", listenerName)
			continue
		}
		// Use the provided ServerNamesEqual for matching
		if ServerNamesEqual(existingServerNames, serverNames) {
			// Match found! DO NOT append this chain, effectively removing it.
			foundMatch = true
			log.Debugf("Removing filter chain with match: %v from listener '%s'", serverNames, listenerName)
			continue
		}

		// Keep the existing chain if it does not match
		updatedChains = append(updatedChains, existingChain)
	}

	// 3. Handle the result
	if !foundMatch {
		return fmt.Errorf("no existing filter chain found on listener '%s' with matching server names: %v",
			listenerName, serverNames)
	}

	// 4. Update the listener with the new slice of filter chains
	listener.FilterChains = updatedChains

	// 5. Get current snapshot to extract all resources for the new snapshot
	snap, err := sm.Cache.GetSnapshot(sm.NodeID)
	if err != nil {
		return fmt.Errorf("failed to get snapshot for modification: %w", err)
	}

	resources := sm.getAllResourcesFromSnapshot(snap)

	// Replace the old listener with the modified one in the resource list
	listenerList, ok := resources[resourcev3.ListenerType]
	if !ok {
		return fmt.Errorf("listener resource type not present in snapshot")
	}

	foundAndReplaced := false
	for i, res := range listenerList {
		if namer, ok := res.(interface{ GetName() string }); ok && namer.GetName() == listenerName {
			listenerList[i] = listener // Replace with the modified listener
			foundAndReplaced = true
			break
		}
	}

	if !foundAndReplaced {
		// Should not happen if GetResourceFromCache succeeded.
		return fmt.Errorf("failed to locate listener '%s' in current resource list for replacement during removal", listenerName)
	}

	// 6. Create and set the new snapshot
	version := fmt.Sprintf("listener-remove-chain-%s-%d", listenerName, time.Now().UnixNano())
	newSnap, err := cachev3.NewSnapshot(version, resources)
	if err != nil {
		return fmt.Errorf("failed to create new snapshot: %w", err)
	}

	if err := sm.Cache.SetSnapshot(ctx, sm.NodeID, newSnap); err != nil {
		return fmt.Errorf("failed to set new snapshot: %w", err)
	}
	// Assume FlushCacheToDB is a necessary final step after snapshot update
	sm.FlushCacheToDB(ctx, storage.DeleteLogical)
	log.Infof("Successfully removed filter chain (match: %v) from listener '%s'", serverNames, listenerName)

	return nil
}

// AddResourceToSnapshot adds any resource to the snapshot dynamically
func (sm *SnapshotManager) AddResourceToSnapshot(resource types.Resource, typ resourcev3.Type) error {
	snap, err := sm.Cache.GetSnapshot(sm.NodeID)
	if err != nil {
		return fmt.Errorf("failed to get snapshot from cache: %w", err)
	}
	resources := sm.getAllResourcesFromSnapshot(snap)

	// Append to the appropriate slice
	switch typ {
	case resourcev3.ClusterType:
		resources[resourcev3.ClusterType] = append(resources[resourcev3.ClusterType], resource)
	case resourcev3.ListenerType:
		resources[resourcev3.ListenerType] = append(resources[resourcev3.ListenerType], resource)
	case resourcev3.EndpointType, resourcev3.SecretType, resourcev3.RuntimeType:
		resources[typ] = append(resources[typ], resource)
	default:
		return fmt.Errorf("unsupported resource type: %s", typ)
	}

	resourceNamer, ok := resource.(interface{ GetName() string })
	if !ok {
		return fmt.Errorf("resource of type %s does not implement GetName()", typ)
	}

	newSnap, _ := cachev3.NewSnapshot(
		"snap-generic-"+resourceNamer.GetName()+"-"+time.Now().Format(time.RFC3339),
		resources,
	)
	return sm.Cache.SetSnapshot(context.TODO(), sm.NodeID, newSnap)
}

// RemoveResource removes any resource by name dynamically
func (sm *SnapshotManager) RemoveResource(name string, typ resourcev3.Type, strategy storage.DeleteStrategy) error {
	snap, _ := sm.Cache.GetSnapshot(sm.NodeID)
	resources := sm.getAllResourcesFromSnapshot(snap)

	// Flag to check if resource was found in cache
	var resourceFound = false

	// Filter the target type
	if targetResources, ok := resources[typ]; ok {
		resources[typ], resourceFound = filterAndCheckResourcesByName(targetResources, name)
	}

	if strategy == storage.DeleteActual {
		if resourceFound {
			return fmt.Errorf("actual delete requested but resource %s of type %s still exists in cache", name, typ)
		}
		if typ == resourcev3.ClusterType {
			if err := sm.DB.RemoveCluster(context.TODO(), name); err != nil {
				return fmt.Errorf("failed to delete cluster %s from DB: %w", name, err)
			}
			return nil
		}
		if typ == resourcev3.ListenerType {
			if err := sm.DB.RemoveListener(context.TODO(), name); err != nil {
				return fmt.Errorf("failed to delete listener %s from DB: %w", name, err)
			}
			return nil
		}
		return fmt.Errorf("actual delete not supported for resource type: %s", typ)
	}

	if !resourceFound {
		return fmt.Errorf("resource %s of type %s not found in cache", name, typ)
	}

	newSnap, _ := cachev3.NewSnapshot(
		"snap-remove-generic-"+name,
		resources,
	)

	if err := sm.Cache.SetSnapshot(context.TODO(), sm.NodeID, newSnap); err != nil {
		return fmt.Errorf("failed to set snapshot: %w", err)
	}

	if err := sm.FlushCacheToDB(context.TODO(), strategy); err != nil {
		return fmt.Errorf("failed to flush cache to DB: %w", err)
	}
	return nil
}

// GetResourceFromCache retrieves a resource by name and type from the cache.
func (sm *SnapshotManager) GetResourceFromCache(name string, typ resourcev3.Type) (types.Resource, error) {
	snap, err := sm.Cache.GetSnapshot(sm.NodeID)
	if err != nil {
		return nil, err
	}
	r, ok := snap.GetResources(string(typ))[name]
	if !ok {
		return nil, fmt.Errorf("%s resource %s not found in cache", typ, name)
	}

	// We rely on the type given to be correct, as all xDS resources implement GetName().
	return r, nil
}

// getAllResourcesFromSnapshot retrieves all known resource types from a snapshot as a map.
func (sm *SnapshotManager) getAllResourcesFromSnapshot(snap cachev3.ResourceSnapshot) map[resourcev3.Type][]types.Resource {
	// Only include types that might be manipulated by the generic functions
	resources := map[resourcev3.Type][]types.Resource{
		resourcev3.ClusterType:  mapToSlice(snap.GetResources(string(resourcev3.ClusterType))),
		resourcev3.ListenerType: mapToSlice(snap.GetResources(string(resourcev3.ListenerType))),
		// resourcev3.EndpointType: mapToSlice(snap.GetResources(string(resourcev3.EndpointType))),
		// resourcev3.SecretType:   mapToSlice(snap.GetResources(string(resourcev3.SecretType))),
		// resourcev3.RuntimeType:  mapToSlice(snap.GetResources(string(resourcev3.RuntimeType))),
		// Include other types as needed
	}
	return resources
}

// mapToSlice converts a map of named resources to a slice of resources.
func mapToSlice(m map[string]types.Resource) []types.Resource {
	out := make([]types.Resource, 0, len(m))
	for _, r := range m {
		out = append(out, r)
	}
	return out
}

// filterAndCheckResourcesByName filters a slice of resources by name,
// returning the filtered slice and a boolean indicating if the named resource was found.
func filterAndCheckResourcesByName(resources []types.Resource, name string) ([]types.Resource, bool) {
	filtered := []types.Resource{}
	var found = false
	for _, r := range resources {
		if namer, ok := r.(interface{ GetName() string }); ok {
			if namer.GetName() != name {
				filtered = append(filtered, r)
			} else {
				found = true
			}
		} else {
			// fallback, include unknown type
			filtered = append(filtered, r)
		}
	}
	return filtered, found
}