Newer
Older
EnvoyControlPlane / main.go
package main

import (
	"context"
	"database/sql"
	"flag"
	"fmt"
	"net/http"
	"os"
	"path/filepath"
	"strings"

	_ "github.com/lib/pq"           // Postgres driver
	_ "github.com/mattn/go-sqlite3" // SQLite driver

	"github.com/envoyproxy/go-control-plane/pkg/cache/types"
	cachev3 "github.com/envoyproxy/go-control-plane/pkg/cache/v3"
	resourcev3 "github.com/envoyproxy/go-control-plane/pkg/resource/v3"
	"github.com/envoyproxy/go-control-plane/pkg/server/v3"
	"github.com/envoyproxy/go-control-plane/pkg/test/v3"
	"k8s.io/klog/v2"

	"envoy-control-plane/internal"
)

var (
    logger       *internal.DefaultLogger
    port         uint
    nodeID       string
    restPort     uint
    snapshotFile string
    configDir    string
    dbConnStr    string
    dbDriver     string
)

func init() {
    logger = internal.NewDefaultLogger()
    klog.InitFlags(nil)

    flag.UintVar(&port, "port", 18000, "xDS management server port")
    flag.StringVar(&nodeID, "nodeID", "test-id", "Node ID")
    flag.UintVar(&restPort, "rest-port", 8080, "REST API server port")
    flag.StringVar(&snapshotFile, "snapshot-file", "", "Optional initial snapshot JSON/YAML file")
    flag.StringVar(&configDir, "config-dir", "data/", "Optional directory containing multiple config files")
    flag.StringVar(&dbConnStr, "db", "", "Optional database connection string for config persistence")
}

// determineDriver returns driver name from connection string
func determineDriver(dsn string) string {
    if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
        return "postgres"
    }
    return "sqlite3"
}

// loadConfigFiles iterates over a directory and loads all .yaml/.json files
func loadConfigFiles(manager *internal.SnapshotManager, dir string) error {
    logger.Infof("loading configuration files from directory: %s", dir)

    files, err := os.ReadDir(dir)
    if err != nil {
        return fmt.Errorf("failed to read directory %s: %w", dir, err)
    }

    resourceFiles := make(map[string][]types.Resource)
    for _, file := range files {
        if file.IsDir() {
            continue
        }
        fileName := file.Name()
        if strings.HasSuffix(fileName, ".yaml") || strings.HasSuffix(fileName, ".yml") || strings.HasSuffix(fileName, ".json") {
            filePath := filepath.Join(dir, fileName)
            logger.Infof("  -> loading config file: %s", filePath)

            rf, err := manager.LoadSnapshotFromFile(filePath)
            if err != nil {
                return fmt.Errorf("failed to load snapshot from file %s: %w", filePath, err)
            }
            for k, v := range rf {
                resourceFiles[k] = append(resourceFiles[k], v...)
            }
            logger.Infof("loaded %d resources from %s", len(rf), filePath)
        }
    }

    if err := manager.SetSnapshot(context.TODO(), "snap-from-file", resourceFiles); err != nil {
        return fmt.Errorf("failed to set combined snapshot from files: %w", err)
    }
    return nil
}

// CORS is a middleware that sets the Access-Control-Allow-Origin header to * (all origins).
func CORS(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		// Set CORS headers for all domains
		w.Header().Set("Access-Control-Allow-Origin", "*")
		w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT, DELETE")
		w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Access-Control-Allow-Headers, Authorization, X-Requested-With")

		// Handle preflight requests
		if r.Method == "OPTIONS" {
			w.WriteHeader(http.StatusOK)
			return
		}

		next.ServeHTTP(w, r)
	})
}


