Newer
Older
EnvoyControlPlane / internal / rest_api.go
package internal

import (
	"encoding/json"
	"fmt"
	"net/http"

	"github.com/envoyproxy/go-control-plane/pkg/cache/types"
	resourcev3 "github.com/envoyproxy/go-control-plane/pkg/resource/v3"
	"github.com/google/uuid"
	"google.golang.org/protobuf/encoding/protojson"
	"google.golang.org/protobuf/reflect/protoreflect"
)

// API holds reference to snapshot manager
type API struct {
	Manager *SnapshotManager
}

// AddClusterRequest defines payload to add a cluster
type AddClusterRequest struct {
	Name string `json:"name"`
}

// RemoveClusterRequest defines payload to remove a cluster
type RemoveClusterRequest struct {
	Name string `json:"name"`
}

// AddRouteRequest defines payload to add a route
type AddRouteRequest struct {
	Name       string `json:"name"`
	Cluster    string `json:"cluster"`
	PathPrefix string `json:"path_prefix"`
}

// RemoveRouteRequest defines payload to remove a route
type RemoveRouteRequest struct {
	Name string `json:"name"`
}

// SnapshotFileRequest defines payload to load/save snapshot from/to file
type SnapshotFileRequest struct {
	Path string `json:"path"`
}

// AddListenerRequest defines payload to add a listener
type AddListenerRequest struct {
	Name string `json:"name"`
	Port uint32 `json:"port"`
}

// RemoveListenerRequest defines payload to remove a listener
type RemoveListenerRequest struct {
	Name string `json:"name"`
}

// NewAPI returns a new REST API handler
func NewAPI(sm *SnapshotManager) *API {
	return &API{
		Manager: sm,
	}
}

// RegisterRoutes mounts REST handlers
func (api *API) RegisterRoutes(mux *http.ServeMux) {
	// Management Handlers (Add / Remove)
	mux.HandleFunc("/add-cluster", api.addCluster)
	mux.HandleFunc("/remove-cluster", func(w http.ResponseWriter, r *http.Request) {
		api.removeResourceHandler(w, r, resourcev3.ClusterType)
	})
	mux.HandleFunc("/add-route", api.addRoute)
	mux.HandleFunc("/remove-route", func(w http.ResponseWriter, r *http.Request) {
		api.removeResourceHandler(w, r, resourcev3.RouteType)
	})
	mux.HandleFunc("/add-listener", api.addListener)
	mux.HandleFunc("/remove-listener", func(w http.ResponseWriter, r *http.Request) {
		api.removeResourceHandler(w, r, resourcev3.ListenerType)
	})

	// Query / List Handlers
	mux.HandleFunc("/list-clusters", func(w http.ResponseWriter, r *http.Request) {
		api.listResourceHandler(w, r, resourcev3.ClusterType)
	})
	mux.HandleFunc("/get-cluster", func(w http.ResponseWriter, r *http.Request) {
		api.getResourceHandler(w, r, resourcev3.ClusterType)
	})

	mux.HandleFunc("/list-routes", func(w http.ResponseWriter, r *http.Request) {
		api.listResourceHandler(w, r, resourcev3.RouteType)
	})
	mux.HandleFunc("/get-route", func(w http.ResponseWriter, r *http.Request) {
		api.getResourceHandler(w, r, resourcev3.RouteType)
	})

	mux.HandleFunc("/list-listeners", func(w http.ResponseWriter, r *http.Request) {
		api.listResourceHandler(w, r, resourcev3.ListenerType)
	})
	mux.HandleFunc("/get-listener", func(w http.ResponseWriter, r *http.Request) {
		api.getResourceHandler(w, r, resourcev3.ListenerType)
	})
}

// ---------------- Cluster / Route / Listener Handlers Using Generic ----------------

func (api *API) addCluster(w http.ResponseWriter, r *http.Request) {
	api.addResourceHandler(w, r, resourcev3.ClusterType, func(req interface{}) types.Resource {
		cr := req.(*AddClusterRequest)
		name := cr.Name
		if name == "" {
			name = uuid.NewString()
		}
		return NewCluster(name)
	})
}

