Allow combined server and agent mode (#595)

This commit is contained in:
shenlan 2025-10-27 21:02:03 +08:00 committed by GitHub
parent 41ba7c3cc0
commit 7a4b2f0e00
17 changed files with 1558 additions and 274 deletions

View File

@ -0,0 +1,69 @@
package api
import (
"net/http"
"time"
"github.com/gin-gonic/gin"
"xcontrol/account/internal/agentserver"
)
type agentStatusReader interface {
Statuses() []agentserver.StatusSnapshot
}
type agentStatusEntry struct {
ID string `json:"id"`
Name string `json:"name,omitempty"`
Groups []string `json:"groups,omitempty"`
Healthy bool `json:"healthy"`
Message string `json:"message,omitempty"`
Users int `json:"users"`
SyncRevision string `json:"syncRevision,omitempty"`
UpdatedAt time.Time `json:"updatedAt"`
Xray agentXraySummary `json:"xray"`
}
type agentXraySummary struct {
Running bool `json:"running"`
Clients int `json:"clients"`
LastSync *time.Time `json:"lastSync,omitempty"`
}
func (h *handler) adminAgentStatus(c *gin.Context) {
if h.agentStatusReader == nil {
respondError(c, http.StatusServiceUnavailable, "agent_status_unavailable", "agent registry is not configured")
return
}
if _, ok := h.requireAdminOrOperator(c); !ok {
return
}
snapshots := h.agentStatusReader.Statuses()
entries := make([]agentStatusEntry, 0, len(snapshots))
for _, snapshot := range snapshots {
entry := agentStatusEntry{
ID: snapshot.Agent.ID,
Name: snapshot.Agent.Name,
Groups: append([]string(nil), snapshot.Agent.Groups...),
Healthy: snapshot.Report.Healthy,
Message: snapshot.Report.Message,
Users: snapshot.Report.Users,
SyncRevision: snapshot.Report.SyncRevision,
UpdatedAt: snapshot.UpdatedAt,
Xray: agentXraySummary{
Running: snapshot.Report.Xray.Running,
Clients: snapshot.Report.Xray.Clients,
},
}
if snapshot.Report.Xray.LastSync != nil {
last := *snapshot.Report.Xray.LastSync
entry.Xray.LastSync = &last
}
entries = append(entries, entry)
}
c.JSON(http.StatusOK, gin.H{"agents": entries})
}

View File

@ -85,4 +85,5 @@ func (h *handler) resolveSessionToken(c *gin.Context) string {
func registerAdminRoutes(group *gin.RouterGroup, h *handler) {
admin := group.Group("/admin")
admin.GET("/users/metrics", h.adminUsersMetrics)
admin.GET("/agents/status", h.adminAgentStatus)
}

View File

@ -57,6 +57,7 @@ type handler struct {
passwordResets map[string]passwordReset
resetMu sync.RWMutex
metricsProvider service.UserMetricsProvider
agentStatusReader agentStatusReader
}
type mfaChallenge struct {
@ -137,6 +138,15 @@ func WithUserMetricsProvider(provider service.UserMetricsProvider) Option {
}
}
// WithAgentStatusReader wires the agent status reader used by admin endpoints.
func WithAgentStatusReader(reader agentStatusReader) Option {
return func(h *handler) {
if reader != nil {
h.agentStatusReader = reader
}
}
}
// WithPasswordResetTTL overrides the default TTL for password reset tokens.
func WithPasswordResetTTL(ttl time.Duration) Option {
return func(h *handler) {

View File

@ -25,6 +25,9 @@ import (
"xcontrol/account/api"
"xcontrol/account/config"
"xcontrol/account/internal/agentmode"
"xcontrol/account/internal/agentproto"
"xcontrol/account/internal/agentserver"
"xcontrol/account/internal/mailer"
"xcontrol/account/internal/model"
"xcontrol/account/internal/service"
@ -54,6 +57,474 @@ func (m mailerAdapter) Send(ctx context.Context, msg api.EmailMessage) error {
return m.sender.Send(ctx, mail)
}
func runServer(ctx context.Context, cfg *config.Config, logger *slog.Logger) error {
if ctx == nil {
ctx = context.Background()
}
if cfg == nil {
return errors.New("config is nil")
}
if logger == nil {
logger = slog.Default()
}
r := gin.New()
corsConfig := buildCORSConfig(logger, cfg.Server)
if corsConfig.AllowAllOrigins {
logger.Info("configured cors", "allowAllOrigins", true)
} else {
logger.Info("configured cors", "allowedOrigins", corsConfig.AllowOrigins)
}
r.Use(cors.New(corsConfig))
r.Use(gin.Recovery())
r.Use(func(c *gin.Context) {
start := time.Now()
c.Next()
logger.Info("request", "method", c.Request.Method, "path", c.FullPath(), "status", c.Writer.Status(), "latency", time.Since(start))
})
storeCfg := store.Config{
Driver: cfg.Store.Driver,
DSN: cfg.Store.DSN,
MaxOpenConns: cfg.Store.MaxOpenConns,
MaxIdleConns: cfg.Store.MaxIdleConns,
}
st, cleanup, err := store.New(ctx, storeCfg)
if err != nil {
return err
}
defer func() {
if cleanup == nil {
return
}
if err := cleanup(context.Background()); err != nil {
logger.Error("failed to close store", "err", err)
}
}()
var emailSender api.EmailSender
emailVerificationEnabled := true
smtpHost := strings.TrimSpace(cfg.SMTP.Host)
if smtpHost == "" {
emailVerificationEnabled = false
}
if smtpHost != "" && isExampleDomain(smtpHost) {
emailVerificationEnabled = false
logger.Warn("smtp host is a placeholder; disabling email delivery", "host", smtpHost)
smtpHost = ""
}
if smtpHost != "" {
tlsMode := mailer.ParseTLSMode(cfg.SMTP.TLS.Mode)
sender, err := mailer.New(mailer.Config{
Host: smtpHost,
Port: cfg.SMTP.Port,
Username: cfg.SMTP.Username,
Password: cfg.SMTP.Password,
From: cfg.SMTP.From,
ReplyTo: cfg.SMTP.ReplyTo,
Timeout: cfg.SMTP.Timeout,
TLSMode: tlsMode,
InsecureSkipVerify: cfg.SMTP.TLS.InsecureSkipVerify,
})
if err != nil {
return err
}
emailSender = mailerAdapter{sender: sender}
}
if emailSender == nil {
emailVerificationEnabled = false
}
gormDB, gormCleanup, err := openAdminSettingsDB(cfg.Store)
if err != nil {
return err
}
defer func() {
if gormCleanup != nil {
if err := gormCleanup(context.Background()); err != nil {
logger.Error("failed to close admin settings db", "err", err)
}
}
}()
service.SetDB(gormDB)
gormSource, err := xrayconfig.NewGormClientSource(gormDB)
if err != nil {
return err
}
var agentRegistry *agentserver.Registry
if len(cfg.Agents.Credentials) > 0 {
creds := make([]agentserver.Credential, 0, len(cfg.Agents.Credentials))
for _, c := range cfg.Agents.Credentials {
creds = append(creds, agentserver.Credential{
ID: c.ID,
Name: c.Name,
Token: c.Token,
Groups: append([]string(nil), c.Groups...),
})
}
agentRegistry, err = agentserver.NewRegistry(agentserver.Config{Credentials: creds})
if err != nil {
return err
}
}
var stopXraySync func(context.Context) error
if cfg.Xray.Sync.Enabled {
syncInterval := cfg.Xray.Sync.Interval
if syncInterval <= 0 {
syncInterval = 5 * time.Minute
}
templatePath := strings.TrimSpace(cfg.Xray.Sync.TemplatePath)
if templatePath == "" {
templatePath = filepath.Join("account", "config", "xray.config.template.json")
}
outputPath := strings.TrimSpace(cfg.Xray.Sync.OutputPath)
if outputPath == "" {
outputPath = "/usr/local/etc/xray/config.json"
}
syncer, err := xrayconfig.NewPeriodicSyncer(xrayconfig.PeriodicOptions{
Logger: logger.With("component", "xray-sync"),
Interval: syncInterval,
Source: gormSource,
Generator: xrayconfig.Generator{TemplatePath: templatePath, OutputPath: outputPath},
ValidateCommand: cfg.Xray.Sync.ValidateCommand,
RestartCommand: cfg.Xray.Sync.RestartCommand,
})
if err != nil {
return err
}
stop, err := syncer.Start(ctx)
if err != nil {
return err
}
logger.Info("xray periodic sync enabled", "interval", syncInterval, "output", outputPath)
stopXraySync = stop
}
if stopXraySync != nil {
defer func() {
waitCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := stopXraySync(waitCtx); err != nil {
logger.Warn("xray syncer shutdown", "err", err)
}
}()
}
options := []api.Option{
api.WithStore(st),
api.WithSessionTTL(cfg.Session.TTL),
}
if emailSender != nil {
options = append(options, api.WithEmailSender(emailSender))
}
options = append(options, api.WithEmailVerification(emailVerificationEnabled))
if agentRegistry != nil {
options = append(options, api.WithAgentStatusReader(agentRegistry))
}
api.RegisterRoutes(r, options...)
if agentRegistry != nil {
registerAgentAPIRoutes(r, agentRegistry, gormSource, logger)
}
addr := strings.TrimSpace(cfg.Server.Addr)
if addr == "" {
addr = ":8080"
}
tlsSettings := cfg.Server.TLS
certFile := strings.TrimSpace(tlsSettings.CertFile)
keyFile := strings.TrimSpace(tlsSettings.KeyFile)
caFile := strings.TrimSpace(tlsSettings.CAFile)
clientCAFile := strings.TrimSpace(tlsSettings.ClientCAFile)
useTLS := tlsSettings.IsEnabled()
var tlsConfig *tls.Config
if useTLS {
if certFile == "" || keyFile == "" {
return fmt.Errorf("tls is enabled but certFile (%q) or keyFile (%q) is empty", certFile, keyFile)
}
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return fmt.Errorf("failed to load tls certificate: %w", err)
}
if caFile != "" {
caPEM, err := os.ReadFile(caFile)
if err != nil {
return fmt.Errorf("failed to read ca file %q: %w", caFile, err)
}
var block *pem.Block
existing := make(map[string]struct{}, len(cert.Certificate))
for _, c := range cert.Certificate {
existing[string(c)] = struct{}{}
}
for len(caPEM) > 0 {
block, caPEM = pem.Decode(caPEM)
if block == nil {
break
}
if block.Type != "CERTIFICATE" || len(block.Bytes) == 0 {
continue
}
if _, ok := existing[string(block.Bytes)]; ok {
continue
}
cert.Certificate = append(cert.Certificate, block.Bytes)
}
if len(cert.Certificate) == 0 {
return fmt.Errorf("ca file %q did not contain any certificates", caFile)
}
}
tlsConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
Certificates: []tls.Certificate{cert},
}
if clientCAFile != "" {
caBytes, err := os.ReadFile(clientCAFile)
if err != nil {
return err
}
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(caBytes) {
return errors.New("failed to parse client CA file")
}
tlsConfig.ClientCAs = pool
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
}
} else {
if certFile != "" || keyFile != "" {
logger.Info("TLS disabled; certificate paths will be ignored", "certFile", certFile, "keyFile", keyFile)
}
if clientCAFile != "" {
logger.Warn("client CA configured but TLS is disabled; ignoring", "clientCAFile", clientCAFile)
}
}
srv := &http.Server{
Addr: addr,
Handler: r,
ReadTimeout: cfg.Server.ReadTimeout,
WriteTimeout: cfg.Server.WriteTimeout,
}
if useTLS {
srv.TLSConfig = tlsConfig
}
logger.Info("starting account service", "addr", addr, "tls", useTLS)
var listenCertFile, listenKeyFile string
if useTLS {
if tlsSettings.RedirectHTTP {
go func() {
redirectAddr := deriveRedirectAddr(addr)
redirectSrv := &http.Server{
Addr: redirectAddr,
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
host := r.Host
if host == "" {
host = redirectAddr
}
target := "https://" + host + r.URL.RequestURI()
http.Redirect(w, r, target, http.StatusPermanentRedirect)
}),
}
if err := redirectSrv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
logger.Error("http redirect listener exited", "err", err)
}
}()
}
if tlsConfig != nil && len(tlsConfig.Certificates) > 0 {
listenCertFile = ""
listenKeyFile = ""
} else {
listenCertFile = certFile
listenKeyFile = keyFile
}
if err := srv.ListenAndServeTLS(listenCertFile, listenKeyFile); err != nil {
if !errors.Is(err, http.ErrServerClosed) {
logger.Error("account service shutdown", "err", err)
return err
}
}
} else {
if err := srv.ListenAndServe(); err != nil {
if !errors.Is(err, http.ErrServerClosed) {
logger.Error("account service shutdown", "err", err)
return err
}
}
}
return nil
}
func runServerAndAgent(ctx context.Context, cfg *config.Config, logger *slog.Logger) error {
if ctx == nil {
ctx = context.Background()
}
if cfg == nil {
return errors.New("config is nil")
}
agentCtx, cancel := context.WithCancel(ctx)
defer cancel()
agentErrCh := make(chan error, 1)
go func() {
agentErrCh <- runAgent(agentCtx, cfg, logger)
}()
agentPending := true
select {
case err := <-agentErrCh:
agentPending = false
if err == nil {
err = errors.New("agent exited unexpectedly")
}
return fmt.Errorf("agent startup failed: %w", err)
default:
}
serverErr := runServer(ctx, cfg, logger)
cancel()
var agentErr error
if agentPending {
agentErr = <-agentErrCh
}
if serverErr != nil {
return serverErr
}
if agentErr != nil {
return agentErr
}
return nil
}
func runAgent(ctx context.Context, cfg *config.Config, logger *slog.Logger) error {
if cfg == nil {
return errors.New("config is nil")
}
if logger == nil {
logger = slog.Default()
}
if !cfg.Xray.Sync.Enabled {
logger.Warn("xray sync is disabled in configuration; agent mode will still attempt to manage xray config")
}
options := agentmode.Options{
Logger: logger.With("component", "agent"),
Agent: cfg.Agent,
Xray: cfg.Xray,
}
return agentmode.Run(ctx, options)
}
const agentIdentityContextKey = "xcontrol-account-agent-identity"
func registerAgentAPIRoutes(r *gin.Engine, registry *agentserver.Registry, source xrayconfig.ClientSource, logger *slog.Logger) {
if registry == nil {
return
}
group := r.Group("/api/agent/v1")
group.Use(agentAuthMiddleware(registry))
group.GET("/users", agentListUsersHandler(source))
group.POST("/status", agentReportStatusHandler(registry, logger))
}
func agentAuthMiddleware(registry *agentserver.Registry) gin.HandlerFunc {
return func(c *gin.Context) {
if registry == nil {
c.AbortWithStatusJSON(http.StatusServiceUnavailable, gin.H{"error": "agent_registry_unavailable", "message": "agent registry not configured"})
return
}
token := extractBearerToken(c.GetHeader("Authorization"))
if token == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "agent_token_required", "message": "agent token is required"})
return
}
identity, ok := registry.Authenticate(token)
if !ok || identity == nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid_agent_token", "message": "invalid agent token"})
return
}
c.Set(agentIdentityContextKey, *identity)
c.Next()
}
}
func agentListUsersHandler(source xrayconfig.ClientSource) gin.HandlerFunc {
return func(c *gin.Context) {
if source == nil {
c.AbortWithStatusJSON(http.StatusServiceUnavailable, gin.H{"error": "client_source_unavailable", "message": "client source not configured"})
return
}
clients, err := source.ListClients(c.Request.Context())
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "list_clients_failed", "message": "failed to list clients"})
return
}
response := agentproto.ClientListResponse{
Clients: clients,
Total: len(clients),
GeneratedAt: time.Now().UTC(),
}
c.JSON(http.StatusOK, response)
}
}
func agentReportStatusHandler(registry *agentserver.Registry, logger *slog.Logger) gin.HandlerFunc {
return func(c *gin.Context) {
value, exists := c.Get(agentIdentityContextKey)
if !exists {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "agent_identity_missing", "message": "agent identity missing"})
return
}
identity, ok := value.(agentserver.Identity)
if !ok {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "agent_identity_invalid", "message": "agent identity malformed"})
return
}
var report agentproto.StatusReport
if err := c.ShouldBindJSON(&report); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid_status_payload", "message": "invalid status payload"})
return
}
registry.ReportStatus(identity, report)
if logger != nil {
logger.Info("agent status updated", "agent", identity.ID, "healthy", report.Healthy, "clients", report.Xray.Clients)
}
c.Status(http.StatusNoContent)
}
}
func extractBearerToken(header string) string {
header = strings.TrimSpace(header)
if header == "" {
return ""
}
const prefix = "Bearer "
if strings.HasPrefix(header, prefix) {
header = header[len(prefix):]
}
return strings.TrimSpace(header)
}
var rootCmd = &cobra.Command{
Use: "xcontrol-account",
Short: "Start the xcontrol account service",
@ -79,284 +550,22 @@ var rootCmd = &cobra.Command{
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: level}))
slog.SetDefault(logger)
r := gin.New()
corsConfig := buildCORSConfig(logger, cfg.Server)
if corsConfig.AllowAllOrigins {
logger.Info("configured cors", "allowAllOrigins", true)
} else {
logger.Info("configured cors", "allowedOrigins", corsConfig.AllowOrigins)
}
r.Use(cors.New(corsConfig))
r.Use(gin.Recovery())
r.Use(func(c *gin.Context) {
start := time.Now()
c.Next()
logger.Info("request", "method", c.Request.Method, "path", c.FullPath(), "status", c.Writer.Status(), "latency", time.Since(start))
})
ctx := context.Background()
storeCfg := store.Config{
Driver: cfg.Store.Driver,
DSN: cfg.Store.DSN,
MaxOpenConns: cfg.Store.MaxOpenConns,
MaxIdleConns: cfg.Store.MaxIdleConns,
mode := strings.ToLower(strings.TrimSpace(cfg.Mode))
if mode == "" {
mode = "server"
}
st, cleanup, err := store.New(ctx, storeCfg)
if err != nil {
return err
switch mode {
case "server":
return runServer(ctx, cfg, logger)
case "agent":
return runAgent(ctx, cfg, logger)
case "server-agent", "all", "combined":
return runServerAndAgent(ctx, cfg, logger)
default:
return fmt.Errorf("unsupported mode %q", cfg.Mode)
}
defer func() {
if cleanup == nil {
return
}
if err := cleanup(context.Background()); err != nil {
logger.Error("failed to close store", "err", err)
}
}()
var emailSender api.EmailSender
emailVerificationEnabled := true
smtpHost := strings.TrimSpace(cfg.SMTP.Host)
if smtpHost == "" {
emailVerificationEnabled = false
}
if smtpHost != "" && isExampleDomain(smtpHost) {
emailVerificationEnabled = false
logger.Warn("smtp host is a placeholder; disabling email delivery", "host", smtpHost)
smtpHost = ""
}
if smtpHost != "" {
tlsMode := mailer.ParseTLSMode(cfg.SMTP.TLS.Mode)
sender, err := mailer.New(mailer.Config{
Host: smtpHost,
Port: cfg.SMTP.Port,
Username: cfg.SMTP.Username,
Password: cfg.SMTP.Password,
From: cfg.SMTP.From,
ReplyTo: cfg.SMTP.ReplyTo,
Timeout: cfg.SMTP.Timeout,
TLSMode: tlsMode,
InsecureSkipVerify: cfg.SMTP.TLS.InsecureSkipVerify,
})
if err != nil {
return err
}
emailSender = mailerAdapter{sender: sender}
}
if emailSender == nil {
emailVerificationEnabled = false
}
gormDB, gormCleanup, err := openAdminSettingsDB(cfg.Store)
if err != nil {
return err
}
defer func() {
if gormCleanup != nil {
if err := gormCleanup(context.Background()); err != nil {
logger.Error("failed to close admin settings db", "err", err)
}
}
}()
service.SetDB(gormDB)
var stopXraySync func(context.Context) error
if cfg.Xray.Sync.Enabled {
syncInterval := cfg.Xray.Sync.Interval
if syncInterval <= 0 {
syncInterval = 5 * time.Minute
}
templatePath := strings.TrimSpace(cfg.Xray.Sync.TemplatePath)
if templatePath == "" {
templatePath = filepath.Join("account", "config", "xray.config.template.json")
}
outputPath := strings.TrimSpace(cfg.Xray.Sync.OutputPath)
if outputPath == "" {
outputPath = "/usr/local/etc/xray/config.json"
}
source, err := xrayconfig.NewGormClientSource(gormDB)
if err != nil {
return err
}
syncer, err := xrayconfig.NewPeriodicSyncer(xrayconfig.PeriodicOptions{
Logger: logger.With("component", "xray-sync"),
Interval: syncInterval,
Source: source,
Generator: xrayconfig.Generator{TemplatePath: templatePath, OutputPath: outputPath},
ValidateCommand: cfg.Xray.Sync.ValidateCommand,
RestartCommand: cfg.Xray.Sync.RestartCommand,
})
if err != nil {
return err
}
stop, err := syncer.Start(ctx)
if err != nil {
return err
}
logger.Info("xray periodic sync enabled", "interval", syncInterval, "output", outputPath)
stopXraySync = stop
}
if stopXraySync != nil {
defer func() {
waitCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := stopXraySync(waitCtx); err != nil {
logger.Warn("xray syncer shutdown", "err", err)
}
}()
}
options := []api.Option{
api.WithStore(st),
api.WithSessionTTL(cfg.Session.TTL),
}
if emailSender != nil {
options = append(options, api.WithEmailSender(emailSender))
}
options = append(options, api.WithEmailVerification(emailVerificationEnabled))
api.RegisterRoutes(r, options...)
addr := strings.TrimSpace(cfg.Server.Addr)
if addr == "" {
addr = ":8080"
}
tlsSettings := cfg.Server.TLS
certFile := strings.TrimSpace(tlsSettings.CertFile)
keyFile := strings.TrimSpace(tlsSettings.KeyFile)
caFile := strings.TrimSpace(tlsSettings.CAFile)
clientCAFile := strings.TrimSpace(tlsSettings.ClientCAFile)
useTLS := tlsSettings.IsEnabled()
var tlsConfig *tls.Config
if useTLS {
if certFile == "" || keyFile == "" {
return fmt.Errorf("tls is enabled but certFile (%q) or keyFile (%q) is empty", certFile, keyFile)
}
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return fmt.Errorf("failed to load tls certificate: %w", err)
}
if caFile != "" {
caPEM, err := os.ReadFile(caFile)
if err != nil {
return fmt.Errorf("failed to read ca file %q: %w", caFile, err)
}
var block *pem.Block
existing := make(map[string]struct{}, len(cert.Certificate))
for _, c := range cert.Certificate {
existing[string(c)] = struct{}{}
}
for len(caPEM) > 0 {
block, caPEM = pem.Decode(caPEM)
if block == nil {
break
}
if block.Type != "CERTIFICATE" || len(block.Bytes) == 0 {
continue
}
if _, ok := existing[string(block.Bytes)]; ok {
continue
}
cert.Certificate = append(cert.Certificate, block.Bytes)
}
if len(cert.Certificate) == 0 {
return fmt.Errorf("ca file %q did not contain any certificates", caFile)
}
}
tlsConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
Certificates: []tls.Certificate{cert},
}
if clientCAFile != "" {
caBytes, err := os.ReadFile(clientCAFile)
if err != nil {
return err
}
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(caBytes) {
return errors.New("failed to parse client CA file")
}
tlsConfig.ClientCAs = pool
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
}
} else {
if certFile != "" || keyFile != "" {
logger.Info("TLS disabled; certificate paths will be ignored", "certFile", certFile, "keyFile", keyFile)
}
if clientCAFile != "" {
logger.Warn("client CA configured but TLS is disabled; ignoring", "clientCAFile", clientCAFile)
}
}
srv := &http.Server{
Addr: addr,
Handler: r,
ReadTimeout: cfg.Server.ReadTimeout,
WriteTimeout: cfg.Server.WriteTimeout,
}
if useTLS {
srv.TLSConfig = tlsConfig
}
logger.Info("starting account service", "addr", addr, "tls", useTLS)
var listenCertFile, listenKeyFile string
if useTLS {
if tlsSettings.RedirectHTTP {
go func() {
redirectAddr := deriveRedirectAddr(addr)
redirectSrv := &http.Server{
Addr: redirectAddr,
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
host := r.Host
if host == "" {
host = redirectAddr
}
target := "https://" + host + r.URL.RequestURI()
http.Redirect(w, r, target, http.StatusPermanentRedirect)
}),
}
if err := redirectSrv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
logger.Error("http redirect listener exited", "err", err)
}
}()
}
if tlsConfig != nil && len(tlsConfig.Certificates) > 0 {
listenCertFile = ""
listenKeyFile = ""
} else {
listenCertFile = certFile
listenKeyFile = keyFile
}
if err := srv.ListenAndServeTLS(listenCertFile, listenKeyFile); err != nil {
if !errors.Is(err, http.ErrServerClosed) {
logger.Error("account service shutdown", "err", err)
return err
}
}
} else {
if err := srv.ListenAndServe(); err != nil {
if !errors.Is(err, http.ErrServerClosed) {
logger.Error("account service shutdown", "err", err)
return err
}
}
}
return nil
},
}

