Newer
Older
EnvoyControlPlane / main.go
package main

import (
	"context"
	"database/sql"

	"flag"
	"fmt"
	"net/http"
	"os"
	"path/filepath"
	"strings"

	"github.com/envoyproxy/go-control-plane/pkg/cache/types"
	cachev3 "github.com/envoyproxy/go-control-plane/pkg/cache/v3"
	resourcev3 "github.com/envoyproxy/go-control-plane/pkg/resource/v3"
	"github.com/envoyproxy/go-control-plane/pkg/server/v3"
	"github.com/envoyproxy/go-control-plane/pkg/test/v3"
	_ "github.com/lib/pq"           // Postgres driver
	_ "github.com/mattn/go-sqlite3" // SQLite driver
	"k8s.io/klog/v2"

	"envoy-control-plane/internal"
	internallog "envoy-control-plane/internal/log"
	"envoy-control-plane/internal/snapshot"
	internalstorage "envoy-control-plane/internal/storage"
)

var (
	// The logger variable should now be of the internal.Logger interface type
	// to use the custom context functions.
	logger       internallog.Logger
	port         uint
	nodeID       string
	restPort     uint
	snapshotFile string
	configDir    string
	dbConnStr    string
	dbDriver     string
)

func init() {
	// Initialize the default logger (which implements the internal.Logger interface)
	logger = internallog.NewDefaultLogger()
	klog.InitFlags(nil)

	flag.UintVar(&port, "port", 18000, "xDS management server port")
	flag.StringVar(&nodeID, "nodeID", "test-id", "Node ID")
	flag.UintVar(&restPort, "rest-port", 8080, "REST API server port")
	flag.StringVar(&snapshotFile, "snapshot-file", "", "Optional initial snapshot JSON/YAML file")
	flag.StringVar(&configDir, "config-dir", "data/", "Optional directory containing multiple config files")
	flag.StringVar(&dbConnStr, "db", "", "Optional database connection string for config persistence")
}

// determineDriver returns driver name from connection string
func determineDriver(dsn string) string {
	if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
		return "postgres"
	}
	return "sqlite3"
}

// loadConfigFiles now accepts and uses a context
func loadConfigFiles(ctx context.Context, manager *snapshot.SnapshotManager, dir string) error {
	log := internallog.LogFromContext(ctx) // Use the logger from context

	log.Infof("loading configuration files from directory: %s", dir)

	files, err := os.ReadDir(dir)
	if err != nil {
		return fmt.Errorf("failed to read directory %s: %w", dir, err)
	}

	resourceFiles := make(map[string][]types.Resource)
	for _, file := range files {
		if file.IsDir() {
			continue
		}
		fileName := file.Name()
		if strings.HasSuffix(fileName, ".yaml") || strings.HasSuffix(fileName, ".yml") || strings.HasSuffix(fileName, ".json") {
			filePath := filepath.Join(dir, fileName)
			log.Infof("  -> loading config file: %s", filePath)

			rf, err := snapshot.LoadSnapshotFromFile(ctx, filePath)
			if err != nil {
				return fmt.Errorf("failed to load snapshot from file %s: %w", filePath, err)
			}
			for k, v := range rf {
				resourceFiles[k] = append(resourceFiles[k], v...)
			}
			log.Infof("loaded %d resources from %s", len(rf), filePath)
		}
	}

	if err := manager.SetSnapshot(ctx, "snap-from-file", resourceFiles); err != nil {
		return fmt.Errorf("failed to set combined snapshot from files: %w", err)
	}
	return nil
}

// CORS is a middleware that sets the Access-Control-Allow-Origin header to * (all origins).
func CORS(next http.Handler) http.Handler {
	// ... (CORS implementation remains unchanged) ...
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		// Set CORS headers for all domains
		w.Header().Set("Access-Control-Allow-Origin", "*")
		w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT, DELETE")
		w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Access-Control-Allow-Headers, Authorization, X-Requested-With")

		// Handle preflight requests
		if r.Method == "OPTIONS" {
			w.WriteHeader(http.StatusOK)
			return
		}

		next.ServeHTTP(w, r)
	})
}