func (api *API) addRoute(w http.ResponseWriter, r *http.Request) {
	api.addResourceHandler(w, r, resourcev3.RouteType, func(req interface{}) types.Resource {
		rr := req.(*AddRouteRequest)
		return NewRoute(rr.Name, rr.Cluster, rr.PathPrefix)
	})
}

func (api *API) addListener(w http.ResponseWriter, r *http.Request) {
	api.addResourceHandler(w, r, resourcev3.ListenerType, func(req interface{}) types.Resource {
		lr := req.(*AddListenerRequest)
		return NewListener(lr.Name, lr.Port)
	})
}

// ---------------- Generic REST Handlers ----------------

// createFn returns a types.Resource
func (api *API) addResourceHandler(w http.ResponseWriter, r *http.Request, typ resourcev3.Type, createFn func(interface{}) types.Resource) {
	if r.Method != http.MethodPost {
		http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
		return
	}
	w.Header().Set("Content-Type", "application/json")

	var req interface{}
	switch typ {
	case resourcev3.ClusterType:
		req = &AddClusterRequest{}
	case resourcev3.RouteType:
		req = &AddRouteRequest{}
	case resourcev3.ListenerType:
		req = &AddListenerRequest{}
	default:
		http.Error(w, "unsupported type", http.StatusBadRequest)
		return
	}

	if err := json.NewDecoder(r.Body).Decode(req); err != nil {
		http.Error(w, "invalid request", http.StatusBadRequest)
		return
	}

	res := createFn(req)
	if err := api.Manager.AddResource(res, typ); err != nil {
		http.Error(w, fmt.Sprintf("failed to add resource: %v", err), http.StatusInternalServerError)
		return
	}

	w.WriteHeader(http.StatusCreated)
	json.NewEncoder(w).Encode(map[string]string{"name": res.(interface{ GetName() string }).GetName()})
}

func (api *API) removeResourceHandler(w http.ResponseWriter, r *http.Request, typ resourcev3.Type) {
	if r.Method != http.MethodPost {
		http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
		return
	}
	w.Header().Set("Content-Type", "application/json")

	var req struct{ Name string }
	if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.Name == "" {
		http.Error(w, "name required", http.StatusBadRequest)
		return
	}

	if err := api.Manager.RemoveResource(req.Name, typ); err != nil {
		http.Error(w, fmt.Sprintf("failed to remove resource: %v", err), http.StatusInternalServerError)
		return
	}

	w.WriteHeader(http.StatusOK)
}

func (api *API) listResourceHandler(w http.ResponseWriter, r *http.Request, typ resourcev3.Type) {
	if r.Method != http.MethodGet {
		http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
		return
	}
	w.Header().Set("Content-Type", "application/json")

	resources, err := api.Manager.ListResources(typ)
	if err != nil {
		http.Error(w, fmt.Sprintf("failed to list resources: %v", err), http.StatusInternalServerError)
		return
	}

	out := []json.RawMessage{}
	for _, res := range resources {
		if pb, ok := res.(interface{ ProtoReflect() protoreflect.Message }); ok {
			data, _ := protojson.Marshal(pb)
			out = append(out, data)
		}
	}

	w.WriteHeader(http.StatusOK)
	json.NewEncoder(w).Encode(out)
}

func (api *API) getResourceHandler(w http.ResponseWriter, r *http.Request, typ resourcev3.Type) {
	if r.Method != http.MethodGet {
		http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
		return
	}
	w.Header().Set("Content-Type", "application/json")

	name := r.URL.Query().Get("name")
	if name == "" {
		http.Error(w, "name query parameter required", http.StatusBadRequest)
		return
	}

	res, err := api.Manager.GetResource(name, typ)
	if err != nil {
		http.Error(w, fmt.Sprintf("resource not found: %v", err), http.StatusNotFound)
		return
	}

	// Marshal using protojson for full nested fields
	if pb, ok := res.(interface{ ProtoReflect() protoreflect.Message }); ok {
		data, err := protojson.Marshal(pb)
		if err != nil {
			http.Error(w, fmt.Sprintf("failed to marshal protobuf: %v", err), http.StatusInternalServerError)
			return
		}
		w.WriteHeader(http.StatusOK)
		w.Write(data)
		return
	}

	// fallback for non-proto resources
	http.Error(w, "resource is not a protobuf message", http.StatusInternalServerError)
}