View File

@ -0,0 +1,26 @@
mode: "agent"
log:
level: info
agent:
id: "edge-node-1"
controllerUrl: "https://account.svc.plus"
apiToken: "replace-with-agent-token"
httpTimeout: 15s
statusInterval: 1m
syncInterval: 5m
tls:
insecureSkipVerify: false
xray:
sync:
enabled: true
interval: 5m
templatePath: "account/config/xray.config.template.json"
outputPath: "/usr/local/etc/xray/config.json"
validateCommand: []
restartCommand:
- "systemctl"
- "restart"
- "xray.service"

View File

@ -0,0 +1,85 @@
mode: "server-agent"
log:
level: info
server:
addr: ":8080"
readTimeout: 15s
writeTimeout: 15s
publicUrl: "http://localhost:8080"
allowedOrigins:
- "https://dev.svc.plus"
- "https://dev-homepage.svc.plus"
- "https://www.svc.plus"
- "https://global-homepage.svc.plus"
- "https://account.svc.plus"
- "https://localhost:8443"
- "http://localhost:8080"
- "http://127.0.0.1:8080"
- "http://localhost:3001"
- "http://127.0.0.1:3001"
- "http://localhost:3000"
- "http://127.0.0.1:3000"
tls:
enabled: false
certFile: ""
keyFile: ""
caFile: ""
clientCAFile: ""
redirectHttp: false
store:
driver: "postgres"
dsn: "postgres://shenlan:password@127.0.0.1:5432/account?sslmode=disable"
maxOpenConns: 30
maxIdleConns: 10
session:
ttl: 24h
cache: "redis"
redis:
addr: "127.0.0.1:6379"
password: ""
smtp:
host: "smtp.example.com"
port: 587
username: "apikey"
password: "YOUR_PASSWORD"
from: "XControl Account <no-reply@example.com>"
replyTo: ""
timeout: 10s
tls:
mode: "auto"
insecureSkipVerify: false
xray:
sync:
enabled: false
interval: 5m
templatePath: "account/config/xray.config.template.json"
outputPath: "/usr/local/etc/xray/config.json"
validateCommand: []
restartCommand:
- "systemctl"
- "restart"
- "xray.service"
agent:
id: "account-primary"
controllerUrl: "http://127.0.0.1:8080"
apiToken: "replace-with-agent-token"
httpTimeout: 15s
statusInterval: 1m
syncInterval: 5m
tls:
insecureSkipVerify: false
agents:
credentials:
- id: "account-primary"
name: "Account Server (local agent)"
token: "replace-with-agent-token"
groups:
- "default"