func main() {
	flag.Parse()
	defer klog.Flush()

	// 1. Create the root context and inject the logger
	ctx := internallog.WithLogger(context.Background(), logger)
	log := internallog.LogFromContext(ctx) // Now 'log' is the context-aware logger

	// Default DB to SQLite file if none provided
	if dbConnStr == "" {
		defaultDBPath := "data/config.db"
		if err := os.MkdirAll(filepath.Dir(defaultDBPath), 0755); err != nil {
			fmt.Fprintf(os.Stderr, "failed to create data directory: %v\n", err)
			os.Exit(1)
		}
		dbConnStr = fmt.Sprintf("file:%s?_foreign_keys=on", defaultDBPath)
		dbDriver = "sqlite3"
	} else {
		dbDriver = determineDriver(dbConnStr)
	}
	// --- Database initialization ---
	db, err := sql.Open(dbDriver, dbConnStr)
	if err != nil {
		log.Errorf("failed to connect to DB: %v", err)
		os.Exit(1)
	}
	defer db.Close()

	// internal.NewStorage likely needs to be updated to accept a logger as well
	// if its methods don't accept context, but we will pass context to its methods below.
	storage := internalstorage.NewStorage(db, dbDriver)
	// Pass the context with the logger down
	if err := storage.InitSchema(ctx); err != nil {
		log.Errorf("failed to initialize DB schema: %v", err)
		os.Exit(1)
	}

	// Create snapshot cache and manager
	// NOTE: The Envoy cachev3.NewSnapshotCache takes a `log.Logger` from go-control-plane,
	// which is likely a different interface. For now, we continue to use the global 'logger'
	// variable (which is an internal.Logger that wraps klog, matching the go-control-plane
	// logger behavior you previously set up) as a bridge, since it was initialized
	// to log.NewDefaultLogger().
	cache := cachev3.NewSnapshotCache(false, cachev3.IDHash{}, logger)
	manager := snapshot.NewSnapshotManager(cache, nodeID, storage)

	loadedConfigs := false

	// Step 1: Try to load snapshot from DB
	// Pass the context with the logger down
	snapCfg, err := storage.RebuildSnapshot(ctx)
	if err == nil && len(snapCfg.EnabledClusters)+len(snapCfg.EnabledListeners) > 0 {
		if err := manager.SetSnapshotFromConfig(ctx, "snap-from-db", snapCfg); err != nil {
			log.Errorf("failed to set DB snapshot: %v", err)
			os.Exit(1)
		}
		loadedConfigs = true
		log.Infof("loaded snapshot from database")
	}

	// Step 2: If DB empty, load from files and persist into DB
	if !loadedConfigs {
		if configDir != "" {
			// Pass the context with the logger down
			if err := loadConfigFiles(ctx, manager, configDir); err != nil {
				log.Errorf("failed to load configs from directory: %v", err)
				os.Exit(1)
			}
			loadedConfigs = true
		} else if snapshotFile != "" {
			if _, err := os.Stat(snapshotFile); err == nil {
				resources, err := snapshot.LoadSnapshotFromFile(ctx, snapshotFile)
				if err != nil {
					log.Errorf("failed to load snapshot from file: %v", err)
					os.Exit(1)
				}
				if err := manager.SetSnapshot(ctx, "snap-from-file", resources); err != nil {
					log.Errorf("failed to set loaded snapshot: %v", err)
					os.Exit(1)
				}
				loadedConfigs = true
			} else {
				log.Warnf("snapshot file not found: %s", snapshotFile)
			}
		}

		// Persist loaded snapshot into DB
		if loadedConfigs {
			// Pass the context with the logger down
			snapCfg, err := manager.SnapshotToConfig(ctx, nodeID)
			if err != nil {
				log.Errorf("failed to convert snapshot to DB config: %v", err)
				os.Exit(1)
			}
			// Pass the context with the logger down
			if err := storage.SaveSnapshot(ctx, snapCfg, internalstorage.DeleteLogical); err != nil {
				log.Errorf("failed to save initial snapshot into DB: %v", err)
				os.Exit(1)
			}
			log.Infof("initial snapshot written into database")
		}
	}

	// Step 3: Ensure snapshot exists in cache
	snap, err := manager.Cache.GetSnapshot(nodeID)
	if err != nil || !loadedConfigs {
		log.Warnf("no valid snapshot found, creating empty snapshot")
		snap, _ = cachev3.NewSnapshot("snap-init", map[resourcev3.Type][]types.Resource{
			resourcev3.ClusterType:  {},
			resourcev3.RouteType:    {},
			resourcev3.ListenerType: {},
		})
		// Pass the context with the logger down
		if err := cache.SetSnapshot(ctx, nodeID, snap); err != nil {
			log.Errorf("failed to set initial snapshot: %v", err)
			os.Exit(1)
		}
	}

	log.Infof("xDS snapshot ready: version %s", snap.GetVersion(string(resourcev3.ClusterType)))

	// --- Start xDS gRPC server ---
	// The root context with logger is used here
	cb := &test.Callbacks{Debug: true}
	srv := server.NewServer(ctx, cache, cb)
	go internal.RunServer(srv, port)

	// --- Start REST API server ---
	api := internal.NewAPI(manager)
	mux := http.NewServeMux()
	// NOTE: If api.RegisterRoutes uses a context to log, it should be updated.
	api.RegisterRoutes(mux)

	// Wrap the main multiplexer with the CORS handler
	corsHandler := CORS(mux)

	// NEW: Serve the index.html file and any other static assets
	mux.Handle("/", http.FileServer(http.Dir("./static"))) // Assuming 'web' is the folder

	restAddr := fmt.Sprintf(":%d", restPort)
	log.Infof("starting REST API server on %s", restAddr)
	if err := http.ListenAndServe(restAddr, corsHandler); err != nil { // Use corsHandler
		log.Errorf("REST server error: %v", err)
		os.Exit(1)
	}
}