From 7eb0b74b70b7b17bc9d6dfaf11b560a9462bfe73 Mon Sep 17 00:00:00 2001 From: shenlan Date: Mon, 6 Oct 2025 11:30:19 +0800 Subject: [PATCH] Tighten CORS policies for production domains (#417) --- account/cmd/accountsvc/main.go | 1 + account/config/account.yaml | 2 + cmd/xcontrol-server/main.go | 1 + server/config/config.go | 8 +- server/config/config_test.go | 16 ++++ server/config/server.yaml | 9 ++ server/cors.go | 167 +++++++++++++++++++++++++++++++++ 7 files changed, 201 insertions(+), 3 deletions(-) create mode 100644 server/cors.go diff --git a/account/cmd/accountsvc/main.go b/account/cmd/accountsvc/main.go index 197e8e6..c83fb46 100644 --- a/account/cmd/accountsvc/main.go +++ b/account/cmd/accountsvc/main.go @@ -337,6 +337,7 @@ func buildCORSConfig(logger *slog.Logger, serverCfg config.Server) cors.Config { "Accept", "Origin", "X-Requested-With", + "Cookie", }, ExposeHeaders: []string{ "Content-Length", diff --git a/account/config/account.yaml b/account/config/account.yaml index 28c503e..d8a3ebc 100644 --- a/account/config/account.yaml +++ b/account/config/account.yaml @@ -7,6 +7,8 @@ server: writeTimeout: 15s publicUrl: "http://localhost:8080" allowedOrigins: + - "https://www.svc.plus" + - "https://global-homepage.svc.plus" - "https://account.svc.plus" - "https://localhost:8443" - "http://localhost:8080" diff --git a/cmd/xcontrol-server/main.go b/cmd/xcontrol-server/main.go index 5881f05..325774b 100644 --- a/cmd/xcontrol-server/main.go +++ b/cmd/xcontrol-server/main.go @@ -85,6 +85,7 @@ var rootCmd = &cobra.Command{ r := server.New( api.RegisterRoutes(conn, cfg.Sync.Repo.Proxy), ) + server.UseCORS(r, logger, cfg.Server) addr := cfg.Server.Addr if addr == "" { diff --git a/server/config/config.go b/server/config/config.go index c25a7a3..4232112 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -153,9 +153,11 @@ func (d Duration) String() string { // ServerCfg contains HTTP server runtime configuration. type ServerCfg struct { - Addr string `yaml:"addr"` - ReadTimeout Duration `yaml:"readTimeout"` - WriteTimeout Duration `yaml:"writeTimeout"` + Addr string `yaml:"addr"` + ReadTimeout Duration `yaml:"readTimeout"` + WriteTimeout Duration `yaml:"writeTimeout"` + PublicURL string `yaml:"publicUrl"` + AllowedOrigins []string `yaml:"allowedOrigins"` } type Config struct { diff --git a/server/config/config_test.go b/server/config/config_test.go index 9123442..02ad081 100644 --- a/server/config/config_test.go +++ b/server/config/config_test.go @@ -2,6 +2,7 @@ package config import ( "os" + "reflect" "testing" "time" ) @@ -37,4 +38,19 @@ func TestLoad(t *testing.T) { if cfg.Server.WriteTimeout.Duration != 15*time.Second { t.Fatalf("unexpected server write timeout %s", cfg.Server.WriteTimeout) } + if cfg.Server.PublicURL != "https://www.svc.plus" { + t.Fatalf("unexpected server public url %q", cfg.Server.PublicURL) + } + wantOrigins := []string{ + "https://www.svc.plus", + "https://global-homepage.svc.plus", + "https://account.svc.plus", + "http://localhost:3000", + "http://127.0.0.1:3000", + "http://localhost:3001", + "http://127.0.0.1:3001", + } + if !reflect.DeepEqual(cfg.Server.AllowedOrigins, wantOrigins) { + t.Fatalf("unexpected server allowed origins %#v", cfg.Server.AllowedOrigins) + } } diff --git a/server/config/server.yaml b/server/config/server.yaml index ac06f54..665b855 100644 --- a/server/config/server.yaml +++ b/server/config/server.yaml @@ -2,6 +2,15 @@ server: addr: ":8090" readTimeout: 15s writeTimeout: 15s + publicUrl: "https://www.svc.plus" + allowedOrigins: + - "https://www.svc.plus" + - "https://global-homepage.svc.plus" + - "https://account.svc.plus" + - "http://localhost:3000" + - "http://127.0.0.1:3000" + - "http://localhost:3001" + - "http://127.0.0.1:3001" global: redis: diff --git a/server/cors.go b/server/cors.go new file mode 100644 index 0000000..e2261d0 --- /dev/null +++ b/server/cors.go @@ -0,0 +1,167 @@ +package server + +import ( + "fmt" + "log/slog" + "net" + "net/http" + "net/url" + "strings" + "time" + + "github.com/gin-contrib/cors" + "github.com/gin-gonic/gin" + + "xcontrol/server/config" +) + +// UseCORS applies a restrictive CORS policy to the provided gin engine based on the +// server configuration. When the configuration specifies explicit origins the +// middleware allows credentials and mirrors the origin. When the configuration +// uses the "*" wildcard, credentials are disabled to remain compliant with the +// Fetch specification. +func UseCORS(r *gin.Engine, logger *slog.Logger, serverCfg config.ServerCfg) { + if r == nil { + return + } + if logger == nil { + logger = slog.Default() + } + + corsCfg := buildCORSConfig(logger, serverCfg) + if corsCfg.AllowAllOrigins { + logger.Info("configured cors", "allowAllOrigins", true) + } else { + logger.Info("configured cors", "allowedOrigins", corsCfg.AllowOrigins) + } + r.Use(cors.New(corsCfg)) +} + +func buildCORSConfig(logger *slog.Logger, serverCfg config.ServerCfg) cors.Config { + allowOrigins, allowAll := resolveAllowedOrigins(logger, serverCfg) + + cfg := cors.Config{ + AllowMethods: []string{ + http.MethodGet, + http.MethodHead, + http.MethodPost, + http.MethodPut, + http.MethodPatch, + http.MethodDelete, + http.MethodOptions, + }, + AllowHeaders: []string{ + "Authorization", + "Content-Type", + "Accept", + "Origin", + "X-Requested-With", + "Cookie", + }, + ExposeHeaders: []string{ + "Content-Length", + }, + MaxAge: 12 * time.Hour, + } + + if allowAll { + cfg.AllowAllOrigins = true + cfg.AllowCredentials = false + } else { + cfg.AllowOrigins = allowOrigins + cfg.AllowCredentials = true + } + + return cfg +} + +func resolveAllowedOrigins(logger *slog.Logger, serverCfg config.ServerCfg) ([]string, bool) { + rawOrigins := serverCfg.AllowedOrigins + seen := make(map[string]struct{}, len(rawOrigins)) + origins := make([]string, 0, len(rawOrigins)) + allowAll := false + + for _, origin := range rawOrigins { + trimmed := strings.TrimSpace(origin) + if trimmed == "" { + continue + } + if trimmed == "*" { + allowAll = true + continue + } + + normalized, err := parseOrigin(trimmed) + if err != nil { + logger.Warn("ignoring invalid cors origin", "origin", origin, "err", err) + continue + } + if _, exists := seen[normalized]; exists { + continue + } + seen[normalized] = struct{}{} + origins = append(origins, normalized) + } + + if allowAll { + return nil, true + } + + if len(origins) == 0 { + publicURL := strings.TrimSpace(serverCfg.PublicURL) + if publicURL != "" { + normalized, err := parseOrigin(publicURL) + if err != nil { + logger.Warn("invalid server public url; falling back to defaults", "publicUrl", publicURL, "err", err) + } else { + origins = append(origins, normalized) + } + } + } + + if len(origins) == 0 { + origins = []string{ + "http://localhost:3000", + "http://127.0.0.1:3000", + "http://localhost:3001", + "http://127.0.0.1:3001", + "https://localhost:8443", + } + } + + return origins, false +} + +func parseOrigin(value string) (string, error) { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return "", fmt.Errorf("origin is empty") + } + + normalized := trimmed + if !strings.Contains(normalized, "://") { + normalized = "https://" + normalized + } + + parsed, err := url.Parse(normalized) + if err != nil { + return "", err + } + + scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme)) + if scheme == "" { + return "", fmt.Errorf("origin must include a scheme") + } + + hostname := strings.ToLower(strings.TrimSpace(parsed.Hostname())) + if hostname == "" { + return "", fmt.Errorf("origin must include a host") + } + + host := hostname + if port := strings.TrimSpace(parsed.Port()); port != "" { + host = net.JoinHostPort(hostname, port) + } + + return scheme + "://" + host, nil +}