View File

@ -1,3 +1,5 @@
mode: "server-agent"
log:
level: info
@ -63,3 +65,21 @@ xray:
- "systemctl"
- "restart"
- "xray.service"
agent:
id: "account-primary"
controllerUrl: "http://127.0.0.1:8080"
apiToken: "replace-with-agent-token"
httpTimeout: 15s
statusInterval: 1m
syncInterval: 5m
tls:
insecureSkipVerify: false
agents:
credentials:
- id: "account-primary"
name: "Account Server (local agent)"
token: "replace-with-agent-token"
groups:
- "default"

View File

@ -19,12 +19,15 @@ type Log struct {
// Config holds configuration for the account service.
type Config struct {
Mode string `yaml:"mode"`
Log Log `yaml:"log"`
Server Server `yaml:"server"`
Store Store `yaml:"store"`
Session Session `yaml:"session"`
SMTP SMTP `yaml:"smtp"`
Xray Xray `yaml:"xray"`
Agent Agent `yaml:"agent"`
Agents Agents `yaml:"agents"`
}
// Server defines HTTP server configuration.
@ -106,6 +109,36 @@ type XraySync struct {
RestartCommand []string `yaml:"restartCommand"`
}
// Agent defines configuration for agent mode deployments.
type Agent struct {
ID string `yaml:"id"`
ControllerURL string `yaml:"controllerUrl"`
APIToken string `yaml:"apiToken"`
HTTPTimeout time.Duration `yaml:"httpTimeout"`
StatusInterval time.Duration `yaml:"statusInterval"`
SyncInterval time.Duration `yaml:"syncInterval"`
TLS AgentTLS `yaml:"tls"`
}
// AgentTLS configures TLS behaviour for the agent HTTP client.
type AgentTLS struct {
InsecureSkipVerify bool `yaml:"insecureSkipVerify"`
}
// Agents describes the controller-side agent configuration.
type Agents struct {
Credentials []AgentCredential `yaml:"credentials"`
}
// AgentCredential represents a single agent identity authorised to call the
// controller API.
type AgentCredential struct {
ID string `yaml:"id"`
Name string `yaml:"name"`
Token string `yaml:"token"`
Groups []string `yaml:"groups"`
}
// Load reads the configuration file at the provided path. When path is empty,
// it defaults to account/config/account.yaml. If the file does not exist an
// empty configuration is returned.

View File

@ -0,0 +1,151 @@
package agentmode
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"xcontrol/account/internal/agentproto"
)
// ClientOptions configures the HTTP client used to communicate with the
// controller.
type ClientOptions struct {
Timeout time.Duration
InsecureSkipVerify bool
UserAgent string
}
// Client issues authenticated requests against the controller.
type Client struct {
baseURL *url.URL
token string
http *http.Client
userAgent string
}
// NewClient constructs a client for the provided controller URL and token.
func NewClient(baseURL, token string, opts ClientOptions) (*Client, error) {
trimmedURL := strings.TrimSpace(baseURL)
if trimmedURL == "" {
return nil, errors.New("controller url is required")
}
parsed, err := url.Parse(trimmedURL)
if err != nil {
return nil, fmt.Errorf("parse controller url: %w", err)
}
token = strings.TrimSpace(token)
if token == "" {
return nil, errors.New("controller token is required")
}
timeout := opts.Timeout
if timeout <= 0 {
timeout = 15 * time.Second
}
transport := http.DefaultTransport
if t, ok := transport.(*http.Transport); ok {
clone := t.Clone()
if opts.InsecureSkipVerify {
if clone.TLSClientConfig == nil {
clone.TLSClientConfig = &tls.Config{}
}
clone.TLSClientConfig.InsecureSkipVerify = true
}
transport = clone
}
client := &http.Client{
Timeout: timeout,
Transport: transport,
}
userAgent := strings.TrimSpace(opts.UserAgent)
if userAgent == "" {
userAgent = "xcontrol-agent"
}
return &Client{
baseURL: parsed,
token: token,
http: client,
userAgent: userAgent,
}, nil
}
// ListClients fetches the current set of Xray clients from the controller.
func (c *Client) ListClients(ctx context.Context) (agentproto.ClientListResponse, error) {
endpoint, err := url.JoinPath(c.baseURL.String(), "/api/agent/v1/users")
if err != nil {
return agentproto.ClientListResponse{}, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return agentproto.ClientListResponse{}, err
}
c.applyHeaders(req)
req.Header.Set("Accept", "application/json")
resp, err := c.http.Do(req)
if err != nil {
return agentproto.ClientListResponse{}, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<14))
return agentproto.ClientListResponse{}, fmt.Errorf("controller returned %s: %s", resp.Status, strings.TrimSpace(string(body)))
}
var payload agentproto.ClientListResponse
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
return agentproto.ClientListResponse{}, fmt.Errorf("decode client list: %w", err)
}
return payload, nil
}
// ReportStatus submits the agent status report to the controller.
func (c *Client) ReportStatus(ctx context.Context, report agentproto.StatusReport) error {
endpoint, err := url.JoinPath(c.baseURL.String(), "/api/agent/v1/status")
if err != nil {
return err
}
buf, err := json.Marshal(report)
if err != nil {
return fmt.Errorf("encode status report: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(buf))
if err != nil {
return err
}
c.applyHeaders(req)
req.Header.Set("Content-Type", "application/json")
resp, err := c.http.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<14))
return fmt.Errorf("controller returned %s: %s", resp.Status, strings.TrimSpace(string(body)))
}
return nil
}
func (c *Client) applyHeaders(req *http.Request) {
req.Header.Set("Authorization", "Bearer "+c.token)
req.Header.Set("User-Agent", c.userAgent)
}

