accounts/cmd/accountsvc/main.go
2026-01-23 23:16:37 +08:00

804 lines
20 KiB
Go

package main
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"log/slog"
"net"
"net/http"
"net/url"
"os"
"strings"
"time"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"github.com/spf13/cobra"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"account/api"
"account/config"
"account/internal/agentmode"
"account/internal/agentproto"
"account/internal/agentserver"
"account/internal/auth"
"account/internal/mailer"
"account/internal/model"
"account/internal/service"
"account/internal/store"
"account/internal/xrayconfig"
)
var (
configPath string
logLevel string
)
type mailerAdapter struct {
sender mailer.Sender
}
func (m mailerAdapter) Send(ctx context.Context, msg api.EmailMessage) error {
if m.sender == nil {
return nil
}
mail := mailer.Message{
To: append([]string(nil), msg.To...),
Subject: msg.Subject,
PlainBody: msg.PlainBody,
HTMLBody: msg.HTMLBody,
}
return m.sender.Send(ctx, mail)
}
func runServer(ctx context.Context, cfg *config.Config, logger *slog.Logger) error {
if ctx == nil {
ctx = context.Background()
}
if cfg == nil {
return errors.New("config is nil")
}
if logger == nil {
logger = slog.Default()
}
r := gin.New()
corsConfig := buildCORSConfig(logger, cfg.Server)
if corsConfig.AllowAllOrigins {
logger.Info("configured cors", "allowAllOrigins", true)
} else {
logger.Info("configured cors", "allowedOrigins", corsConfig.AllowOrigins)
}
r.Use(cors.New(corsConfig))
r.Use(gin.Recovery())
r.Use(func(c *gin.Context) {
start := time.Now()
c.Next()
logger.Info("request", "method", c.Request.Method, "path", c.FullPath(), "status", c.Writer.Status(), "latency", time.Since(start))
})
storeCfg := store.Config{
Driver: cfg.Store.Driver,
DSN: cfg.Store.DSN,
MaxOpenConns: cfg.Store.MaxOpenConns,
MaxIdleConns: cfg.Store.MaxIdleConns,
}
st, cleanup, err := store.New(ctx, storeCfg)
if err != nil {
return err
}
defer func() {
if cleanup == nil {
return
}
if err := cleanup(context.Background()); err != nil {
logger.Error("failed to close store", "err", err)
}
}()
var emailSender api.EmailSender
emailVerificationEnabled := true
smtpHost := strings.TrimSpace(cfg.SMTP.Host)
if smtpHost == "" {
emailVerificationEnabled = false
}
if smtpHost != "" && isExampleDomain(smtpHost) {
emailVerificationEnabled = false
logger.Warn("smtp host is a placeholder; disabling email delivery", "host", smtpHost)
smtpHost = ""
}
if smtpHost != "" {
tlsMode := mailer.ParseTLSMode(cfg.SMTP.TLS.Mode)
sender, err := mailer.New(mailer.Config{
Host: smtpHost,
Port: cfg.SMTP.Port,
Username: cfg.SMTP.Username,
Password: cfg.SMTP.Password,
From: cfg.SMTP.From,
ReplyTo: cfg.SMTP.ReplyTo,
Timeout: cfg.SMTP.Timeout,
TLSMode: tlsMode,
InsecureSkipVerify: cfg.SMTP.TLS.InsecureSkipVerify,
})
if err != nil {
return err
}
emailSender = mailerAdapter{sender: sender}
}
if emailSender == nil {
emailVerificationEnabled = false
}
// Initialize TokenService for authentication
var tokenService *auth.TokenService
if cfg.Auth.Enable {
accessExpiry := cfg.Auth.Token.AccessExpiry
if accessExpiry <= 0 {
accessExpiry = 1 * time.Hour
}
refreshExpiry := cfg.Auth.Token.RefreshExpiry
if refreshExpiry <= 0 {
refreshExpiry = 168 * time.Hour // 7 days
}
tokenService = auth.NewTokenService(auth.TokenConfig{
PublicToken: cfg.Auth.Token.PublicToken,
RefreshSecret: cfg.Auth.Token.RefreshSecret,
AccessSecret: cfg.Auth.Token.AccessSecret,
AccessExpiry: accessExpiry,
RefreshExpiry: refreshExpiry,
})
logger.Info("token service initialized", "auth_enabled", cfg.Auth.Enable)
}
gormDB, gormCleanup, err := openAdminSettingsDB(cfg.Store)
if err != nil {
return err
}
defer func() {
if gormCleanup != nil {
if err := gormCleanup(context.Background()); err != nil {
logger.Error("failed to close admin settings db", "err", err)
}
}
}()
service.SetDB(gormDB)
gormSource, err := xrayconfig.NewGormClientSource(gormDB)
if err != nil {
return err
}
var agentRegistry *agentserver.Registry
if len(cfg.Agents.Credentials) > 0 {
creds := make([]agentserver.Credential, 0, len(cfg.Agents.Credentials))
for _, c := range cfg.Agents.Credentials {
creds = append(creds, agentserver.Credential{
ID: c.ID,
Name: c.Name,
Token: c.Token,
Groups: append([]string(nil), c.Groups...),
})
}
agentRegistry, err = agentserver.NewRegistry(agentserver.Config{Credentials: creds})
if err != nil {
return err
}
}
var stopXraySync func(context.Context) error
if cfg.Xray.Sync.Enabled {
syncInterval := cfg.Xray.Sync.Interval
if syncInterval <= 0 {
syncInterval = 5 * time.Minute
}
outputPath := strings.TrimSpace(cfg.Xray.Sync.OutputPath)
if outputPath == "" {
outputPath = "/usr/local/etc/xray/config.json"
}
syncer, err := xrayconfig.NewPeriodicSyncer(xrayconfig.PeriodicOptions{
Logger: logger.With("component", "xray-sync"),
Interval: syncInterval,
Source: gormSource,
Generator: xrayconfig.Generator{Definition: xrayconfig.DefaultDefinition(), OutputPath: outputPath},
ValidateCommand: cfg.Xray.Sync.ValidateCommand,
RestartCommand: cfg.Xray.Sync.RestartCommand,
})
if err != nil {
return err
}
stop, err := syncer.Start(ctx)
if err != nil {
return err
}
logger.Info("xray periodic sync enabled", "interval", syncInterval, "output", outputPath)
stopXraySync = stop
}
if stopXraySync != nil {
defer func() {
waitCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := stopXraySync(waitCtx); err != nil {
logger.Warn("xray syncer shutdown", "err", err)
}
}()
}
options := []api.Option{
api.WithStore(st),
api.WithSessionTTL(cfg.Session.TTL),
}
if emailSender != nil {
options = append(options, api.WithEmailSender(emailSender))
}
options = append(options, api.WithEmailVerification(emailVerificationEnabled))
if tokenService != nil {
options = append(options, api.WithTokenService(tokenService))
}
if agentRegistry != nil {
options = append(options, api.WithAgentStatusReader(agentRegistry))
}
api.RegisterRoutes(r, options...)
if agentRegistry != nil {
registerAgentAPIRoutes(r, agentRegistry, gormSource, logger)
}
addr := strings.TrimSpace(cfg.Server.Addr)
if addr == "" {
addr = ":8080"
}
tlsSettings := cfg.Server.TLS
certFile := strings.TrimSpace(tlsSettings.CertFile)
keyFile := strings.TrimSpace(tlsSettings.KeyFile)
caFile := strings.TrimSpace(tlsSettings.CAFile)
clientCAFile := strings.TrimSpace(tlsSettings.ClientCAFile)
useTLS := tlsSettings.IsEnabled()
var tlsConfig *tls.Config
if useTLS {
if certFile == "" || keyFile == "" {
return fmt.Errorf("tls is enabled but certFile (%q) or keyFile (%q) is empty", certFile, keyFile)
}
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return fmt.Errorf("failed to load tls certificate: %w", err)
}
if caFile != "" {
caPEM, err := os.ReadFile(caFile)
if err != nil {
return fmt.Errorf("failed to read ca file %q: %w", caFile, err)
}
var block *pem.Block
existing := make(map[string]struct{}, len(cert.Certificate))
for _, c := range cert.Certificate {
existing[string(c)] = struct{}{}
}
for len(caPEM) > 0 {
block, caPEM = pem.Decode(caPEM)
if block == nil {
break
}
if block.Type != "CERTIFICATE" || len(block.Bytes) == 0 {
continue
}
if _, ok := existing[string(block.Bytes)]; ok {
continue
}
cert.Certificate = append(cert.Certificate, block.Bytes)
}
if len(cert.Certificate) == 0 {
return fmt.Errorf("ca file %q did not contain any certificates", caFile)
}
}
tlsConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
Certificates: []tls.Certificate{cert},
}
if clientCAFile != "" {
caBytes, err := os.ReadFile(clientCAFile)
if err != nil {
return err
}
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(caBytes) {
return errors.New("failed to parse client CA file")
}
tlsConfig.ClientCAs = pool
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
}
} else {
if certFile != "" || keyFile != "" {
logger.Info("TLS disabled; certificate paths will be ignored", "certFile", certFile, "keyFile", keyFile)
}
if clientCAFile != "" {
logger.Warn("client CA configured but TLS is disabled; ignoring", "clientCAFile", clientCAFile)
}
}
srv := &http.Server{
Addr: addr,
Handler: r,
ReadTimeout: cfg.Server.ReadTimeout,
WriteTimeout: cfg.Server.WriteTimeout,
}
if useTLS {
srv.TLSConfig = tlsConfig
}
logger.Info("starting account service", "addr", addr, "tls", useTLS)
var listenCertFile, listenKeyFile string
if useTLS {
if tlsSettings.RedirectHTTP {
go func() {
redirectAddr := deriveRedirectAddr(addr)
redirectSrv := &http.Server{
Addr: redirectAddr,
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
host := r.Host
if host == "" {
host = redirectAddr
}
target := "https://" + host + r.URL.RequestURI()
http.Redirect(w, r, target, http.StatusPermanentRedirect)
}),
}
if err := redirectSrv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
logger.Error("http redirect listener exited", "err", err)
}
}()
}
if tlsConfig != nil && len(tlsConfig.Certificates) > 0 {
listenCertFile = ""
listenKeyFile = ""
} else {
listenCertFile = certFile
listenKeyFile = keyFile
}
if err := srv.ListenAndServeTLS(listenCertFile, listenKeyFile); err != nil {
if !errors.Is(err, http.ErrServerClosed) {
logger.Error("account service shutdown", "err", err)
return err
}
}
} else {
if err := srv.ListenAndServe(); err != nil {
if !errors.Is(err, http.ErrServerClosed) {
logger.Error("account service shutdown", "err", err)
return err
}
}
}
return nil
}
func runServerAndAgent(ctx context.Context, cfg *config.Config, logger *slog.Logger) error {
if ctx == nil {
ctx = context.Background()
}
if cfg == nil {
return errors.New("config is nil")
}
agentCtx, cancel := context.WithCancel(ctx)
defer cancel()
agentErrCh := make(chan error, 1)
go func() {
agentErrCh <- runAgent(agentCtx, cfg, logger)
}()
agentPending := true
select {
case err := <-agentErrCh:
agentPending = false
if err == nil {
err = errors.New("agent exited unexpectedly")
}
return fmt.Errorf("agent startup failed: %w", err)
default:
}
serverErr := runServer(ctx, cfg, logger)
cancel()
var agentErr error
if agentPending {
agentErr = <-agentErrCh
}
if serverErr != nil {
return serverErr
}
if agentErr != nil {
return agentErr
}
return nil
}
func runAgent(ctx context.Context, cfg *config.Config, logger *slog.Logger) error {
if cfg == nil {
return errors.New("config is nil")
}
if logger == nil {
logger = slog.Default()
}
if !cfg.Xray.Sync.Enabled {
logger.Warn("xray sync is disabled in configuration; agent mode will still attempt to manage xray config")
}
options := agentmode.Options{
Logger: logger.With("component", "agent"),
Agent: cfg.Agent,
Xray: cfg.Xray,
}
return agentmode.Run(ctx, options)
}
const agentIdentityContextKey = "xcontrol-account-agent-identity"
func registerAgentAPIRoutes(r *gin.Engine, registry *agentserver.Registry, source xrayconfig.ClientSource, logger *slog.Logger) {
if registry == nil {
return
}
group := r.Group("/api/agent/v1")
group.Use(agentAuthMiddleware(registry))
group.GET("/users", agentListUsersHandler(source))
group.POST("/status", agentReportStatusHandler(registry, logger))
}
func agentAuthMiddleware(registry *agentserver.Registry) gin.HandlerFunc {
return func(c *gin.Context) {
if registry == nil {
c.AbortWithStatusJSON(http.StatusServiceUnavailable, gin.H{"error": "agent_registry_unavailable", "message": "agent registry not configured"})
return
}
token := extractBearerToken(c.GetHeader("Authorization"))
if token == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "agent_token_required", "message": "agent token is required"})
return
}
identity, ok := registry.Authenticate(token)
if !ok || identity == nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid_agent_token", "message": "invalid agent token"})
return
}
c.Set(agentIdentityContextKey, *identity)
c.Next()
}
}
func agentListUsersHandler(source xrayconfig.ClientSource) gin.HandlerFunc {
return func(c *gin.Context) {
if source == nil {
c.AbortWithStatusJSON(http.StatusServiceUnavailable, gin.H{"error": "client_source_unavailable", "message": "client source not configured"})
return
}
clients, err := source.ListClients(c.Request.Context())
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "list_clients_failed", "message": "failed to list clients"})
return
}
response := agentproto.ClientListResponse{
Clients: clients,
Total: len(clients),
GeneratedAt: time.Now().UTC(),
}
c.JSON(http.StatusOK, response)
}
}
func agentReportStatusHandler(registry *agentserver.Registry, logger *slog.Logger) gin.HandlerFunc {
return func(c *gin.Context) {
value, exists := c.Get(agentIdentityContextKey)
if !exists {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "agent_identity_missing", "message": "agent identity missing"})
return
}
identity, ok := value.(agentserver.Identity)
if !ok {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "agent_identity_invalid", "message": "agent identity malformed"})
return
}
var report agentproto.StatusReport
if err := c.ShouldBindJSON(&report); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid_status_payload", "message": "invalid status payload"})
return
}
registry.ReportStatus(identity, report)
if logger != nil {
logger.Info("agent status updated", "agent", identity.ID, "healthy", report.Healthy, "clients", report.Xray.Clients)
}
c.Status(http.StatusNoContent)
}
}
func extractBearerToken(header string) string {
header = strings.TrimSpace(header)
if header == "" {
return ""
}
const prefix = "Bearer "
if strings.HasPrefix(header, prefix) {
header = header[len(prefix):]
}
return strings.TrimSpace(header)
}
var rootCmd = &cobra.Command{
Use: "xcontrol-account",
Short: "Start the xcontrol account service",
RunE: func(cmd *cobra.Command, args []string) error {
cfg, err := config.Load(configPath)
if err != nil {
return err
}
if logLevel != "" {
cfg.Log.Level = logLevel
}
level := slog.LevelInfo
switch strings.ToLower(strings.TrimSpace(cfg.Log.Level)) {
case "debug":
level = slog.LevelDebug
case "warn", "warning":
level = slog.LevelWarn
case "error":
level = slog.LevelError
}
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: level}))
slog.SetDefault(logger)
ctx := context.Background()
mode := strings.ToLower(strings.TrimSpace(cfg.Mode))
if mode == "" {
mode = "server"
}
switch mode {
case "server":
return runServer(ctx, cfg, logger)
case "agent":
return runAgent(ctx, cfg, logger)
case "server-agent", "all", "combined":
return runServerAndAgent(ctx, cfg, logger)
default:
return fmt.Errorf("unsupported mode %q", cfg.Mode)
}
},
}
func openAdminSettingsDB(cfg config.Store) (*gorm.DB, func(context.Context) error, error) {
driver := strings.ToLower(strings.TrimSpace(cfg.Driver))
var (
db *gorm.DB
err error
)
switch driver {
case "", "memory":
db, err = gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
case "postgres", "postgresql", "pgx":
if strings.TrimSpace(cfg.DSN) == "" {
return nil, nil, errors.New("admin settings database requires a dsn")
}
db, err = gorm.Open(postgres.Open(cfg.DSN), &gorm.Config{})
default:
return nil, nil, fmt.Errorf("unsupported admin settings driver %q", cfg.Driver)
}
if err != nil {
return nil, nil, err
}
if err := db.AutoMigrate(&model.AdminSetting{}); err != nil {
return nil, nil, err
}
sqlDB, err := db.DB()
if err != nil {
return nil, nil, err
}
if cfg.MaxOpenConns > 0 {
sqlDB.SetMaxOpenConns(cfg.MaxOpenConns)
}
if cfg.MaxIdleConns > 0 {
sqlDB.SetMaxIdleConns(cfg.MaxIdleConns)
}
cleanup := func(context.Context) error {
return sqlDB.Close()
}
return db, cleanup, nil
}
func init() {
rootCmd.Flags().StringVar(&configPath, "config", "", "path to xcontrol account configuration file")
rootCmd.Flags().StringVar(&logLevel, "log-level", "", "log level (debug, info, warn, error)")
}
func main() {
if err := rootCmd.Execute(); err != nil {
os.Exit(1)
}
}
func isExampleDomain(host string) bool {
normalized := strings.ToLower(strings.TrimSpace(host))
if normalized == "" {
return false
}
if h, _, ok := strings.Cut(normalized, ":"); ok {
normalized = h
}
if normalized == "example.com" {
return true
}
return strings.HasSuffix(normalized, ".example.com")
}
func buildCORSConfig(logger *slog.Logger, serverCfg config.Server) cors.Config {
allowOrigins, allowAll := resolveAllowedOrigins(logger, serverCfg)
cfg := cors.Config{
AllowMethods: []string{
http.MethodGet,
http.MethodHead,
http.MethodPost,
http.MethodPut,
http.MethodPatch,
http.MethodDelete,
http.MethodOptions,
},
AllowHeaders: []string{
"Authorization",
"Content-Type",
"Accept",
"Origin",
"X-Requested-With",
"Cookie",
},
ExposeHeaders: []string{
"Content-Length",
},
MaxAge: 12 * time.Hour,
}
if allowAll {
cfg.AllowAllOrigins = true
cfg.AllowCredentials = false
} else {
cfg.AllowOrigins = allowOrigins
cfg.AllowCredentials = true
}
return cfg
}
func resolveAllowedOrigins(logger *slog.Logger, serverCfg config.Server) ([]string, bool) {
rawOrigins := serverCfg.AllowedOrigins
seen := make(map[string]struct{}, len(rawOrigins))
origins := make([]string, 0, len(rawOrigins))
allowAll := false
for _, origin := range rawOrigins {
trimmed := strings.TrimSpace(origin)
if trimmed == "" {
continue
}
if trimmed == "*" {
allowAll = true
continue
}
normalized, err := parseOrigin(trimmed)
if err != nil {
logger.Warn("ignoring invalid cors origin", "origin", origin, "err", err)
continue
}
if _, exists := seen[normalized]; exists {
continue
}
seen[normalized] = struct{}{}
origins = append(origins, normalized)
}
if allowAll {
return nil, true
}
if len(origins) == 0 {
publicURL := strings.TrimSpace(serverCfg.PublicURL)
if publicURL != "" {
normalized, err := parseOrigin(publicURL)
if err != nil {
logger.Warn("invalid server public url; falling back to defaults", "publicUrl", publicURL, "err", err)
} else {
origins = append(origins, normalized)
}
}
}
if len(origins) == 0 {
origins = []string{
"http://localhost:3001",
"http://127.0.0.1:3001",
}
}
return origins, false
}
func parseOrigin(value string) (string, error) {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
return "", fmt.Errorf("origin is empty")
}
normalized := trimmed
if !strings.Contains(normalized, "://") {
normalized = "https://" + normalized
}
parsed, err := url.Parse(normalized)
if err != nil {
return "", err
}
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
if scheme == "" {
return "", fmt.Errorf("origin must include a scheme")
}
hostname := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
if hostname == "" {
return "", fmt.Errorf("origin must include a host")
}
host := hostname
if port := strings.TrimSpace(parsed.Port()); port != "" {
host = net.JoinHostPort(hostname, port)
}
return scheme + "://" + host, nil
}
func deriveRedirectAddr(addr string) string {
host, port, err := net.SplitHostPort(strings.TrimSpace(addr))
if err != nil {
trimmed := strings.TrimSpace(addr)
if strings.HasPrefix(trimmed, ":") {
port = strings.TrimPrefix(trimmed, ":")
if port == "" || port == "443" {
return ":80"
}
return ":" + port
}
return ":80"
}
if port == "" || port == "443" {
port = "80"
}
return net.JoinHostPort(host, port)
}