Tighten CORS policies for production domains (#417)

This commit is contained in:
shenlan 2025-10-06 11:30:19 +08:00 committed by GitHub
parent bfae1b53fc
commit 7eb0b74b70
7 changed files with 201 additions and 3 deletions

View File

@ -337,6 +337,7 @@ func buildCORSConfig(logger *slog.Logger, serverCfg config.Server) cors.Config {
"Accept",
"Origin",
"X-Requested-With",
"Cookie",
},
ExposeHeaders: []string{
"Content-Length",

View File

@ -7,6 +7,8 @@ server:
writeTimeout: 15s
publicUrl: "http://localhost:8080"
allowedOrigins:
- "https://www.svc.plus"
- "https://global-homepage.svc.plus"
- "https://account.svc.plus"
- "https://localhost:8443"
- "http://localhost:8080"

View File

@ -85,6 +85,7 @@ var rootCmd = &cobra.Command{
r := server.New(
api.RegisterRoutes(conn, cfg.Sync.Repo.Proxy),
)
server.UseCORS(r, logger, cfg.Server)
addr := cfg.Server.Addr
if addr == "" {

View File

@ -153,9 +153,11 @@ func (d Duration) String() string {
// ServerCfg contains HTTP server runtime configuration.
type ServerCfg struct {
Addr string `yaml:"addr"`
ReadTimeout Duration `yaml:"readTimeout"`
WriteTimeout Duration `yaml:"writeTimeout"`
Addr string `yaml:"addr"`
ReadTimeout Duration `yaml:"readTimeout"`
WriteTimeout Duration `yaml:"writeTimeout"`
PublicURL string `yaml:"publicUrl"`
AllowedOrigins []string `yaml:"allowedOrigins"`
}
type Config struct {

View File

@ -2,6 +2,7 @@ package config
import (
"os"
"reflect"
"testing"
"time"
)
@ -37,4 +38,19 @@ func TestLoad(t *testing.T) {
if cfg.Server.WriteTimeout.Duration != 15*time.Second {
t.Fatalf("unexpected server write timeout %s", cfg.Server.WriteTimeout)
}
if cfg.Server.PublicURL != "https://www.svc.plus" {
t.Fatalf("unexpected server public url %q", cfg.Server.PublicURL)
}
wantOrigins := []string{
"https://www.svc.plus",
"https://global-homepage.svc.plus",
"https://account.svc.plus",
"http://localhost:3000",
"http://127.0.0.1:3000",
"http://localhost:3001",
"http://127.0.0.1:3001",
}
if !reflect.DeepEqual(cfg.Server.AllowedOrigins, wantOrigins) {
t.Fatalf("unexpected server allowed origins %#v", cfg.Server.AllowedOrigins)
}
}

View File

@ -2,6 +2,15 @@ server:
addr: ":8090"
readTimeout: 15s
writeTimeout: 15s
publicUrl: "https://www.svc.plus"
allowedOrigins:
- "https://www.svc.plus"
- "https://global-homepage.svc.plus"
- "https://account.svc.plus"
- "http://localhost:3000"
- "http://127.0.0.1:3000"
- "http://localhost:3001"
- "http://127.0.0.1:3001"
global:
redis:

167
server/cors.go Normal file
View File

@ -0,0 +1,167 @@
package server
import (
"fmt"
"log/slog"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"xcontrol/server/config"
)
// UseCORS applies a restrictive CORS policy to the provided gin engine based on the
// server configuration. When the configuration specifies explicit origins the
// middleware allows credentials and mirrors the origin. When the configuration
// uses the "*" wildcard, credentials are disabled to remain compliant with the
// Fetch specification.
func UseCORS(r *gin.Engine, logger *slog.Logger, serverCfg config.ServerCfg) {
if r == nil {
return
}
if logger == nil {
logger = slog.Default()
}
corsCfg := buildCORSConfig(logger, serverCfg)
if corsCfg.AllowAllOrigins {
logger.Info("configured cors", "allowAllOrigins", true)
} else {
logger.Info("configured cors", "allowedOrigins", corsCfg.AllowOrigins)
}
r.Use(cors.New(corsCfg))
}
func buildCORSConfig(logger *slog.Logger, serverCfg config.ServerCfg) cors.Config {
allowOrigins, allowAll := resolveAllowedOrigins(logger, serverCfg)
cfg := cors.Config{
AllowMethods: []string{
http.MethodGet,
http.MethodHead,
http.MethodPost,
http.MethodPut,
http.MethodPatch,
http.MethodDelete,
http.MethodOptions,
},
AllowHeaders: []string{
"Authorization",
"Content-Type",
"Accept",
"Origin",
"X-Requested-With",
"Cookie",
},
ExposeHeaders: []string{
"Content-Length",
},
MaxAge: 12 * time.Hour,
}
if allowAll {
cfg.AllowAllOrigins = true
cfg.AllowCredentials = false
} else {
cfg.AllowOrigins = allowOrigins
cfg.AllowCredentials = true
}
return cfg
}
func resolveAllowedOrigins(logger *slog.Logger, serverCfg config.ServerCfg) ([]string, bool) {
rawOrigins := serverCfg.AllowedOrigins
seen := make(map[string]struct{}, len(rawOrigins))
origins := make([]string, 0, len(rawOrigins))
allowAll := false
for _, origin := range rawOrigins {
trimmed := strings.TrimSpace(origin)
if trimmed == "" {
continue
}
if trimmed == "*" {
allowAll = true
continue
}
normalized, err := parseOrigin(trimmed)
if err != nil {
logger.Warn("ignoring invalid cors origin", "origin", origin, "err", err)
continue
}
if _, exists := seen[normalized]; exists {
continue
}
seen[normalized] = struct{}{}
origins = append(origins, normalized)
}
if allowAll {
return nil, true
}
if len(origins) == 0 {
publicURL := strings.TrimSpace(serverCfg.PublicURL)
if publicURL != "" {
normalized, err := parseOrigin(publicURL)
if err != nil {
logger.Warn("invalid server public url; falling back to defaults", "publicUrl", publicURL, "err", err)
} else {
origins = append(origins, normalized)
}
}
}
if len(origins) == 0 {
origins = []string{
"http://localhost:3000",
"http://127.0.0.1:3000",
"http://localhost:3001",
"http://127.0.0.1:3001",
"https://localhost:8443",
}
}
return origins, false
}
func parseOrigin(value string) (string, error) {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
return "", fmt.Errorf("origin is empty")
}
normalized := trimmed
if !strings.Contains(normalized, "://") {
normalized = "https://" + normalized
}
parsed, err := url.Parse(normalized)
if err != nil {
return "", err
}
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
if scheme == "" {
return "", fmt.Errorf("origin must include a scheme")
}
hostname := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
if hostname == "" {
return "", fmt.Errorf("origin must include a host")
}
host := hostname
if port := strings.TrimSpace(parsed.Port()); port != "" {
host = net.JoinHostPort(hostname, port)
}
return scheme + "://" + host, nil
}