View File

@ -0,0 +1,195 @@
package agentmode
import (
"context"
"errors"
"fmt"
"log/slog"
"path/filepath"
"strings"
"sync"
"time"
"xcontrol/account/config"
"xcontrol/account/internal/agentproto"
"xcontrol/account/internal/xrayconfig"
)
// Options configures the agent runtime.
type Options struct {
Logger *slog.Logger
Agent config.Agent
Xray config.Xray
}
// Run launches the agent mode control loop. It blocks until the context is
// cancelled or a fatal error occurs during setup.
func Run(ctx context.Context, opts Options) error {
if ctx == nil {
return errors.New("context is required")
}
logger := opts.Logger
if logger == nil {
logger = slog.Default()
}
controllerURL := strings.TrimSpace(opts.Agent.ControllerURL)
if controllerURL == "" {
return errors.New("agent.controllerUrl is required")
}
token := strings.TrimSpace(opts.Agent.APIToken)
if token == "" {
return errors.New("agent.apiToken is required")
}
syncInterval := opts.Agent.SyncInterval
if syncInterval <= 0 {
syncInterval = opts.Xray.Sync.Interval
}
if syncInterval <= 0 {
syncInterval = 5 * time.Minute
}
statusInterval := opts.Agent.StatusInterval
if statusInterval <= 0 {
statusInterval = time.Minute
}
httpTimeout := opts.Agent.HTTPTimeout
if httpTimeout <= 0 {
httpTimeout = 15 * time.Second
}
templatePath := strings.TrimSpace(opts.Xray.Sync.TemplatePath)
if templatePath == "" {
templatePath = filepath.Join("account", "config", "xray.config.template.json")
}
outputPath := strings.TrimSpace(opts.Xray.Sync.OutputPath)
if outputPath == "" {
outputPath = "/usr/local/etc/xray/config.json"
}
client, err := NewClient(controllerURL, token, ClientOptions{
Timeout: httpTimeout,
InsecureSkipVerify: opts.Agent.TLS.InsecureSkipVerify,
UserAgent: buildUserAgent(opts.Agent.ID),
})
if err != nil {
return err
}
tracker := newSyncTracker()
source := NewHTTPClientSource(client, tracker)
syncLogger := logger.With("component", "agent-xray-sync")
syncer, err := xrayconfig.NewPeriodicSyncer(xrayconfig.PeriodicOptions{
Logger: syncLogger,
Interval: syncInterval,
Source: source,
Generator: xrayconfig.Generator{TemplatePath: templatePath, OutputPath: outputPath},
ValidateCommand: opts.Xray.Sync.ValidateCommand,
RestartCommand: opts.Xray.Sync.RestartCommand,
OnSync: func(result xrayconfig.SyncResult) {
if result.Error != nil {
tracker.MarkError(result.Error, result.CompletedAt)
return
}
tracker.MarkSuccess(result.CompletedAt)
},
})
if err != nil {
return err
}
stopSync, err := syncer.Start(ctx)
if err != nil {
return err
}
defer func() {
waitCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := stopSync(waitCtx); err != nil {
logger.Warn("xray syncer shutdown", "err", err)
}
}()
reporterCtx, reporterCancel := context.WithCancel(ctx)
defer reporterCancel()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
runStatusReporter(reporterCtx, client, tracker, statusInterval, syncInterval, logger)
}()
<-ctx.Done()
reporterCancel()
wg.Wait()
return nil
}
func buildUserAgent(id string) string {
id = strings.TrimSpace(id)
if id == "" {
return "xcontrol-agent"
}
return fmt.Sprintf("xcontrol-agent/%s", id)
}
func runStatusReporter(ctx context.Context, client *Client, tracker *syncTracker, interval, syncInterval time.Duration, logger *slog.Logger) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
send := func() {
snapshot := tracker.Snapshot()
report := buildStatusReport(snapshot, syncInterval)
if err := client.ReportStatus(ctx, report); err != nil {
logger.Warn("failed to report agent status", "err", err)
}
}
send()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
send()
}
}
}
func buildStatusReport(snapshot trackerSnapshot, syncInterval time.Duration) agentproto.StatusReport {
healthy := snapshot.LastError == "" && !snapshot.LastSuccess.IsZero()
running := false
var lastSyncPtr *time.Time
if !snapshot.LastSuccess.IsZero() {
running = time.Since(snapshot.LastSuccess) <= 3*syncInterval
last := snapshot.LastSuccess
lastSyncPtr = &last
}
report := agentproto.StatusReport{
Healthy: healthy,
Message: snapshot.LastError,
Users: snapshot.Clients,
SyncRevision: snapshot.Revision,
Xray: agentproto.XrayStatus{
Running: running,
Clients: snapshot.Clients,
LastSync: func() *time.Time {
if lastSyncPtr == nil {
return nil
}
copy := *lastSyncPtr
return &copy
}(),
},
}
return report
}

