package snapshot
import (
"context"
"fmt"
"sort"
"strings"
"time"
corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
listenerv3 "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3"
secretv3 "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3"
"github.com/envoyproxy/go-control-plane/pkg/cache/types"
cachev3 "github.com/envoyproxy/go-control-plane/pkg/cache/v3"
resourcev3 "github.com/envoyproxy/go-control-plane/pkg/resource/v3"
"github.com/go-acme/lego/v4/log"
internallog "envoy-control-plane/internal/log"
internalcertapi "envoy-control-plane/internal/pkg/cert/api"
"envoy-control-plane/internal/pkg/storage"
)
// AppendFilterChainToListener loads the current listener from the cache, appends the provided
// FilterChain to its list of FilterChains, and updates the cache with the new snapshot.
func (sm *SnapshotManager) AppendFilterChainToListener(ctx context.Context, listenerName string, newFilterChain *listenerv3.FilterChain, upsert bool) error {
log := internallog.LogFromContext(ctx)
// 1. Get the current Listener from the cache
resource, err := sm.GetResourceFromCache(listenerName, resourcev3.ListenerType)
if err != nil {
return fmt.Errorf("failed to get listener '%s' from cache: %w", listenerName, err)
}
listener, ok := resource.(*listenerv3.Listener)
if !ok {
return fmt.Errorf("resource '%s' is not a Listener type", listenerName)
}
replaced := false
for i, fc := range listener.FilterChains {
if fc.Name == newFilterChain.Name {
if !upsert {
return fmt.Errorf("filter chain with name '%s' already exists in listener '%s' and upsert is false", newFilterChain.Name, listenerName)
}
listener.FilterChains[i] = newFilterChain
replaced = true
log.Infof("Replaced new filter chain (match: %v) to listener '%s'", newFilterChain.FilterChainMatch, listenerName)
break
}
}
if !replaced {
// 2. Append the new FilterChain to the listener's list of filter chains.
listener.FilterChains = append(listener.FilterChains, newFilterChain)
log.Infof("Appended new filter chain (match: %v) to listener '%s'", newFilterChain.FilterChainMatch, listenerName)
}
// 3. Create a new snapshot with the modified listener (rest of logic remains similar)
snap, err := sm.Cache.GetSnapshot(sm.NodeID)
if err != nil {
return fmt.Errorf("failed to get snapshot for modification: %w", err)
}
// Get all current resources
resources := sm.getAllResourcesFromSnapshot(snap)
// Replace the old listener with the modified one
listenerList, ok := resources[resourcev3.ListenerType]
if !ok {
return fmt.Errorf("listener resource type not present in snapshot")
}
foundAndReplaced := false
for i, res := range listenerList {
if namer, ok := res.(interface{ GetName() string }); ok && namer.GetName() == listenerName {
listenerList[i] = listener
foundAndReplaced = true
break
}
}
if !foundAndReplaced {
return fmt.Errorf("failed to locate listener '%s' in current resource list for replacement", listenerName)
}
// Create and set the new snapshot
version := fmt.Sprintf("listener-update-%s-%d", listenerName, time.Now().UnixNano())
newSnap, err := cachev3.NewSnapshot(version, resources)
if err != nil {
return fmt.Errorf("failed to create new snapshot: %w", err)
}
if err := sm.Cache.SetSnapshot(ctx, sm.NodeID, newSnap); err != nil {
return fmt.Errorf("failed to set new snapshot: %w", err)
}
sm.FlushCacheToDB(ctx, storage.DeleteLogical)
log.Infof("Successfully updated listener '%s' in cache with new filter chain.", listenerName)
return nil
}
// ServerNamesEqual checks if two slices of server names contain the same elements, ignoring order.
// This is necessary because server_names is a list in the Envoy API, and order shouldn't matter for a match.
func ServerNamesEqual(a, b []string) bool {
if len(a) != len(b) {
return false
}
// Sort copies of the slices to perform an ordered comparison
sort.Strings(a)
sort.Strings(b)
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
// UpdateFilterChainOfListener iterates through a listener's filter chains and replaces
// the one that matches the new filter chain's ServerNames.
func (sm *SnapshotManager) UpdateFilterChainOfListener(ctx context.Context, listenerName string, newFilterChain *listenerv3.FilterChain) error {
log := internallog.LogFromContext(ctx)
// 1. Get the current Listener from the cache
resource, err := sm.GetResourceFromCache(listenerName, resourcev3.ListenerType)
if err != nil {
return fmt.Errorf("failed to get listener '%s' from cache: %w", listenerName, err)
}
listener, ok := resource.(*listenerv3.Listener)
if !ok {
return fmt.Errorf("resource '%s' is not a Listener type", listenerName)
}
if newFilterChain == nil {
return fmt.Errorf("new filter chain is nil")
}
// Get the server names from the new filter chain for matching
newServerNames := newFilterChain.GetFilterChainMatch().GetServerNames()
if len(newServerNames) == 0 {
// If the new filter chain has no server names, it should typically be considered the default,
// but explicit domain matching is safer for replacement. For this implementation,
// we require at least one ServerName to perform a targeted update.
return fmt.Errorf("new filter chain must specify at least one ServerName for targeted replacement")
}
// 2. Iterate and attempt to find the matching filter chain
foundMatch := false
// We create a new slice to hold the updated list of filter chains
var updatedChains []*listenerv3.FilterChain
for _, existingChain := range listener.FilterChains {
existingServerNames := existingChain.GetFilterChainMatch().GetServerNames()
// NOTE: The ServerNamesEqual implementation sorts the slices *in place*.
// This side-effect is a common bug source. The existing function *should* use copies.
// We'll keep the call as-is for the fix, but note the potential bug in ServerNamesEqual.
if ServerNamesEqual(existingServerNames, newServerNames) {
// Match found! Replace the existing chain with the new one.
updatedChains = append(updatedChains, newFilterChain)
foundMatch = true
log.Debugf("Replaced filter chain with match: %v in listener '%s'", newServerNames, listenerName)
continue
}
// Keep the existing chain if it does not match
updatedChains = append(updatedChains, existingChain)
}
// 3. Handle the result
if !foundMatch {
return fmt.Errorf("no existing filter chain found on listener '%s' with matching server names: %v",
listenerName, newServerNames)
}
// 4. Update the listener with the new slice of filter chains
listener.FilterChains = updatedChains
// 5. Get current snapshot to extract all resources for the new snapshot
snap, err := sm.Cache.GetSnapshot(sm.NodeID)
if err != nil {
return fmt.Errorf("failed to get snapshot for modification: %w", err)
}
// Get all current resources (THIS WAS MISSING)
resources := sm.getAllResourcesFromSnapshot(snap)
// Replace the old listener with the modified one in the resource list
listenerList, ok := resources[resourcev3.ListenerType]
if !ok {
return fmt.Errorf("listener resource type not present in snapshot")
}
foundAndReplaced := false
for i, res := range listenerList {
if namer, ok := res.(interface{ GetName() string }); ok && namer.GetName() == listenerName {
// The `listener` variable already holds the modified listener
listenerList[i] = listener
foundAndReplaced = true
break
}
}
if !foundAndReplaced {
// This should not happen if GetResourceFromCache succeeded, but is a good safeguard.
return fmt.Errorf("failed to locate listener '%s' in current resource list for replacement", listenerName)
}
// 6. Create and set the new snapshot
version := fmt.Sprintf("listener-update-chain-%s-%d", listenerName, time.Now().UnixNano())
newSnap, err := cachev3.NewSnapshot(version, resources)
if err != nil {
return fmt.Errorf("failed to create new snapshot: %w", err)
}
if err := sm.Cache.SetSnapshot(ctx, sm.NodeID, newSnap); err != nil {
return fmt.Errorf("failed to set new snapshot: %w", err)
}
sm.FlushCacheToDB(ctx, storage.DeleteLogical)
log.Infof("Successfully updated filter chain (match: %v) on listener '%s'", newServerNames, listenerName)
return nil
}
// RemoveFilterChainFromListener loads the current listener from the cache, removes the
// FilterChain that matches the provided ServerNames from the listener's list of FilterChains,
// and updates the cache with the new snapshot.
func (sm *SnapshotManager) RemoveFilterChainFromListener(ctx context.Context, listenerName string, serverNames []string) error {
log := internallog.LogFromContext(ctx)
// 1. Validate input and get the current Listener from the cache
if len(serverNames) == 0 {
return fmt.Errorf("target filter chain match must specify at least one ServerName for targeted removal")
}
resource, err := sm.GetResourceFromCache(listenerName, resourcev3.ListenerType)
if err != nil {
return fmt.Errorf("failed to get listener '%s' from cache: %w", listenerName, err)
}
listener, ok := resource.(*listenerv3.Listener)
if !ok {
return fmt.Errorf("resource '%s' is not a Listener type", listenerName)
}
// 2. Iterate and attempt to find and remove the matching filter chain
foundMatch := false
var updatedChains []*listenerv3.FilterChain // New slice for chains to keep
for _, existingChain := range listener.FilterChains {
existingServerNames := existingChain.GetFilterChainMatch().GetServerNames()
if len(serverNames) == 1 && serverNames[0] == "(default)" && len(existingServerNames) == 0 {
foundMatch = true
log.Debugf("Removing default filter chain from listener '%s'", listenerName)
continue
}
// Use the provided ServerNamesEqual for matching
if ServerNamesEqual(existingServerNames, serverNames) {
// Match found! DO NOT append this chain, effectively removing it.
foundMatch = true
log.Debugf("Removing filter chain with match: %v from listener '%s'", serverNames, listenerName)
continue
}
// Keep the existing chain if it does not match
updatedChains = append(updatedChains, existingChain)
}
// 3. Handle the result
if !foundMatch {
return fmt.Errorf("no existing filter chain found on listener '%s' with matching server names: %v",
listenerName, serverNames)
}
// 4. Update the listener with the new slice of filter chains
listener.FilterChains = updatedChains
// 5. Get current snapshot to extract all resources for the new snapshot
snap, err := sm.Cache.GetSnapshot(sm.NodeID)
if err != nil {
return fmt.Errorf("failed to get snapshot for modification: %w", err)
}
resources := sm.getAllResourcesFromSnapshot(snap)
// Replace the old listener with the modified one in the resource list
listenerList, ok := resources[resourcev3.ListenerType]
if !ok {
return fmt.Errorf("listener resource type not present in snapshot")
}
foundAndReplaced := false
for i, res := range listenerList {
if namer, ok := res.(interface{ GetName() string }); ok && namer.GetName() == listenerName {
listenerList[i] = listener // Replace with the modified listener
foundAndReplaced = true
break
}
}
if !foundAndReplaced {
// Should not happen if GetResourceFromCache succeeded.
return fmt.Errorf("failed to locate listener '%s' in current resource list for replacement during removal", listenerName)
}
// 6. Create and set the new snapshot
version := fmt.Sprintf("listener-remove-chain-%s-%d", listenerName, time.Now().UnixNano())
newSnap, err := cachev3.NewSnapshot(version, resources)
if err != nil {
return fmt.Errorf("failed to create new snapshot: %w", err)
}
if err := sm.Cache.SetSnapshot(ctx, sm.NodeID, newSnap); err != nil {
return fmt.Errorf("failed to set new snapshot: %w", err)
}
// Assume FlushCacheToDB is a necessary final step after snapshot update
sm.FlushCacheToDB(ctx, storage.DeleteLogical)
log.Infof("Successfully removed filter chain (match: %v) from listener '%s'", serverNames, listenerName)
return nil
}
// -----------------------------------------------------------------------------
// GENERIC RESOURCE CRUD (MODIFIED FOR SDS)
// -----------------------------------------------------------------------------
// AddResourceToSnapshot adds any resource to the snapshot dynamically
func (sm *SnapshotManager) AddResourceToSnapshot(resource types.Resource, typ resourcev3.Type, upsert bool) error {
snap, err := sm.Cache.GetSnapshot(sm.NodeID)
if err != nil {
return fmt.Errorf("failed to get snapshot from cache: %w", err)
}
resources := sm.getAllResourcesFromSnapshot(snap)
// Check if the resource already exists by name (optional but good practice)
resourceNamer, ok := resource.(interface{ GetName() string })
if !ok {
return fmt.Errorf("resource of type %s does not implement GetName()", typ)
}
// Check for duplicates before adding.
if existingResources, ok := resources[typ]; ok {
for _, r := range existingResources {
if namer, ok := r.(interface{ GetName() string }); ok && namer.GetName() == resourceNamer.GetName() {
if upsert {
log.Infof("Resource '%s' of type %s already exists. Performing update (upsert).", resourceNamer.GetName(), typ)
continue
}
return fmt.Errorf("resource '%s' of type %s already exists in cache", resourceNamer.GetName(), typ)
}
}
}
// Append to the appropriate slice
switch typ {
case resourcev3.ClusterType, resourcev3.ListenerType, resourcev3.SecretType: // ADDED: SecretType
// Ensure the resource type entry exists in the map
if _, ok := resources[typ]; !ok {
resources[typ] = make([]types.Resource, 0)
}
resources[typ] = append(resources[typ], resource)
case resourcev3.EndpointType, resourcev3.RuntimeType:
// These types might not be backed by DB storage, but are handled here if needed
resources[typ] = append(resources[typ], resource)
default:
return fmt.Errorf("unsupported resource type for dynamic addition: %s", typ)
}
newSnap, err := cachev3.NewSnapshot(
"snap-generic-"+resourceNamer.GetName()+"-"+time.Now().Format(time.RFC3339),
resources,
)
if err != nil {
return fmt.Errorf("failed to create new snapshot: %w", err)
}
return sm.Cache.SetSnapshot(context.TODO(), sm.NodeID, newSnap)
}
// RemoveResource removes any resource by name dynamically
func (sm *SnapshotManager) RemoveResource(name string, typ resourcev3.Type, strategy storage.DeleteStrategy) error {
snap, _ := sm.Cache.GetSnapshot(sm.NodeID)
resources := sm.getAllResourcesFromSnapshot(snap)
// Flag to check if resource was found in cache
var resourceFound = false
// Filter the target type
if targetResources, ok := resources[typ]; ok {
resources[typ], resourceFound = filterAndCheckResourcesByName(targetResources, name)
}
// Handle the DB removal logic based on strategy and type
if strategy == storage.DeleteActual {
// For DeleteActual, we remove from the cache and then delete from DB.
// If the resource was not found in the cache, we still try to delete it from the DB
// just in case the cache was inconsistent (though this indicates a problem).
// Note: The original code returned an error if resource was found in cache and DeleteActual was requested.
// We'll trust the caller to update the cache *before* calling DeleteActual strategy for consistency.
if typ == resourcev3.ClusterType {
if err := sm.DB.RemoveCluster(context.TODO(), name); err != nil {
return fmt.Errorf("failed to delete cluster %s from DB: %w", name, err)
}
return nil
}
if typ == resourcev3.ListenerType {
if err := sm.DB.RemoveListener(context.TODO(), name); err != nil {
return fmt.Errorf("failed to delete listener %s from DB: %w", name, err)
}
return nil
}
if typ == resourcev3.SecretType { // ADDED: SecretType Delete Actual
if err := sm.DB.RemoveSecret(context.TODO(), name); err != nil {
return fmt.Errorf("failed to delete secret %s from DB: %w", name, err)
}
return nil
}
return fmt.Errorf("actual delete not supported for resource type: %s", typ)
}
// For DeleteLogical or DeleteNone, we rely on the generic update flow.
if !resourceFound {
return fmt.Errorf("resource %s of type %s not found in cache", name, typ)
}
// Rebuild and set new snapshot with the resource removed
newSnap, err := cachev3.NewSnapshot(
"snap-remove-generic-"+name,
resources,
)
if err != nil {
return fmt.Errorf("failed to create snapshot after removal: %w", err)
}
if err := sm.Cache.SetSnapshot(context.TODO(), sm.NodeID, newSnap); err != nil {
return fmt.Errorf("failed to set snapshot: %w", err)
}
// Flush the updated (removed) snapshot to DB. This handles the 'DeleteLogical' strategy.
if err := sm.FlushCacheToDB(context.TODO(), strategy); err != nil {
return fmt.Errorf("failed to flush cache to DB: %w", err)
}
return nil
}
// GetResourceFromCache retrieves a resource by name and type from the cache.
func (sm *SnapshotManager) GetResourceFromCache(name string, typ resourcev3.Type) (types.Resource, error) {
snap, err := sm.Cache.GetSnapshot(sm.NodeID)
if err != nil {
return nil, err
}
r, ok := snap.GetResources(string(typ))[name]
if !ok {
return nil, fmt.Errorf("%s resource %s not found in cache", typ, name)
}
// We rely on the type given to be correct, as all xDS resources implement GetName().
return r, nil
}
// getAllResourcesFromSnapshot retrieves all known resource types from a snapshot as a map.
func (sm *SnapshotManager) getAllResourcesFromSnapshot(snap cachev3.ResourceSnapshot) map[resourcev3.Type][]types.Resource {
// Only include types that might be manipulated by the generic functions
resources := map[resourcev3.Type][]types.Resource{
resourcev3.ClusterType: mapToSlice(snap.GetResources(string(resourcev3.ClusterType))),
resourcev3.ListenerType: mapToSlice(snap.GetResources(string(resourcev3.ListenerType))),
resourcev3.SecretType: mapToSlice(snap.GetResources(string(resourcev3.SecretType))), // ADDED: SecretType
// resourcev3.EndpointType: mapToSlice(snap.GetResources(string(resourcev3.EndpointType))),
// resourcev3.RuntimeType: mapToSlice(snap.GetResources(string(resourcev3.RuntimeType))),
// Include other types as needed
}
return resources
}
// UpdateSDSSecretByName updates an existing Secret resource in the cache with a refreshed
// certificate and private key, using the Secret's exact name. This is useful when the
// Secret name convention is not based directly on the domain.
func (sm *SnapshotManager) UpdateSDSSecretByName(ctx context.Context, secretName string, internalcert *internalcertapi.Certificate) error {
log := internallog.LogFromContext(ctx)
secretType := resourcev3.SecretType
// 1. Get the current Secret from the cache using the provided name
resource, err := sm.GetResourceFromCache(secretName, secretType)
if err != nil {
return fmt.Errorf("failed to get Secret '%s' from cache: %w", secretName, err)
}
secret, ok := resource.(*secretv3.Secret)
if !ok {
return fmt.Errorf("resource '%s' is not a Secret type", secretName)
}
// 2. Validate and update the Secret data
if secret.GetType() == nil || secret.GetTlsCertificate() == nil {
return fmt.Errorf("secret '%s' is not a TlsCertificate secret or is malformed", secretName)
}
// Update the certificate chain and private key fields
tlsCert := secret.GetTlsCertificate()
// --- CertificateChain Update: Always use InlineString ---
// The certificate content (CertPEM) is stored as an InlineString.
tlsCert.CertificateChain = &corev3.DataSource{
Specifier: &corev3.DataSource_InlineString{
InlineString: string(internalcert.CertPEM),
},
}
// --- End CertificateChain Update ---
// --- PrivateKey Update: Always use InlineString ---
// The private key content (KeyPEM) is stored as an InlineString.
tlsCert.PrivateKey = &corev3.DataSource{
Specifier: &corev3.DataSource_InlineString{
InlineString: string(internalcert.KeyPEM),
},
}
// --- End PrivateKey Update ---
log.Debugf("Updated TlsCertificate data for secret '%s' (Domain: %s)", secretName, internalcert.Domain)
// 3. Get current snapshot to extract all resources for the new snapshot
snap, err := sm.Cache.GetSnapshot(sm.NodeID)
if err != nil {
return fmt.Errorf("failed to get snapshot for modification: %w", err)
}
// Get all current resources
resources := sm.getAllResourcesFromSnapshot(snap)
// Replace the old secret with the modified one in the resource list
secretList, ok := resources[secretType]
if !ok {
return fmt.Errorf("secret resource type not present in snapshot")
}
foundAndReplaced := false
for i, res := range secretList {
// Assuming the resource implements a GetName() method
if namer, ok := res.(interface{ GetName() string }); ok && namer.GetName() == secretName {
// The `secret` variable already holds the modified secret
secretList[i] = secret
foundAndReplaced = true
break
}
}
if !foundAndReplaced {
// Should not happen if GetResourceFromCache succeeded.
return fmt.Errorf("failed to locate Secret '%s' in current resource list for replacement", secretName)
}
// 4. Create and set the new snapshot
version := fmt.Sprintf("secret-update-byname-%s-%d", secretName, time.Now().UnixNano())
newSnap, err := cachev3.NewSnapshot(version, resources)
if err != nil {
return fmt.Errorf("failed to create new snapshot: %w", err)
}
if err := sm.Cache.SetSnapshot(ctx, sm.NodeID, newSnap); err != nil {
return fmt.Errorf("failed to set new snapshot: %w", err)
}
// 5. Flush the updated snapshot metadata to persistent storage.
sm.FlushCacheToDB(ctx, storage.DeleteLogical)
log.Infof("Successfully updated Secret '%s' in cache with refreshed certificate.", secretName)
return nil
}
// AddSDSSecretWithCert creates a new Secret resource for SDS with the provided certificate
// and adds it to the snapshot. The secret name is derived from the domain:
// e.g., "abc.com" -> "abc_com".
func (sm *SnapshotManager) AddSDSSecretWithCert(ctx context.Context, internalcert *internalcertapi.Certificate) error {
log := internallog.LogFromContext(ctx)
if internalcert == nil {
return fmt.Errorf("certificate data is nil")
}
// 1. Determine the Secret Name (e.g., "abc.com" -> "abc_com")
// NOTE: This logic assumes the 'Domain' field of the certificate is what should be used for naming.
secretName := strings.ReplaceAll(internalcert.Domain, ".", "_")
// 2. Create the Secret resource (envoy/config/secret/v3.Secret)
tlsCert := &secretv3.TlsCertificate{
// The certificate content (CertPEM) is stored as an InlineString.
CertificateChain: &corev3.DataSource{
Specifier: &corev3.DataSource_InlineString{
InlineString: string(internalcert.CertPEM),
},
},
// The private key content (KeyPEM) is stored as an InlineString.
PrivateKey: &corev3.DataSource{
Specifier: &corev3.DataSource_InlineString{
InlineString: string(internalcert.KeyPEM),
},
},
}
newSecret := &secretv3.Secret{
Name: secretName,
Type: &secretv3.Secret_TlsCertificate{
TlsCertificate: tlsCert,
},
}
// 3. Add the resource to the snapshot using the generic function
if err := sm.AddResourceToSnapshot(newSecret, resourcev3.SecretType, true /*upsert flag*/); err != nil {
return fmt.Errorf("failed to add SDS Secret '%s' to snapshot: %w", secretName, err)
}
// 4. Flush the updated snapshot metadata to persistent storage.
sm.FlushCacheToDB(ctx, storage.DeleteLogical)
log.Infof("Successfully added new SDS Secret '%s' for domain '%s' to cache.", secretName, internalcert.Domain)
return nil
}
// mapToSlice converts a map of named resources to a slice of resources.
func mapToSlice(m map[string]types.Resource) []types.Resource {
out := make([]types.Resource, 0, len(m))
for _, r := range m {
out = append(out, r)
}
return out
}
// filterAndCheckResourcesByName filters a slice of resources by name,
// returning the filtered slice and a boolean indicating if the named resource was found.
func filterAndCheckResourcesByName(resources []types.Resource, name string) ([]types.Resource, bool) {
filtered := []types.Resource{}
var found = false
for _, r := range resources {
if namer, ok := r.(interface{ GetName() string }); ok {
if namer.GetName() != name {
filtered = append(filtered, r)
} else {
found = true
}
} else {
// fallback, include unknown type
filtered = append(filtered, r)
}
}
return filtered, found
}