func main() {
    flag.Parse()
    defer klog.Flush()

    // Default DB to SQLite file if none provided
    if dbConnStr == "" {
        defaultDBPath := "data/config.db"
        if err := os.MkdirAll(filepath.Dir(defaultDBPath), 0755); err != nil {
            fmt.Fprintf(os.Stderr, "failed to create data directory: %v\n", err)
            os.Exit(1)
        }
        dbConnStr = fmt.Sprintf("file:%s?_foreign_keys=on", defaultDBPath)
        dbDriver = "sqlite3"
    } else {
        dbDriver = determineDriver(dbConnStr)
    }
    // --- Database initialization ---
    db, err := sql.Open(dbDriver, dbConnStr)
    if err != nil {
        logger.Errorf("failed to connect to DB: %v", err)
        os.Exit(1)
    }
    defer db.Close()

    storage := internal.NewStorage(db, dbDriver)
    if err := storage.InitSchema(context.Background()); err != nil {
        logger.Errorf("failed to initialize DB schema: %v", err)
        os.Exit(1)
    }

    // Create snapshot cache and manager
    cache := cachev3.NewSnapshotCache(false, cachev3.IDHash{}, logger)
    manager := internal.NewSnapshotManager(cache, nodeID, storage)

    loadedConfigs := false

    // Step 1: Try to load snapshot from DB
    snapCfg, err := storage.RebuildSnapshot(context.Background())
    if err == nil && len(snapCfg.EnabledClusters)+len(snapCfg.EnabledListeners) > 0 {
        if err := manager.SetSnapshotFromConfig(context.Background(), "snap-from-db", snapCfg); err != nil {
            logger.Errorf("failed to set DB snapshot: %v", err)
            os.Exit(1)
        }
        loadedConfigs = true
        logger.Infof("loaded snapshot from database")
    }

    // Step 2: If DB empty, load from files and persist into DB
    if !loadedConfigs {
        if configDir != "" {
            if err := loadConfigFiles(manager, configDir); err != nil {
                logger.Errorf("failed to load configs from directory: %v", err)
                os.Exit(1)
            }
            loadedConfigs = true
        } else if snapshotFile != "" {
            if _, err := os.Stat(snapshotFile); err == nil {
                resources, err := manager.LoadSnapshotFromFile(snapshotFile)
                if err != nil {
                    logger.Errorf("failed to load snapshot from file: %v", err)
                    os.Exit(1)
                }
                if err := manager.SetSnapshot(context.TODO(), "snap-from-file", resources); err != nil {
                    logger.Errorf("failed to set loaded snapshot: %v", err)
                    os.Exit(1)
                }
                loadedConfigs = true
            } else {
                logger.Warnf("snapshot file not found: %s", snapshotFile)
            }
        }

        // Persist loaded snapshot into DB
        if loadedConfigs {
            snapCfg, err := manager.SnapshotToConfig(context.Background(), nodeID)
            if err != nil {
                logger.Errorf("failed to convert snapshot to DB config: %v", err)
                os.Exit(1)
            }
            if err := storage.SaveSnapshot(context.Background(), snapCfg, internal.DeleteLogical); err != nil {
                logger.Errorf("failed to save initial snapshot into DB: %v", err)
                os.Exit(1)
            }
            logger.Infof("initial snapshot written into database")
        }
    }

    // Step 3: Ensure snapshot exists in cache
    snap, err := manager.Cache.GetSnapshot(nodeID)
    if err != nil || !loadedConfigs {
        logger.Warnf("no valid snapshot found, creating empty snapshot")
        snap, _ = cachev3.NewSnapshot("snap-init", map[resourcev3.Type][]types.Resource{
            resourcev3.ClusterType:  {},
            resourcev3.RouteType:    {},
            resourcev3.ListenerType: {},
        })
        if err := cache.SetSnapshot(context.Background(), nodeID, snap); err != nil {
            logger.Errorf("failed to set initial snapshot: %v", err)
            os.Exit(1)
        }
    }

    logger.Infof("xDS snapshot ready: version %s", snap.GetVersion(string(resourcev3.ClusterType)))

    // --- Start xDS gRPC server ---
    ctx := context.Background()
    cb := &test.Callbacks{Debug: true}
    srv := server.NewServer(ctx, cache, cb)
    go internal.RunServer(srv, port)

    // --- Start REST API server ---
    api := internal.NewAPI(manager)
    mux := http.NewServeMux()
    api.RegisterRoutes(mux)

    // Wrap the main multiplexer with the CORS handler
    corsHandler := CORS(mux) 

    // NEW: Serve the index.html file and any other static assets
    mux.Handle("/", http.FileServer(http.Dir("./static"))) // Assuming 'web' is the folder

    restAddr := fmt.Sprintf(":%d", restPort)
    logger.Infof("starting REST API server on %s", restAddr)
    if err := http.ListenAndServe(restAddr, corsHandler); err != nil { // Use corsHandler
        logger.Errorf("REST server error: %v", err)
        os.Exit(1)
    }
}