View File

@ -0,0 +1,33 @@
package agentmode
import (
"context"
"time"
"xcontrol/account/internal/xrayconfig"
)
// HTTPClientSource retrieves Xray clients from the controller over HTTP.
type HTTPClientSource struct {
client *Client
tracker *syncTracker
}
// NewHTTPClientSource constructs a source backed by the provided client and
// tracker.
func NewHTTPClientSource(client *Client, tracker *syncTracker) *HTTPClientSource {
return &HTTPClientSource{client: client, tracker: tracker}
}
// ListClients implements xrayconfig.ClientSource by fetching the latest client
// list via the controller API.
func (s *HTTPClientSource) ListClients(ctx context.Context) ([]xrayconfig.Client, error) {
resp, err := s.client.ListClients(ctx)
if err != nil {
return nil, err
}
if s.tracker != nil {
s.tracker.UpdateFetch(len(resp.Clients), resp.Revision, time.Now().UTC())
}
return resp.Clients, nil
}

View File

@ -0,0 +1,72 @@
package agentmode
import (
"sync"
"time"
)
type syncTracker struct {
mu sync.RWMutex
clients int
revision string
lastFetch time.Time
lastSuccess time.Time
lastError string
lastErrorAt time.Time
}
func newSyncTracker() *syncTracker {
return &syncTracker{}
}
func (t *syncTracker) UpdateFetch(count int, revision string, when time.Time) {
t.mu.Lock()
defer t.mu.Unlock()
t.clients = count
t.revision = revision
t.lastFetch = when
}
func (t *syncTracker) MarkSuccess(at time.Time) {
t.mu.Lock()
defer t.mu.Unlock()
t.lastSuccess = at
t.lastError = ""
t.lastErrorAt = time.Time{}
}
func (t *syncTracker) MarkError(err error, at time.Time) {
if err == nil {
return
}
t.mu.Lock()
defer t.mu.Unlock()
t.lastError = err.Error()
t.lastErrorAt = at
}
type trackerSnapshot struct {
Clients int
Revision string
LastFetch time.Time
LastSuccess time.Time
LastError string
LastErrorAt time.Time
}
func (t *syncTracker) Snapshot() trackerSnapshot {
t.mu.RLock()
defer t.mu.RUnlock()
return trackerSnapshot{
Clients: t.clients,
Revision: t.revision,
LastFetch: t.lastFetch,
LastSuccess: t.lastSuccess,
LastError: t.lastError,
LastErrorAt: t.lastErrorAt,
}
}

View File

@ -0,0 +1,34 @@
package agentproto
import (
"time"
"xcontrol/account/internal/xrayconfig"
)
// ClientListResponse represents the payload returned by the controller when an
// agent requests the latest set of Xray clients.
type ClientListResponse struct {
Clients []xrayconfig.Client `json:"clients"`
Total int `json:"total"`
GeneratedAt time.Time `json:"generatedAt"`
Revision string `json:"revision,omitempty"`
}
// StatusReport captures the runtime state of an agent and the managed Xray
// instance.
type StatusReport struct {
Healthy bool `json:"healthy"`
Message string `json:"message,omitempty"`
Users int `json:"users"`
SyncRevision string `json:"syncRevision,omitempty"`
Xray XrayStatus `json:"xray"`
}
// XrayStatus describes the synchronisation state of the managed Xray process.
type XrayStatus struct {
Running bool `json:"running"`
Clients int `json:"clients"`
LastSync *time.Time `json:"lastSync,omitempty"`
ConfigHash string `json:"configHash,omitempty"`
}

View File

@ -0,0 +1,179 @@
package agentserver
import (
"crypto/sha256"
"errors"
"sort"
"strings"
"sync"
"time"
"xcontrol/account/internal/agentproto"
)
// Credential defines the authentication material assigned to a managed agent.
type Credential struct {
ID string `yaml:"id"`
Name string `yaml:"name"`
Token string `yaml:"token"`
Groups []string `yaml:"groups"`
}
// Config groups the credential set exposed through configuration.
type Config struct {
Credentials []Credential `yaml:"credentials"`
}
// Identity represents an authenticated agent instance.
type Identity struct {
ID string
Name string
Groups []string
}
// StatusSnapshot captures the last reported status for an agent.
type StatusSnapshot struct {
Agent Identity
Report agentproto.StatusReport
UpdatedAt time.Time
}
// Registry manages agent credentials and status reports in-memory.
type Registry struct {
mu sync.RWMutex
credentials map[[32]byte]Identity
byID map[string]Identity
statuses map[string]StatusSnapshot
}
// NewRegistry constructs a registry from configuration, validating credentials
// and normalising their representation.
func NewRegistry(cfg Config) (*Registry, error) {
r := &Registry{
credentials: make(map[[32]byte]Identity),
byID: make(map[string]Identity),
statuses: make(map[string]StatusSnapshot),
}
for _, cred := range cfg.Credentials {
id := strings.TrimSpace(cred.ID)
token := strings.TrimSpace(cred.Token)
if id == "" {
return nil, errors.New("agent credential id is required")
}
if token == "" {
return nil, errors.New("agent credential token is required")
}
if _, exists := r.byID[id]; exists {
return nil, errors.New("duplicate agent credential id: " + id)
}
digest := sha256.Sum256([]byte(token))
if _, exists := r.credentials[digest]; exists {
return nil, errors.New("duplicate agent credential token")
}
identity := Identity{
ID: id,
Name: strings.TrimSpace(cred.Name),
Groups: normalizeStrings(cred.Groups),
}
r.credentials[digest] = identity
r.byID[id] = identity
}
return r, nil
}
// Authenticate validates the provided token and returns the associated agent
// identity when successful.
func (r *Registry) Authenticate(token string) (*Identity, bool) {
token = strings.TrimSpace(token)
if token == "" {
return nil, false
}
digest := sha256.Sum256([]byte(token))
r.mu.RLock()
identity, ok := r.credentials[digest]
r.mu.RUnlock()
if !ok {
return nil, false
}
copy := identity
return &copy, true
}
// ReportStatus records the status report for the provided agent identity.
func (r *Registry) ReportStatus(agent Identity, report agentproto.StatusReport) {
r.mu.Lock()
defer r.mu.Unlock()
r.statuses[agent.ID] = StatusSnapshot{
Agent: agent,
Report: report,
UpdatedAt: time.Now().UTC(),
}
}
// Statuses returns the latest status snapshot for all agents sorted by ID.
func (r *Registry) Statuses() []StatusSnapshot {
r.mu.RLock()
defer r.mu.RUnlock()
snapshots := make([]StatusSnapshot, 0, len(r.byID))
for id, identity := range r.byID {
snapshot, ok := r.statuses[id]
if !ok {
snapshot = StatusSnapshot{Agent: identity}
}
snapshots = append(snapshots, snapshot)
}
sort.Slice(snapshots, func(i, j int) bool {
return snapshots[i].Agent.ID < snapshots[j].Agent.ID
})
return snapshots
}
// Agents returns the configured agent identities in a deterministic order.
func (r *Registry) Agents() []Identity {
r.mu.RLock()
defer r.mu.RUnlock()
agents := make([]Identity, 0, len(r.byID))
for _, identity := range r.byID {
agents = append(agents, identity)
}
sort.Slice(agents, func(i, j int) bool {
return agents[i].ID < agents[j].ID
})
return agents
}
// normalizeStrings trims whitespace and removes duplicates from the provided
// slice while preserving the original order for unique entries.
func normalizeStrings(values []string) []string {
if len(values) == 0 {
return nil
}
result := make([]string, 0, len(values))
seen := make(map[string]struct{}, len(values))
for _, value := range values {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
continue
}
if _, exists := seen[trimmed]; exists {
continue
}
seen[trimmed] = struct{}{}
result = append(result, trimmed)
}
if len(result) == 0 {
return nil
}
return result
}

View File

@ -0,0 +1,104 @@
package agentserver
import (
"testing"
"time"
"xcontrol/account/internal/agentproto"
)
func TestNewRegistryValidation(t *testing.T) {
_, err := NewRegistry(Config{})
if err != nil {
t.Fatalf("unexpected error for empty config: %v", err)
}
_, err = NewRegistry(Config{Credentials: []Credential{{ID: "", Token: "token"}}})
if err == nil {
t.Fatalf("expected error for empty id")
}
_, err = NewRegistry(Config{Credentials: []Credential{{ID: "edge", Token: ""}}})
if err == nil {
t.Fatalf("expected error for empty token")
}
_, err = NewRegistry(Config{Credentials: []Credential{{ID: "a", Token: "1"}, {ID: "a", Token: "2"}}})
if err == nil {
t.Fatalf("expected error for duplicate id")
}
_, err = NewRegistry(Config{Credentials: []Credential{{ID: "a", Token: "dup"}, {ID: "b", Token: "dup"}}})
if err == nil {
t.Fatalf("expected error for duplicate token")
}
}
func TestRegistryAuthenticateAndStatus(t *testing.T) {
registry, err := NewRegistry(Config{Credentials: []Credential{{ID: "edge", Name: "Edge", Token: "secret", Groups: []string{"default"}}}})
if err != nil {
t.Fatalf("new registry: %v", err)
}
identity, ok := registry.Authenticate("secret")
if !ok || identity == nil {
t.Fatalf("expected authentication to succeed")
}
if identity.ID != "edge" {
t.Fatalf("unexpected identity id %q", identity.ID)
}
report := agentproto.StatusReport{
Healthy: true,
Users: 5,
Xray: agentproto.XrayStatus{
Running: true,
Clients: 5,
},
}
registry.ReportStatus(*identity, report)
snapshots := registry.Statuses()
if len(snapshots) != 1 {
t.Fatalf("expected 1 snapshot, got %d", len(snapshots))
}
snapshot := snapshots[0]
if snapshot.Agent.ID != "edge" {
t.Fatalf("unexpected snapshot agent id %q", snapshot.Agent.ID)
}
if !snapshot.Report.Healthy {
t.Fatalf("expected healthy report")
}
if snapshot.Report.Users != 5 {
t.Fatalf("unexpected users count %d", snapshot.Report.Users)
}
if snapshot.Report.Xray.Clients != 5 {
t.Fatalf("unexpected xray clients %d", snapshot.Report.Xray.Clients)
}
if snapshot.UpdatedAt.IsZero() {
t.Fatalf("expected updated timestamp")
}
// Ensure snapshots include configured agents without reports.
registry, err = NewRegistry(Config{Credentials: []Credential{{ID: "a", Token: "1"}, {ID: "b", Token: "2"}}})
if err != nil {
t.Fatalf("new registry: %v", err)
}
snapshots = registry.Statuses()
if len(snapshots) != 2 {
t.Fatalf("expected 2 snapshots, got %d", len(snapshots))
}
if snapshots[0].Agent.ID != "a" || snapshots[1].Agent.ID != "b" {
t.Fatalf("unexpected snapshot ordering: %+v", snapshots)
}
// Report status with timestamp and ensure Latest is retained.
now := time.Now().UTC()
registry.ReportStatus(snapshots[0].Agent, agentproto.StatusReport{Users: 1})
entries := registry.Statuses()
if len(entries) != 2 {
t.Fatalf("expected 2 entries, got %d", len(entries))
}
if entries[0].UpdatedAt.Before(now) {
t.Fatalf("expected updated timestamp to be after initial time")
}
}

View File

@ -26,6 +26,7 @@ type PeriodicOptions struct {
ValidateCommand []string
RestartCommand []string
Runner commandRunner
OnSync func(SyncResult)
}
// PeriodicSyncer periodically rebuilds the Xray configuration from the database.
@ -37,6 +38,14 @@ type PeriodicSyncer struct {
validateCommand []string
restartCommand []string
runner commandRunner
onSync func(SyncResult)
}
// SyncResult describes the outcome of a synchronization attempt.
type SyncResult struct {
Clients int
Error error
CompletedAt time.Time
}
// NewPeriodicSyncer constructs a new PeriodicSyncer from the provided options.
@ -69,6 +78,7 @@ func NewPeriodicSyncer(opts PeriodicOptions) (*PeriodicSyncer, error) {
validateCommand: append([]string(nil), opts.ValidateCommand...),
restartCommand: append([]string(nil), opts.RestartCommand...),
runner: runner,
onSync: opts.OnSync,
}, nil
}
@ -101,6 +111,7 @@ func (s *PeriodicSyncer) Start(ctx context.Context) (func(context.Context) error
func (s *PeriodicSyncer) run(ctx context.Context) {
if n, err := s.sync(ctx); err != nil {
s.notify(SyncResult{Clients: n, Error: err, CompletedAt: time.Now().UTC()})
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
s.logger.Error("xray config sync failed", "err", err)
}
@ -109,6 +120,7 @@ func (s *PeriodicSyncer) run(ctx context.Context) {
}
} else {
s.logger.Info("xray config synchronized", "clients", n)
s.notify(SyncResult{Clients: n, CompletedAt: time.Now().UTC()})
}
ticker := time.NewTicker(s.interval)
@ -121,6 +133,7 @@ func (s *PeriodicSyncer) run(ctx context.Context) {
case <-ticker.C:
n, err := s.sync(ctx)
if err != nil {
s.notify(SyncResult{Clients: n, Error: err, CompletedAt: time.Now().UTC()})
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return
}
@ -128,6 +141,7 @@ func (s *PeriodicSyncer) run(ctx context.Context) {
continue
}
s.logger.Info("xray config synchronized", "clients", n)
s.notify(SyncResult{Clients: n, CompletedAt: time.Now().UTC()})
}
}
}
@ -153,6 +167,13 @@ func (s *PeriodicSyncer) sync(ctx context.Context) (int, error) {
return len(clients), nil
}
func (s *PeriodicSyncer) notify(result SyncResult) {
if s.onSync == nil {
return
}
s.onSync(result)
}
func (s *PeriodicSyncer) runCommand(ctx context.Context, cmd []string, action string) error {
output, err := s.runner(ctx, cmd)
if err != nil {

View File

@ -6,6 +6,7 @@ import (
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
@ -173,6 +174,47 @@ func TestPeriodicSyncerStartStop(t *testing.T) {
}
}
func TestPeriodicSyncerOnSyncCallback(t *testing.T) {
template, output := writeTemplate(t)
var results []SyncResult
var mu sync.Mutex
opts := PeriodicOptions{
Interval: 5 * time.Millisecond,
Source: staticSource{clients: []Client{{ID: "uuid-a"}}},
Generator: Generator{TemplatePath: template, OutputPath: output},
OnSync: func(res SyncResult) {
mu.Lock()
defer mu.Unlock()
results = append(results, res)
},
}
syncer, err := NewPeriodicSyncer(opts)
if err != nil {
t.Fatalf("new syncer: %v", err)
}
ctx, cancel := context.WithCancel(context.Background())
stop, err := syncer.Start(ctx)
if err != nil {
t.Fatalf("start: %v", err)
}
time.Sleep(20 * time.Millisecond)
cancel()
if err := stop(context.Background()); err != nil {
t.Fatalf("stop: %v", err)
}
mu.Lock()
defer mu.Unlock()
if len(results) == 0 {
t.Fatalf("expected at least one sync result")
}
if results[0].Clients != 1 {
t.Fatalf("expected 1 client, got %d", results[0].Clients)
}
if results[0].CompletedAt.IsZero() {
t.Fatalf("expected completion timestamp to be set")
}
}
type clientSourceFunc func(ctx context.Context) ([]Client, error)
func (f clientSourceFunc) ListClients(ctx context.Context) ([]Client, error) {