Allow combined server and agent mode (#595)
This commit is contained in:
parent
41ba7c3cc0
commit
7a4b2f0e00
69
account/api/admin_agents.go
Normal file
69
account/api/admin_agents.go
Normal file
@ -0,0 +1,69 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"xcontrol/account/internal/agentserver"
|
||||
)
|
||||
|
||||
type agentStatusReader interface {
|
||||
Statuses() []agentserver.StatusSnapshot
|
||||
}
|
||||
|
||||
type agentStatusEntry struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Groups []string `json:"groups,omitempty"`
|
||||
Healthy bool `json:"healthy"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Users int `json:"users"`
|
||||
SyncRevision string `json:"syncRevision,omitempty"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
Xray agentXraySummary `json:"xray"`
|
||||
}
|
||||
|
||||
type agentXraySummary struct {
|
||||
Running bool `json:"running"`
|
||||
Clients int `json:"clients"`
|
||||
LastSync *time.Time `json:"lastSync,omitempty"`
|
||||
}
|
||||
|
||||
func (h *handler) adminAgentStatus(c *gin.Context) {
|
||||
if h.agentStatusReader == nil {
|
||||
respondError(c, http.StatusServiceUnavailable, "agent_status_unavailable", "agent registry is not configured")
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := h.requireAdminOrOperator(c); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
snapshots := h.agentStatusReader.Statuses()
|
||||
entries := make([]agentStatusEntry, 0, len(snapshots))
|
||||
for _, snapshot := range snapshots {
|
||||
entry := agentStatusEntry{
|
||||
ID: snapshot.Agent.ID,
|
||||
Name: snapshot.Agent.Name,
|
||||
Groups: append([]string(nil), snapshot.Agent.Groups...),
|
||||
Healthy: snapshot.Report.Healthy,
|
||||
Message: snapshot.Report.Message,
|
||||
Users: snapshot.Report.Users,
|
||||
SyncRevision: snapshot.Report.SyncRevision,
|
||||
UpdatedAt: snapshot.UpdatedAt,
|
||||
Xray: agentXraySummary{
|
||||
Running: snapshot.Report.Xray.Running,
|
||||
Clients: snapshot.Report.Xray.Clients,
|
||||
},
|
||||
}
|
||||
if snapshot.Report.Xray.LastSync != nil {
|
||||
last := *snapshot.Report.Xray.LastSync
|
||||
entry.Xray.LastSync = &last
|
||||
}
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"agents": entries})
|
||||
}
|
||||
@ -85,4 +85,5 @@ func (h *handler) resolveSessionToken(c *gin.Context) string {
|
||||
func registerAdminRoutes(group *gin.RouterGroup, h *handler) {
|
||||
admin := group.Group("/admin")
|
||||
admin.GET("/users/metrics", h.adminUsersMetrics)
|
||||
admin.GET("/agents/status", h.adminAgentStatus)
|
||||
}
|
||||
|
||||
@ -57,6 +57,7 @@ type handler struct {
|
||||
passwordResets map[string]passwordReset
|
||||
resetMu sync.RWMutex
|
||||
metricsProvider service.UserMetricsProvider
|
||||
agentStatusReader agentStatusReader
|
||||
}
|
||||
|
||||
type mfaChallenge struct {
|
||||
@ -137,6 +138,15 @@ func WithUserMetricsProvider(provider service.UserMetricsProvider) Option {
|
||||
}
|
||||
}
|
||||
|
||||
// WithAgentStatusReader wires the agent status reader used by admin endpoints.
|
||||
func WithAgentStatusReader(reader agentStatusReader) Option {
|
||||
return func(h *handler) {
|
||||
if reader != nil {
|
||||
h.agentStatusReader = reader
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithPasswordResetTTL overrides the default TTL for password reset tokens.
|
||||
func WithPasswordResetTTL(ttl time.Duration) Option {
|
||||
return func(h *handler) {
|
||||
|
||||
@ -25,6 +25,9 @@ import (
|
||||
|
||||
"xcontrol/account/api"
|
||||
"xcontrol/account/config"
|
||||
"xcontrol/account/internal/agentmode"
|
||||
"xcontrol/account/internal/agentproto"
|
||||
"xcontrol/account/internal/agentserver"
|
||||
"xcontrol/account/internal/mailer"
|
||||
"xcontrol/account/internal/model"
|
||||
"xcontrol/account/internal/service"
|
||||
@ -54,6 +57,474 @@ func (m mailerAdapter) Send(ctx context.Context, msg api.EmailMessage) error {
|
||||
return m.sender.Send(ctx, mail)
|
||||
}
|
||||
|
||||
func runServer(ctx context.Context, cfg *config.Config, logger *slog.Logger) error {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if cfg == nil {
|
||||
return errors.New("config is nil")
|
||||
}
|
||||
if logger == nil {
|
||||
logger = slog.Default()
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
corsConfig := buildCORSConfig(logger, cfg.Server)
|
||||
if corsConfig.AllowAllOrigins {
|
||||
logger.Info("configured cors", "allowAllOrigins", true)
|
||||
} else {
|
||||
logger.Info("configured cors", "allowedOrigins", corsConfig.AllowOrigins)
|
||||
}
|
||||
r.Use(cors.New(corsConfig))
|
||||
r.Use(gin.Recovery())
|
||||
r.Use(func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
c.Next()
|
||||
logger.Info("request", "method", c.Request.Method, "path", c.FullPath(), "status", c.Writer.Status(), "latency", time.Since(start))
|
||||
})
|
||||
|
||||
storeCfg := store.Config{
|
||||
Driver: cfg.Store.Driver,
|
||||
DSN: cfg.Store.DSN,
|
||||
MaxOpenConns: cfg.Store.MaxOpenConns,
|
||||
MaxIdleConns: cfg.Store.MaxIdleConns,
|
||||
}
|
||||
|
||||
st, cleanup, err := store.New(ctx, storeCfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if cleanup == nil {
|
||||
return
|
||||
}
|
||||
if err := cleanup(context.Background()); err != nil {
|
||||
logger.Error("failed to close store", "err", err)
|
||||
}
|
||||
}()
|
||||
|
||||
var emailSender api.EmailSender
|
||||
emailVerificationEnabled := true
|
||||
smtpHost := strings.TrimSpace(cfg.SMTP.Host)
|
||||
if smtpHost == "" {
|
||||
emailVerificationEnabled = false
|
||||
}
|
||||
if smtpHost != "" && isExampleDomain(smtpHost) {
|
||||
emailVerificationEnabled = false
|
||||
logger.Warn("smtp host is a placeholder; disabling email delivery", "host", smtpHost)
|
||||
smtpHost = ""
|
||||
}
|
||||
if smtpHost != "" {
|
||||
tlsMode := mailer.ParseTLSMode(cfg.SMTP.TLS.Mode)
|
||||
sender, err := mailer.New(mailer.Config{
|
||||
Host: smtpHost,
|
||||
Port: cfg.SMTP.Port,
|
||||
Username: cfg.SMTP.Username,
|
||||
Password: cfg.SMTP.Password,
|
||||
From: cfg.SMTP.From,
|
||||
ReplyTo: cfg.SMTP.ReplyTo,
|
||||
Timeout: cfg.SMTP.Timeout,
|
||||
TLSMode: tlsMode,
|
||||
InsecureSkipVerify: cfg.SMTP.TLS.InsecureSkipVerify,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
emailSender = mailerAdapter{sender: sender}
|
||||
}
|
||||
if emailSender == nil {
|
||||
emailVerificationEnabled = false
|
||||
}
|
||||
|
||||
gormDB, gormCleanup, err := openAdminSettingsDB(cfg.Store)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if gormCleanup != nil {
|
||||
if err := gormCleanup(context.Background()); err != nil {
|
||||
logger.Error("failed to close admin settings db", "err", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
service.SetDB(gormDB)
|
||||
|
||||
gormSource, err := xrayconfig.NewGormClientSource(gormDB)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var agentRegistry *agentserver.Registry
|
||||
if len(cfg.Agents.Credentials) > 0 {
|
||||
creds := make([]agentserver.Credential, 0, len(cfg.Agents.Credentials))
|
||||
for _, c := range cfg.Agents.Credentials {
|
||||
creds = append(creds, agentserver.Credential{
|
||||
ID: c.ID,
|
||||
Name: c.Name,
|
||||
Token: c.Token,
|
||||
Groups: append([]string(nil), c.Groups...),
|
||||
})
|
||||
}
|
||||
agentRegistry, err = agentserver.NewRegistry(agentserver.Config{Credentials: creds})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var stopXraySync func(context.Context) error
|
||||
if cfg.Xray.Sync.Enabled {
|
||||
syncInterval := cfg.Xray.Sync.Interval
|
||||
if syncInterval <= 0 {
|
||||
syncInterval = 5 * time.Minute
|
||||
}
|
||||
templatePath := strings.TrimSpace(cfg.Xray.Sync.TemplatePath)
|
||||
if templatePath == "" {
|
||||
templatePath = filepath.Join("account", "config", "xray.config.template.json")
|
||||
}
|
||||
outputPath := strings.TrimSpace(cfg.Xray.Sync.OutputPath)
|
||||
if outputPath == "" {
|
||||
outputPath = "/usr/local/etc/xray/config.json"
|
||||
}
|
||||
syncer, err := xrayconfig.NewPeriodicSyncer(xrayconfig.PeriodicOptions{
|
||||
Logger: logger.With("component", "xray-sync"),
|
||||
Interval: syncInterval,
|
||||
Source: gormSource,
|
||||
Generator: xrayconfig.Generator{TemplatePath: templatePath, OutputPath: outputPath},
|
||||
ValidateCommand: cfg.Xray.Sync.ValidateCommand,
|
||||
RestartCommand: cfg.Xray.Sync.RestartCommand,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stop, err := syncer.Start(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
logger.Info("xray periodic sync enabled", "interval", syncInterval, "output", outputPath)
|
||||
stopXraySync = stop
|
||||
}
|
||||
|
||||
if stopXraySync != nil {
|
||||
defer func() {
|
||||
waitCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := stopXraySync(waitCtx); err != nil {
|
||||
logger.Warn("xray syncer shutdown", "err", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
options := []api.Option{
|
||||
api.WithStore(st),
|
||||
api.WithSessionTTL(cfg.Session.TTL),
|
||||
}
|
||||
if emailSender != nil {
|
||||
options = append(options, api.WithEmailSender(emailSender))
|
||||
}
|
||||
options = append(options, api.WithEmailVerification(emailVerificationEnabled))
|
||||
if agentRegistry != nil {
|
||||
options = append(options, api.WithAgentStatusReader(agentRegistry))
|
||||
}
|
||||
api.RegisterRoutes(r, options...)
|
||||
|
||||
if agentRegistry != nil {
|
||||
registerAgentAPIRoutes(r, agentRegistry, gormSource, logger)
|
||||
}
|
||||
|
||||
addr := strings.TrimSpace(cfg.Server.Addr)
|
||||
if addr == "" {
|
||||
addr = ":8080"
|
||||
}
|
||||
|
||||
tlsSettings := cfg.Server.TLS
|
||||
certFile := strings.TrimSpace(tlsSettings.CertFile)
|
||||
keyFile := strings.TrimSpace(tlsSettings.KeyFile)
|
||||
caFile := strings.TrimSpace(tlsSettings.CAFile)
|
||||
clientCAFile := strings.TrimSpace(tlsSettings.ClientCAFile)
|
||||
|
||||
useTLS := tlsSettings.IsEnabled()
|
||||
|
||||
var tlsConfig *tls.Config
|
||||
if useTLS {
|
||||
if certFile == "" || keyFile == "" {
|
||||
return fmt.Errorf("tls is enabled but certFile (%q) or keyFile (%q) is empty", certFile, keyFile)
|
||||
}
|
||||
|
||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load tls certificate: %w", err)
|
||||
}
|
||||
|
||||
if caFile != "" {
|
||||
caPEM, err := os.ReadFile(caFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read ca file %q: %w", caFile, err)
|
||||
}
|
||||
|
||||
var block *pem.Block
|
||||
existing := make(map[string]struct{}, len(cert.Certificate))
|
||||
for _, c := range cert.Certificate {
|
||||
existing[string(c)] = struct{}{}
|
||||
}
|
||||
|
||||
for len(caPEM) > 0 {
|
||||
block, caPEM = pem.Decode(caPEM)
|
||||
if block == nil {
|
||||
break
|
||||
}
|
||||
if block.Type != "CERTIFICATE" || len(block.Bytes) == 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := existing[string(block.Bytes)]; ok {
|
||||
continue
|
||||
}
|
||||
cert.Certificate = append(cert.Certificate, block.Bytes)
|
||||
}
|
||||
|
||||
if len(cert.Certificate) == 0 {
|
||||
return fmt.Errorf("ca file %q did not contain any certificates", caFile)
|
||||
}
|
||||
}
|
||||
|
||||
tlsConfig = &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
|
||||
if clientCAFile != "" {
|
||||
caBytes, err := os.ReadFile(clientCAFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pool := x509.NewCertPool()
|
||||
if !pool.AppendCertsFromPEM(caBytes) {
|
||||
return errors.New("failed to parse client CA file")
|
||||
}
|
||||
tlsConfig.ClientCAs = pool
|
||||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
}
|
||||
} else {
|
||||
if certFile != "" || keyFile != "" {
|
||||
logger.Info("TLS disabled; certificate paths will be ignored", "certFile", certFile, "keyFile", keyFile)
|
||||
}
|
||||
if clientCAFile != "" {
|
||||
logger.Warn("client CA configured but TLS is disabled; ignoring", "clientCAFile", clientCAFile)
|
||||
}
|
||||
}
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: addr,
|
||||
Handler: r,
|
||||
ReadTimeout: cfg.Server.ReadTimeout,
|
||||
WriteTimeout: cfg.Server.WriteTimeout,
|
||||
}
|
||||
|
||||
if useTLS {
|
||||
srv.TLSConfig = tlsConfig
|
||||
}
|
||||
|
||||
logger.Info("starting account service", "addr", addr, "tls", useTLS)
|
||||
|
||||
var listenCertFile, listenKeyFile string
|
||||
if useTLS {
|
||||
if tlsSettings.RedirectHTTP {
|
||||
go func() {
|
||||
redirectAddr := deriveRedirectAddr(addr)
|
||||
redirectSrv := &http.Server{
|
||||
Addr: redirectAddr,
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
host := r.Host
|
||||
if host == "" {
|
||||
host = redirectAddr
|
||||
}
|
||||
target := "https://" + host + r.URL.RequestURI()
|
||||
http.Redirect(w, r, target, http.StatusPermanentRedirect)
|
||||
}),
|
||||
}
|
||||
if err := redirectSrv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
logger.Error("http redirect listener exited", "err", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if tlsConfig != nil && len(tlsConfig.Certificates) > 0 {
|
||||
listenCertFile = ""
|
||||
listenKeyFile = ""
|
||||
} else {
|
||||
listenCertFile = certFile
|
||||
listenKeyFile = keyFile
|
||||
}
|
||||
|
||||
if err := srv.ListenAndServeTLS(listenCertFile, listenKeyFile); err != nil {
|
||||
if !errors.Is(err, http.ErrServerClosed) {
|
||||
logger.Error("account service shutdown", "err", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if err := srv.ListenAndServe(); err != nil {
|
||||
if !errors.Is(err, http.ErrServerClosed) {
|
||||
logger.Error("account service shutdown", "err", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func runServerAndAgent(ctx context.Context, cfg *config.Config, logger *slog.Logger) error {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if cfg == nil {
|
||||
return errors.New("config is nil")
|
||||
}
|
||||
|
||||
agentCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
agentErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
agentErrCh <- runAgent(agentCtx, cfg, logger)
|
||||
}()
|
||||
|
||||
agentPending := true
|
||||
|
||||
select {
|
||||
case err := <-agentErrCh:
|
||||
agentPending = false
|
||||
if err == nil {
|
||||
err = errors.New("agent exited unexpectedly")
|
||||
}
|
||||
return fmt.Errorf("agent startup failed: %w", err)
|
||||
default:
|
||||
}
|
||||
|
||||
serverErr := runServer(ctx, cfg, logger)
|
||||
cancel()
|
||||
|
||||
var agentErr error
|
||||
if agentPending {
|
||||
agentErr = <-agentErrCh
|
||||
}
|
||||
|
||||
if serverErr != nil {
|
||||
return serverErr
|
||||
}
|
||||
if agentErr != nil {
|
||||
return agentErr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func runAgent(ctx context.Context, cfg *config.Config, logger *slog.Logger) error {
|
||||
if cfg == nil {
|
||||
return errors.New("config is nil")
|
||||
}
|
||||
if logger == nil {
|
||||
logger = slog.Default()
|
||||
}
|
||||
if !cfg.Xray.Sync.Enabled {
|
||||
logger.Warn("xray sync is disabled in configuration; agent mode will still attempt to manage xray config")
|
||||
}
|
||||
options := agentmode.Options{
|
||||
Logger: logger.With("component", "agent"),
|
||||
Agent: cfg.Agent,
|
||||
Xray: cfg.Xray,
|
||||
}
|
||||
return agentmode.Run(ctx, options)
|
||||
}
|
||||
|
||||
const agentIdentityContextKey = "xcontrol-account-agent-identity"
|
||||
|
||||
func registerAgentAPIRoutes(r *gin.Engine, registry *agentserver.Registry, source xrayconfig.ClientSource, logger *slog.Logger) {
|
||||
if registry == nil {
|
||||
return
|
||||
}
|
||||
group := r.Group("/api/agent/v1")
|
||||
group.Use(agentAuthMiddleware(registry))
|
||||
group.GET("/users", agentListUsersHandler(source))
|
||||
group.POST("/status", agentReportStatusHandler(registry, logger))
|
||||
}
|
||||
|
||||
func agentAuthMiddleware(registry *agentserver.Registry) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if registry == nil {
|
||||
c.AbortWithStatusJSON(http.StatusServiceUnavailable, gin.H{"error": "agent_registry_unavailable", "message": "agent registry not configured"})
|
||||
return
|
||||
}
|
||||
token := extractBearerToken(c.GetHeader("Authorization"))
|
||||
if token == "" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "agent_token_required", "message": "agent token is required"})
|
||||
return
|
||||
}
|
||||
identity, ok := registry.Authenticate(token)
|
||||
if !ok || identity == nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid_agent_token", "message": "invalid agent token"})
|
||||
return
|
||||
}
|
||||
c.Set(agentIdentityContextKey, *identity)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func agentListUsersHandler(source xrayconfig.ClientSource) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if source == nil {
|
||||
c.AbortWithStatusJSON(http.StatusServiceUnavailable, gin.H{"error": "client_source_unavailable", "message": "client source not configured"})
|
||||
return
|
||||
}
|
||||
clients, err := source.ListClients(c.Request.Context())
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "list_clients_failed", "message": "failed to list clients"})
|
||||
return
|
||||
}
|
||||
response := agentproto.ClientListResponse{
|
||||
Clients: clients,
|
||||
Total: len(clients),
|
||||
GeneratedAt: time.Now().UTC(),
|
||||
}
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
}
|
||||
|
||||
func agentReportStatusHandler(registry *agentserver.Registry, logger *slog.Logger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
value, exists := c.Get(agentIdentityContextKey)
|
||||
if !exists {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "agent_identity_missing", "message": "agent identity missing"})
|
||||
return
|
||||
}
|
||||
identity, ok := value.(agentserver.Identity)
|
||||
if !ok {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "agent_identity_invalid", "message": "agent identity malformed"})
|
||||
return
|
||||
}
|
||||
var report agentproto.StatusReport
|
||||
if err := c.ShouldBindJSON(&report); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid_status_payload", "message": "invalid status payload"})
|
||||
return
|
||||
}
|
||||
registry.ReportStatus(identity, report)
|
||||
if logger != nil {
|
||||
logger.Info("agent status updated", "agent", identity.ID, "healthy", report.Healthy, "clients", report.Xray.Clients)
|
||||
}
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
}
|
||||
|
||||
func extractBearerToken(header string) string {
|
||||
header = strings.TrimSpace(header)
|
||||
if header == "" {
|
||||
return ""
|
||||
}
|
||||
const prefix = "Bearer "
|
||||
if strings.HasPrefix(header, prefix) {
|
||||
header = header[len(prefix):]
|
||||
}
|
||||
return strings.TrimSpace(header)
|
||||
}
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
Use: "xcontrol-account",
|
||||
Short: "Start the xcontrol account service",
|
||||
@ -79,284 +550,22 @@ var rootCmd = &cobra.Command{
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: level}))
|
||||
slog.SetDefault(logger)
|
||||
|
||||
r := gin.New()
|
||||
corsConfig := buildCORSConfig(logger, cfg.Server)
|
||||
if corsConfig.AllowAllOrigins {
|
||||
logger.Info("configured cors", "allowAllOrigins", true)
|
||||
} else {
|
||||
logger.Info("configured cors", "allowedOrigins", corsConfig.AllowOrigins)
|
||||
}
|
||||
r.Use(cors.New(corsConfig))
|
||||
r.Use(gin.Recovery())
|
||||
r.Use(func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
c.Next()
|
||||
logger.Info("request", "method", c.Request.Method, "path", c.FullPath(), "status", c.Writer.Status(), "latency", time.Since(start))
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
storeCfg := store.Config{
|
||||
Driver: cfg.Store.Driver,
|
||||
DSN: cfg.Store.DSN,
|
||||
MaxOpenConns: cfg.Store.MaxOpenConns,
|
||||
MaxIdleConns: cfg.Store.MaxIdleConns,
|
||||
mode := strings.ToLower(strings.TrimSpace(cfg.Mode))
|
||||
if mode == "" {
|
||||
mode = "server"
|
||||
}
|
||||
|
||||
st, cleanup, err := store.New(ctx, storeCfg)
|
||||
if err != nil {
|
||||
return err
|
||||
switch mode {
|
||||
case "server":
|
||||
return runServer(ctx, cfg, logger)
|
||||
case "agent":
|
||||
return runAgent(ctx, cfg, logger)
|
||||
case "server-agent", "all", "combined":
|
||||
return runServerAndAgent(ctx, cfg, logger)
|
||||
default:
|
||||
return fmt.Errorf("unsupported mode %q", cfg.Mode)
|
||||
}
|
||||
defer func() {
|
||||
if cleanup == nil {
|
||||
return
|
||||
}
|
||||
if err := cleanup(context.Background()); err != nil {
|
||||
logger.Error("failed to close store", "err", err)
|
||||
}
|
||||
}()
|
||||
|
||||
var emailSender api.EmailSender
|
||||
emailVerificationEnabled := true
|
||||
smtpHost := strings.TrimSpace(cfg.SMTP.Host)
|
||||
if smtpHost == "" {
|
||||
emailVerificationEnabled = false
|
||||
}
|
||||
if smtpHost != "" && isExampleDomain(smtpHost) {
|
||||
emailVerificationEnabled = false
|
||||
logger.Warn("smtp host is a placeholder; disabling email delivery", "host", smtpHost)
|
||||
smtpHost = ""
|
||||
}
|
||||
if smtpHost != "" {
|
||||
tlsMode := mailer.ParseTLSMode(cfg.SMTP.TLS.Mode)
|
||||
sender, err := mailer.New(mailer.Config{
|
||||
Host: smtpHost,
|
||||
Port: cfg.SMTP.Port,
|
||||
Username: cfg.SMTP.Username,
|
||||
Password: cfg.SMTP.Password,
|
||||
From: cfg.SMTP.From,
|
||||
ReplyTo: cfg.SMTP.ReplyTo,
|
||||
Timeout: cfg.SMTP.Timeout,
|
||||
TLSMode: tlsMode,
|
||||
InsecureSkipVerify: cfg.SMTP.TLS.InsecureSkipVerify,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
emailSender = mailerAdapter{sender: sender}
|
||||
}
|
||||
if emailSender == nil {
|
||||
emailVerificationEnabled = false
|
||||
}
|
||||
|
||||
gormDB, gormCleanup, err := openAdminSettingsDB(cfg.Store)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if gormCleanup != nil {
|
||||
if err := gormCleanup(context.Background()); err != nil {
|
||||
logger.Error("failed to close admin settings db", "err", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
service.SetDB(gormDB)
|
||||
|
||||
var stopXraySync func(context.Context) error
|
||||
if cfg.Xray.Sync.Enabled {
|
||||
syncInterval := cfg.Xray.Sync.Interval
|
||||
if syncInterval <= 0 {
|
||||
syncInterval = 5 * time.Minute
|
||||
}
|
||||
templatePath := strings.TrimSpace(cfg.Xray.Sync.TemplatePath)
|
||||
if templatePath == "" {
|
||||
templatePath = filepath.Join("account", "config", "xray.config.template.json")
|
||||
}
|
||||
outputPath := strings.TrimSpace(cfg.Xray.Sync.OutputPath)
|
||||
if outputPath == "" {
|
||||
outputPath = "/usr/local/etc/xray/config.json"
|
||||
}
|
||||
source, err := xrayconfig.NewGormClientSource(gormDB)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
syncer, err := xrayconfig.NewPeriodicSyncer(xrayconfig.PeriodicOptions{
|
||||
Logger: logger.With("component", "xray-sync"),
|
||||
Interval: syncInterval,
|
||||
Source: source,
|
||||
Generator: xrayconfig.Generator{TemplatePath: templatePath, OutputPath: outputPath},
|
||||
ValidateCommand: cfg.Xray.Sync.ValidateCommand,
|
||||
RestartCommand: cfg.Xray.Sync.RestartCommand,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stop, err := syncer.Start(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
logger.Info("xray periodic sync enabled", "interval", syncInterval, "output", outputPath)
|
||||
stopXraySync = stop
|
||||
}
|
||||
|
||||
if stopXraySync != nil {
|
||||
defer func() {
|
||||
waitCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := stopXraySync(waitCtx); err != nil {
|
||||
logger.Warn("xray syncer shutdown", "err", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
options := []api.Option{
|
||||
api.WithStore(st),
|
||||
api.WithSessionTTL(cfg.Session.TTL),
|
||||
}
|
||||
if emailSender != nil {
|
||||
options = append(options, api.WithEmailSender(emailSender))
|
||||
}
|
||||
options = append(options, api.WithEmailVerification(emailVerificationEnabled))
|
||||
api.RegisterRoutes(r, options...)
|
||||
|
||||
addr := strings.TrimSpace(cfg.Server.Addr)
|
||||
if addr == "" {
|
||||
addr = ":8080"
|
||||
}
|
||||
|
||||
tlsSettings := cfg.Server.TLS
|
||||
certFile := strings.TrimSpace(tlsSettings.CertFile)
|
||||
keyFile := strings.TrimSpace(tlsSettings.KeyFile)
|
||||
caFile := strings.TrimSpace(tlsSettings.CAFile)
|
||||
clientCAFile := strings.TrimSpace(tlsSettings.ClientCAFile)
|
||||
|
||||
useTLS := tlsSettings.IsEnabled()
|
||||
|
||||
var tlsConfig *tls.Config
|
||||
if useTLS {
|
||||
if certFile == "" || keyFile == "" {
|
||||
return fmt.Errorf("tls is enabled but certFile (%q) or keyFile (%q) is empty", certFile, keyFile)
|
||||
}
|
||||
|
||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load tls certificate: %w", err)
|
||||
}
|
||||
|
||||
if caFile != "" {
|
||||
caPEM, err := os.ReadFile(caFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read ca file %q: %w", caFile, err)
|
||||
}
|
||||
|
||||
var block *pem.Block
|
||||
existing := make(map[string]struct{}, len(cert.Certificate))
|
||||
for _, c := range cert.Certificate {
|
||||
existing[string(c)] = struct{}{}
|
||||
}
|
||||
|
||||
for len(caPEM) > 0 {
|
||||
block, caPEM = pem.Decode(caPEM)
|
||||
if block == nil {
|
||||
break
|
||||
}
|
||||
if block.Type != "CERTIFICATE" || len(block.Bytes) == 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := existing[string(block.Bytes)]; ok {
|
||||
continue
|
||||
}
|
||||
cert.Certificate = append(cert.Certificate, block.Bytes)
|
||||
}
|
||||
|
||||
if len(cert.Certificate) == 0 {
|
||||
return fmt.Errorf("ca file %q did not contain any certificates", caFile)
|
||||
}
|
||||
}
|
||||
|
||||
tlsConfig = &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
|
||||
if clientCAFile != "" {
|
||||
caBytes, err := os.ReadFile(clientCAFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pool := x509.NewCertPool()
|
||||
if !pool.AppendCertsFromPEM(caBytes) {
|
||||
return errors.New("failed to parse client CA file")
|
||||
}
|
||||
tlsConfig.ClientCAs = pool
|
||||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
}
|
||||
} else {
|
||||
if certFile != "" || keyFile != "" {
|
||||
logger.Info("TLS disabled; certificate paths will be ignored", "certFile", certFile, "keyFile", keyFile)
|
||||
}
|
||||
if clientCAFile != "" {
|
||||
logger.Warn("client CA configured but TLS is disabled; ignoring", "clientCAFile", clientCAFile)
|
||||
}
|
||||
}
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: addr,
|
||||
Handler: r,
|
||||
ReadTimeout: cfg.Server.ReadTimeout,
|
||||
WriteTimeout: cfg.Server.WriteTimeout,
|
||||
}
|
||||
|
||||
if useTLS {
|
||||
srv.TLSConfig = tlsConfig
|
||||
}
|
||||
|
||||
logger.Info("starting account service", "addr", addr, "tls", useTLS)
|
||||
|
||||
var listenCertFile, listenKeyFile string
|
||||
if useTLS {
|
||||
if tlsSettings.RedirectHTTP {
|
||||
go func() {
|
||||
redirectAddr := deriveRedirectAddr(addr)
|
||||
redirectSrv := &http.Server{
|
||||
Addr: redirectAddr,
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
host := r.Host
|
||||
if host == "" {
|
||||
host = redirectAddr
|
||||
}
|
||||
target := "https://" + host + r.URL.RequestURI()
|
||||
http.Redirect(w, r, target, http.StatusPermanentRedirect)
|
||||
}),
|
||||
}
|
||||
if err := redirectSrv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
logger.Error("http redirect listener exited", "err", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if tlsConfig != nil && len(tlsConfig.Certificates) > 0 {
|
||||
listenCertFile = ""
|
||||
listenKeyFile = ""
|
||||
} else {
|
||||
listenCertFile = certFile
|
||||
listenKeyFile = keyFile
|
||||
}
|
||||
|
||||
if err := srv.ListenAndServeTLS(listenCertFile, listenKeyFile); err != nil {
|
||||
if !errors.Is(err, http.ErrServerClosed) {
|
||||
logger.Error("account service shutdown", "err", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if err := srv.ListenAndServe(); err != nil {
|
||||
if !errors.Is(err, http.ErrServerClosed) {
|
||||
logger.Error("account service shutdown", "err", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
26
account/config/account-agent.yaml
Normal file
26
account/config/account-agent.yaml
Normal file
@ -0,0 +1,26 @@
|
||||
mode: "agent"
|
||||
|
||||
log:
|
||||
level: info
|
||||
|
||||
agent:
|
||||
id: "edge-node-1"
|
||||
controllerUrl: "https://account.svc.plus"
|
||||
apiToken: "replace-with-agent-token"
|
||||
httpTimeout: 15s
|
||||
statusInterval: 1m
|
||||
syncInterval: 5m
|
||||
tls:
|
||||
insecureSkipVerify: false
|
||||
|
||||
xray:
|
||||
sync:
|
||||
enabled: true
|
||||
interval: 5m
|
||||
templatePath: "account/config/xray.config.template.json"
|
||||
outputPath: "/usr/local/etc/xray/config.json"
|
||||
validateCommand: []
|
||||
restartCommand:
|
||||
- "systemctl"
|
||||
- "restart"
|
||||
- "xray.service"
|
||||
85
account/config/account-server.yaml
Normal file
85
account/config/account-server.yaml
Normal file
@ -0,0 +1,85 @@
|
||||
mode: "server-agent"
|
||||
|
||||
log:
|
||||
level: info
|
||||
|
||||
server:
|
||||
addr: ":8080"
|
||||
readTimeout: 15s
|
||||
writeTimeout: 15s
|
||||
publicUrl: "http://localhost:8080"
|
||||
allowedOrigins:
|
||||
- "https://dev.svc.plus"
|
||||
- "https://dev-homepage.svc.plus"
|
||||
- "https://www.svc.plus"
|
||||
- "https://global-homepage.svc.plus"
|
||||
- "https://account.svc.plus"
|
||||
- "https://localhost:8443"
|
||||
- "http://localhost:8080"
|
||||
- "http://127.0.0.1:8080"
|
||||
- "http://localhost:3001"
|
||||
- "http://127.0.0.1:3001"
|
||||
- "http://localhost:3000"
|
||||
- "http://127.0.0.1:3000"
|
||||
tls:
|
||||
enabled: false
|
||||
certFile: ""
|
||||
keyFile: ""
|
||||
caFile: ""
|
||||
clientCAFile: ""
|
||||
redirectHttp: false
|
||||
|
||||
store:
|
||||
driver: "postgres"
|
||||
dsn: "postgres://shenlan:password@127.0.0.1:5432/account?sslmode=disable"
|
||||
maxOpenConns: 30
|
||||
maxIdleConns: 10
|
||||
|
||||
session:
|
||||
ttl: 24h
|
||||
cache: "redis"
|
||||
redis:
|
||||
addr: "127.0.0.1:6379"
|
||||
password: ""
|
||||
|
||||
smtp:
|
||||
host: "smtp.example.com"
|
||||
port: 587
|
||||
username: "apikey"
|
||||
password: "YOUR_PASSWORD"
|
||||
from: "XControl Account <no-reply@example.com>"
|
||||
replyTo: ""
|
||||
timeout: 10s
|
||||
tls:
|
||||
mode: "auto"
|
||||
insecureSkipVerify: false
|
||||
|
||||
xray:
|
||||
sync:
|
||||
enabled: false
|
||||
interval: 5m
|
||||
templatePath: "account/config/xray.config.template.json"
|
||||
outputPath: "/usr/local/etc/xray/config.json"
|
||||
validateCommand: []
|
||||
restartCommand:
|
||||
- "systemctl"
|
||||
- "restart"
|
||||
- "xray.service"
|
||||
|
||||
agent:
|
||||
id: "account-primary"
|
||||
controllerUrl: "http://127.0.0.1:8080"
|
||||
apiToken: "replace-with-agent-token"
|
||||
httpTimeout: 15s
|
||||
statusInterval: 1m
|
||||
syncInterval: 5m
|
||||
tls:
|
||||
insecureSkipVerify: false
|
||||
|
||||
agents:
|
||||
credentials:
|
||||
- id: "account-primary"
|
||||
name: "Account Server (local agent)"
|
||||
token: "replace-with-agent-token"
|
||||
groups:
|
||||
- "default"
|
||||
@ -1,3 +1,5 @@
|
||||
mode: "server-agent"
|
||||
|
||||
log:
|
||||
level: info
|
||||
|
||||
@ -63,3 +65,21 @@ xray:
|
||||
- "systemctl"
|
||||
- "restart"
|
||||
- "xray.service"
|
||||
|
||||
agent:
|
||||
id: "account-primary"
|
||||
controllerUrl: "http://127.0.0.1:8080"
|
||||
apiToken: "replace-with-agent-token"
|
||||
httpTimeout: 15s
|
||||
statusInterval: 1m
|
||||
syncInterval: 5m
|
||||
tls:
|
||||
insecureSkipVerify: false
|
||||
|
||||
agents:
|
||||
credentials:
|
||||
- id: "account-primary"
|
||||
name: "Account Server (local agent)"
|
||||
token: "replace-with-agent-token"
|
||||
groups:
|
||||
- "default"
|
||||
|
||||
@ -19,12 +19,15 @@ type Log struct {
|
||||
|
||||
// Config holds configuration for the account service.
|
||||
type Config struct {
|
||||
Mode string `yaml:"mode"`
|
||||
Log Log `yaml:"log"`
|
||||
Server Server `yaml:"server"`
|
||||
Store Store `yaml:"store"`
|
||||
Session Session `yaml:"session"`
|
||||
SMTP SMTP `yaml:"smtp"`
|
||||
Xray Xray `yaml:"xray"`
|
||||
Agent Agent `yaml:"agent"`
|
||||
Agents Agents `yaml:"agents"`
|
||||
}
|
||||
|
||||
// Server defines HTTP server configuration.
|
||||
@ -106,6 +109,36 @@ type XraySync struct {
|
||||
RestartCommand []string `yaml:"restartCommand"`
|
||||
}
|
||||
|
||||
// Agent defines configuration for agent mode deployments.
|
||||
type Agent struct {
|
||||
ID string `yaml:"id"`
|
||||
ControllerURL string `yaml:"controllerUrl"`
|
||||
APIToken string `yaml:"apiToken"`
|
||||
HTTPTimeout time.Duration `yaml:"httpTimeout"`
|
||||
StatusInterval time.Duration `yaml:"statusInterval"`
|
||||
SyncInterval time.Duration `yaml:"syncInterval"`
|
||||
TLS AgentTLS `yaml:"tls"`
|
||||
}
|
||||
|
||||
// AgentTLS configures TLS behaviour for the agent HTTP client.
|
||||
type AgentTLS struct {
|
||||
InsecureSkipVerify bool `yaml:"insecureSkipVerify"`
|
||||
}
|
||||
|
||||
// Agents describes the controller-side agent configuration.
|
||||
type Agents struct {
|
||||
Credentials []AgentCredential `yaml:"credentials"`
|
||||
}
|
||||
|
||||
// AgentCredential represents a single agent identity authorised to call the
|
||||
// controller API.
|
||||
type AgentCredential struct {
|
||||
ID string `yaml:"id"`
|
||||
Name string `yaml:"name"`
|
||||
Token string `yaml:"token"`
|
||||
Groups []string `yaml:"groups"`
|
||||
}
|
||||
|
||||
// Load reads the configuration file at the provided path. When path is empty,
|
||||
// it defaults to account/config/account.yaml. If the file does not exist an
|
||||
// empty configuration is returned.
|
||||
|
||||
151
account/internal/agentmode/client.go
Normal file
151
account/internal/agentmode/client.go
Normal file
@ -0,0 +1,151 @@
|
||||
package agentmode
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"xcontrol/account/internal/agentproto"
|
||||
)
|
||||
|
||||
// ClientOptions configures the HTTP client used to communicate with the
|
||||
// controller.
|
||||
type ClientOptions struct {
|
||||
Timeout time.Duration
|
||||
InsecureSkipVerify bool
|
||||
UserAgent string
|
||||
}
|
||||
|
||||
// Client issues authenticated requests against the controller.
|
||||
type Client struct {
|
||||
baseURL *url.URL
|
||||
token string
|
||||
http *http.Client
|
||||
userAgent string
|
||||
}
|
||||
|
||||
// NewClient constructs a client for the provided controller URL and token.
|
||||
func NewClient(baseURL, token string, opts ClientOptions) (*Client, error) {
|
||||
trimmedURL := strings.TrimSpace(baseURL)
|
||||
if trimmedURL == "" {
|
||||
return nil, errors.New("controller url is required")
|
||||
}
|
||||
parsed, err := url.Parse(trimmedURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse controller url: %w", err)
|
||||
}
|
||||
token = strings.TrimSpace(token)
|
||||
if token == "" {
|
||||
return nil, errors.New("controller token is required")
|
||||
}
|
||||
|
||||
timeout := opts.Timeout
|
||||
if timeout <= 0 {
|
||||
timeout = 15 * time.Second
|
||||
}
|
||||
|
||||
transport := http.DefaultTransport
|
||||
if t, ok := transport.(*http.Transport); ok {
|
||||
clone := t.Clone()
|
||||
if opts.InsecureSkipVerify {
|
||||
if clone.TLSClientConfig == nil {
|
||||
clone.TLSClientConfig = &tls.Config{}
|
||||
}
|
||||
clone.TLSClientConfig.InsecureSkipVerify = true
|
||||
}
|
||||
transport = clone
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: timeout,
|
||||
Transport: transport,
|
||||
}
|
||||
|
||||
userAgent := strings.TrimSpace(opts.UserAgent)
|
||||
if userAgent == "" {
|
||||
userAgent = "xcontrol-agent"
|
||||
}
|
||||
|
||||
return &Client{
|
||||
baseURL: parsed,
|
||||
token: token,
|
||||
http: client,
|
||||
userAgent: userAgent,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ListClients fetches the current set of Xray clients from the controller.
|
||||
func (c *Client) ListClients(ctx context.Context) (agentproto.ClientListResponse, error) {
|
||||
endpoint, err := url.JoinPath(c.baseURL.String(), "/api/agent/v1/users")
|
||||
if err != nil {
|
||||
return agentproto.ClientListResponse{}, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||
if err != nil {
|
||||
return agentproto.ClientListResponse{}, err
|
||||
}
|
||||
c.applyHeaders(req)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return agentproto.ClientListResponse{}, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<14))
|
||||
return agentproto.ClientListResponse{}, fmt.Errorf("controller returned %s: %s", resp.Status, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
var payload agentproto.ClientListResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
|
||||
return agentproto.ClientListResponse{}, fmt.Errorf("decode client list: %w", err)
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
// ReportStatus submits the agent status report to the controller.
|
||||
func (c *Client) ReportStatus(ctx context.Context, report agentproto.StatusReport) error {
|
||||
endpoint, err := url.JoinPath(c.baseURL.String(), "/api/agent/v1/status")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf, err := json.Marshal(report)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encode status report: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(buf))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.applyHeaders(req)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<14))
|
||||
return fmt.Errorf("controller returned %s: %s", resp.Status, strings.TrimSpace(string(body)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) applyHeaders(req *http.Request) {
|
||||
req.Header.Set("Authorization", "Bearer "+c.token)
|
||||
req.Header.Set("User-Agent", c.userAgent)
|
||||
}
|
||||
195
account/internal/agentmode/runner.go
Normal file
195
account/internal/agentmode/runner.go
Normal file
@ -0,0 +1,195 @@
|
||||
package agentmode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"xcontrol/account/config"
|
||||
"xcontrol/account/internal/agentproto"
|
||||
"xcontrol/account/internal/xrayconfig"
|
||||
)
|
||||
|
||||
// Options configures the agent runtime.
|
||||
type Options struct {
|
||||
Logger *slog.Logger
|
||||
Agent config.Agent
|
||||
Xray config.Xray
|
||||
}
|
||||
|
||||
// Run launches the agent mode control loop. It blocks until the context is
|
||||
// cancelled or a fatal error occurs during setup.
|
||||
func Run(ctx context.Context, opts Options) error {
|
||||
if ctx == nil {
|
||||
return errors.New("context is required")
|
||||
}
|
||||
|
||||
logger := opts.Logger
|
||||
if logger == nil {
|
||||
logger = slog.Default()
|
||||
}
|
||||
|
||||
controllerURL := strings.TrimSpace(opts.Agent.ControllerURL)
|
||||
if controllerURL == "" {
|
||||
return errors.New("agent.controllerUrl is required")
|
||||
}
|
||||
token := strings.TrimSpace(opts.Agent.APIToken)
|
||||
if token == "" {
|
||||
return errors.New("agent.apiToken is required")
|
||||
}
|
||||
|
||||
syncInterval := opts.Agent.SyncInterval
|
||||
if syncInterval <= 0 {
|
||||
syncInterval = opts.Xray.Sync.Interval
|
||||
}
|
||||
if syncInterval <= 0 {
|
||||
syncInterval = 5 * time.Minute
|
||||
}
|
||||
|
||||
statusInterval := opts.Agent.StatusInterval
|
||||
if statusInterval <= 0 {
|
||||
statusInterval = time.Minute
|
||||
}
|
||||
|
||||
httpTimeout := opts.Agent.HTTPTimeout
|
||||
if httpTimeout <= 0 {
|
||||
httpTimeout = 15 * time.Second
|
||||
}
|
||||
|
||||
templatePath := strings.TrimSpace(opts.Xray.Sync.TemplatePath)
|
||||
if templatePath == "" {
|
||||
templatePath = filepath.Join("account", "config", "xray.config.template.json")
|
||||
}
|
||||
outputPath := strings.TrimSpace(opts.Xray.Sync.OutputPath)
|
||||
if outputPath == "" {
|
||||
outputPath = "/usr/local/etc/xray/config.json"
|
||||
}
|
||||
|
||||
client, err := NewClient(controllerURL, token, ClientOptions{
|
||||
Timeout: httpTimeout,
|
||||
InsecureSkipVerify: opts.Agent.TLS.InsecureSkipVerify,
|
||||
UserAgent: buildUserAgent(opts.Agent.ID),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tracker := newSyncTracker()
|
||||
source := NewHTTPClientSource(client, tracker)
|
||||
|
||||
syncLogger := logger.With("component", "agent-xray-sync")
|
||||
syncer, err := xrayconfig.NewPeriodicSyncer(xrayconfig.PeriodicOptions{
|
||||
Logger: syncLogger,
|
||||
Interval: syncInterval,
|
||||
Source: source,
|
||||
Generator: xrayconfig.Generator{TemplatePath: templatePath, OutputPath: outputPath},
|
||||
ValidateCommand: opts.Xray.Sync.ValidateCommand,
|
||||
RestartCommand: opts.Xray.Sync.RestartCommand,
|
||||
OnSync: func(result xrayconfig.SyncResult) {
|
||||
if result.Error != nil {
|
||||
tracker.MarkError(result.Error, result.CompletedAt)
|
||||
return
|
||||
}
|
||||
tracker.MarkSuccess(result.CompletedAt)
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stopSync, err := syncer.Start(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
waitCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := stopSync(waitCtx); err != nil {
|
||||
logger.Warn("xray syncer shutdown", "err", err)
|
||||
}
|
||||
}()
|
||||
|
||||
reporterCtx, reporterCancel := context.WithCancel(ctx)
|
||||
defer reporterCancel()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
runStatusReporter(reporterCtx, client, tracker, statusInterval, syncInterval, logger)
|
||||
}()
|
||||
|
||||
<-ctx.Done()
|
||||
reporterCancel()
|
||||
wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildUserAgent(id string) string {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
return "xcontrol-agent"
|
||||
}
|
||||
return fmt.Sprintf("xcontrol-agent/%s", id)
|
||||
}
|
||||
|
||||
func runStatusReporter(ctx context.Context, client *Client, tracker *syncTracker, interval, syncInterval time.Duration, logger *slog.Logger) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
send := func() {
|
||||
snapshot := tracker.Snapshot()
|
||||
report := buildStatusReport(snapshot, syncInterval)
|
||||
if err := client.ReportStatus(ctx, report); err != nil {
|
||||
logger.Warn("failed to report agent status", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
send()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
send()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func buildStatusReport(snapshot trackerSnapshot, syncInterval time.Duration) agentproto.StatusReport {
|
||||
healthy := snapshot.LastError == "" && !snapshot.LastSuccess.IsZero()
|
||||
|
||||
running := false
|
||||
var lastSyncPtr *time.Time
|
||||
if !snapshot.LastSuccess.IsZero() {
|
||||
running = time.Since(snapshot.LastSuccess) <= 3*syncInterval
|
||||
last := snapshot.LastSuccess
|
||||
lastSyncPtr = &last
|
||||
}
|
||||
|
||||
report := agentproto.StatusReport{
|
||||
Healthy: healthy,
|
||||
Message: snapshot.LastError,
|
||||
Users: snapshot.Clients,
|
||||
SyncRevision: snapshot.Revision,
|
||||
Xray: agentproto.XrayStatus{
|
||||
Running: running,
|
||||
Clients: snapshot.Clients,
|
||||
LastSync: func() *time.Time {
|
||||
if lastSyncPtr == nil {
|
||||
return nil
|
||||
}
|
||||
copy := *lastSyncPtr
|
||||
return ©
|
||||
}(),
|
||||
},
|
||||
}
|
||||
|
||||
return report
|
||||
}
|
||||
33
account/internal/agentmode/source_http.go
Normal file
33
account/internal/agentmode/source_http.go
Normal file
@ -0,0 +1,33 @@
|
||||
package agentmode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"xcontrol/account/internal/xrayconfig"
|
||||
)
|
||||
|
||||
// HTTPClientSource retrieves Xray clients from the controller over HTTP.
|
||||
type HTTPClientSource struct {
|
||||
client *Client
|
||||
tracker *syncTracker
|
||||
}
|
||||
|
||||
// NewHTTPClientSource constructs a source backed by the provided client and
|
||||
// tracker.
|
||||
func NewHTTPClientSource(client *Client, tracker *syncTracker) *HTTPClientSource {
|
||||
return &HTTPClientSource{client: client, tracker: tracker}
|
||||
}
|
||||
|
||||
// ListClients implements xrayconfig.ClientSource by fetching the latest client
|
||||
// list via the controller API.
|
||||
func (s *HTTPClientSource) ListClients(ctx context.Context) ([]xrayconfig.Client, error) {
|
||||
resp, err := s.client.ListClients(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.tracker != nil {
|
||||
s.tracker.UpdateFetch(len(resp.Clients), resp.Revision, time.Now().UTC())
|
||||
}
|
||||
return resp.Clients, nil
|
||||
}
|
||||
72
account/internal/agentmode/tracker.go
Normal file
72
account/internal/agentmode/tracker.go
Normal file
@ -0,0 +1,72 @@
|
||||
package agentmode
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type syncTracker struct {
|
||||
mu sync.RWMutex
|
||||
clients int
|
||||
revision string
|
||||
lastFetch time.Time
|
||||
lastSuccess time.Time
|
||||
lastError string
|
||||
lastErrorAt time.Time
|
||||
}
|
||||
|
||||
func newSyncTracker() *syncTracker {
|
||||
return &syncTracker{}
|
||||
}
|
||||
|
||||
func (t *syncTracker) UpdateFetch(count int, revision string, when time.Time) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
t.clients = count
|
||||
t.revision = revision
|
||||
t.lastFetch = when
|
||||
}
|
||||
|
||||
func (t *syncTracker) MarkSuccess(at time.Time) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
t.lastSuccess = at
|
||||
t.lastError = ""
|
||||
t.lastErrorAt = time.Time{}
|
||||
}
|
||||
|
||||
func (t *syncTracker) MarkError(err error, at time.Time) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
t.lastError = err.Error()
|
||||
t.lastErrorAt = at
|
||||
}
|
||||
|
||||
type trackerSnapshot struct {
|
||||
Clients int
|
||||
Revision string
|
||||
LastFetch time.Time
|
||||
LastSuccess time.Time
|
||||
LastError string
|
||||
LastErrorAt time.Time
|
||||
}
|
||||
|
||||
func (t *syncTracker) Snapshot() trackerSnapshot {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
|
||||
return trackerSnapshot{
|
||||
Clients: t.clients,
|
||||
Revision: t.revision,
|
||||
LastFetch: t.lastFetch,
|
||||
LastSuccess: t.lastSuccess,
|
||||
LastError: t.lastError,
|
||||
LastErrorAt: t.lastErrorAt,
|
||||
}
|
||||
}
|
||||
34
account/internal/agentproto/types.go
Normal file
34
account/internal/agentproto/types.go
Normal file
@ -0,0 +1,34 @@
|
||||
package agentproto
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"xcontrol/account/internal/xrayconfig"
|
||||
)
|
||||
|
||||
// ClientListResponse represents the payload returned by the controller when an
|
||||
// agent requests the latest set of Xray clients.
|
||||
type ClientListResponse struct {
|
||||
Clients []xrayconfig.Client `json:"clients"`
|
||||
Total int `json:"total"`
|
||||
GeneratedAt time.Time `json:"generatedAt"`
|
||||
Revision string `json:"revision,omitempty"`
|
||||
}
|
||||
|
||||
// StatusReport captures the runtime state of an agent and the managed Xray
|
||||
// instance.
|
||||
type StatusReport struct {
|
||||
Healthy bool `json:"healthy"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Users int `json:"users"`
|
||||
SyncRevision string `json:"syncRevision,omitempty"`
|
||||
Xray XrayStatus `json:"xray"`
|
||||
}
|
||||
|
||||
// XrayStatus describes the synchronisation state of the managed Xray process.
|
||||
type XrayStatus struct {
|
||||
Running bool `json:"running"`
|
||||
Clients int `json:"clients"`
|
||||
LastSync *time.Time `json:"lastSync,omitempty"`
|
||||
ConfigHash string `json:"configHash,omitempty"`
|
||||
}
|
||||
179
account/internal/agentserver/registry.go
Normal file
179
account/internal/agentserver/registry.go
Normal file
@ -0,0 +1,179 @@
|
||||
package agentserver
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"xcontrol/account/internal/agentproto"
|
||||
)
|
||||
|
||||
// Credential defines the authentication material assigned to a managed agent.
|
||||
type Credential struct {
|
||||
ID string `yaml:"id"`
|
||||
Name string `yaml:"name"`
|
||||
Token string `yaml:"token"`
|
||||
Groups []string `yaml:"groups"`
|
||||
}
|
||||
|
||||
// Config groups the credential set exposed through configuration.
|
||||
type Config struct {
|
||||
Credentials []Credential `yaml:"credentials"`
|
||||
}
|
||||
|
||||
// Identity represents an authenticated agent instance.
|
||||
type Identity struct {
|
||||
ID string
|
||||
Name string
|
||||
Groups []string
|
||||
}
|
||||
|
||||
// StatusSnapshot captures the last reported status for an agent.
|
||||
type StatusSnapshot struct {
|
||||
Agent Identity
|
||||
Report agentproto.StatusReport
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// Registry manages agent credentials and status reports in-memory.
|
||||
type Registry struct {
|
||||
mu sync.RWMutex
|
||||
credentials map[[32]byte]Identity
|
||||
byID map[string]Identity
|
||||
statuses map[string]StatusSnapshot
|
||||
}
|
||||
|
||||
// NewRegistry constructs a registry from configuration, validating credentials
|
||||
// and normalising their representation.
|
||||
func NewRegistry(cfg Config) (*Registry, error) {
|
||||
r := &Registry{
|
||||
credentials: make(map[[32]byte]Identity),
|
||||
byID: make(map[string]Identity),
|
||||
statuses: make(map[string]StatusSnapshot),
|
||||
}
|
||||
|
||||
for _, cred := range cfg.Credentials {
|
||||
id := strings.TrimSpace(cred.ID)
|
||||
token := strings.TrimSpace(cred.Token)
|
||||
if id == "" {
|
||||
return nil, errors.New("agent credential id is required")
|
||||
}
|
||||
if token == "" {
|
||||
return nil, errors.New("agent credential token is required")
|
||||
}
|
||||
if _, exists := r.byID[id]; exists {
|
||||
return nil, errors.New("duplicate agent credential id: " + id)
|
||||
}
|
||||
|
||||
digest := sha256.Sum256([]byte(token))
|
||||
if _, exists := r.credentials[digest]; exists {
|
||||
return nil, errors.New("duplicate agent credential token")
|
||||
}
|
||||
|
||||
identity := Identity{
|
||||
ID: id,
|
||||
Name: strings.TrimSpace(cred.Name),
|
||||
Groups: normalizeStrings(cred.Groups),
|
||||
}
|
||||
r.credentials[digest] = identity
|
||||
r.byID[id] = identity
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// Authenticate validates the provided token and returns the associated agent
|
||||
// identity when successful.
|
||||
func (r *Registry) Authenticate(token string) (*Identity, bool) {
|
||||
token = strings.TrimSpace(token)
|
||||
if token == "" {
|
||||
return nil, false
|
||||
}
|
||||
digest := sha256.Sum256([]byte(token))
|
||||
|
||||
r.mu.RLock()
|
||||
identity, ok := r.credentials[digest]
|
||||
r.mu.RUnlock()
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
copy := identity
|
||||
return ©, true
|
||||
}
|
||||
|
||||
// ReportStatus records the status report for the provided agent identity.
|
||||
func (r *Registry) ReportStatus(agent Identity, report agentproto.StatusReport) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.statuses[agent.ID] = StatusSnapshot{
|
||||
Agent: agent,
|
||||
Report: report,
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
// Statuses returns the latest status snapshot for all agents sorted by ID.
|
||||
func (r *Registry) Statuses() []StatusSnapshot {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
snapshots := make([]StatusSnapshot, 0, len(r.byID))
|
||||
for id, identity := range r.byID {
|
||||
snapshot, ok := r.statuses[id]
|
||||
if !ok {
|
||||
snapshot = StatusSnapshot{Agent: identity}
|
||||
}
|
||||
snapshots = append(snapshots, snapshot)
|
||||
}
|
||||
|
||||
sort.Slice(snapshots, func(i, j int) bool {
|
||||
return snapshots[i].Agent.ID < snapshots[j].Agent.ID
|
||||
})
|
||||
|
||||
return snapshots
|
||||
}
|
||||
|
||||
// Agents returns the configured agent identities in a deterministic order.
|
||||
func (r *Registry) Agents() []Identity {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
agents := make([]Identity, 0, len(r.byID))
|
||||
for _, identity := range r.byID {
|
||||
agents = append(agents, identity)
|
||||
}
|
||||
sort.Slice(agents, func(i, j int) bool {
|
||||
return agents[i].ID < agents[j].ID
|
||||
})
|
||||
return agents
|
||||
}
|
||||
|
||||
// normalizeStrings trims whitespace and removes duplicates from the provided
|
||||
// slice while preserving the original order for unique entries.
|
||||
func normalizeStrings(values []string) []string {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
result := make([]string, 0, len(values))
|
||||
seen := make(map[string]struct{}, len(values))
|
||||
for _, value := range values {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[trimmed]; exists {
|
||||
continue
|
||||
}
|
||||
seen[trimmed] = struct{}{}
|
||||
result = append(result, trimmed)
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil
|
||||
}
|
||||
return result
|
||||
}
|
||||
104
account/internal/agentserver/registry_test.go
Normal file
104
account/internal/agentserver/registry_test.go
Normal file
@ -0,0 +1,104 @@
|
||||
package agentserver
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"xcontrol/account/internal/agentproto"
|
||||
)
|
||||
|
||||
func TestNewRegistryValidation(t *testing.T) {
|
||||
_, err := NewRegistry(Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error for empty config: %v", err)
|
||||
}
|
||||
|
||||
_, err = NewRegistry(Config{Credentials: []Credential{{ID: "", Token: "token"}}})
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for empty id")
|
||||
}
|
||||
|
||||
_, err = NewRegistry(Config{Credentials: []Credential{{ID: "edge", Token: ""}}})
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for empty token")
|
||||
}
|
||||
|
||||
_, err = NewRegistry(Config{Credentials: []Credential{{ID: "a", Token: "1"}, {ID: "a", Token: "2"}}})
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for duplicate id")
|
||||
}
|
||||
|
||||
_, err = NewRegistry(Config{Credentials: []Credential{{ID: "a", Token: "dup"}, {ID: "b", Token: "dup"}}})
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for duplicate token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryAuthenticateAndStatus(t *testing.T) {
|
||||
registry, err := NewRegistry(Config{Credentials: []Credential{{ID: "edge", Name: "Edge", Token: "secret", Groups: []string{"default"}}}})
|
||||
if err != nil {
|
||||
t.Fatalf("new registry: %v", err)
|
||||
}
|
||||
identity, ok := registry.Authenticate("secret")
|
||||
if !ok || identity == nil {
|
||||
t.Fatalf("expected authentication to succeed")
|
||||
}
|
||||
if identity.ID != "edge" {
|
||||
t.Fatalf("unexpected identity id %q", identity.ID)
|
||||
}
|
||||
|
||||
report := agentproto.StatusReport{
|
||||
Healthy: true,
|
||||
Users: 5,
|
||||
Xray: agentproto.XrayStatus{
|
||||
Running: true,
|
||||
Clients: 5,
|
||||
},
|
||||
}
|
||||
registry.ReportStatus(*identity, report)
|
||||
|
||||
snapshots := registry.Statuses()
|
||||
if len(snapshots) != 1 {
|
||||
t.Fatalf("expected 1 snapshot, got %d", len(snapshots))
|
||||
}
|
||||
snapshot := snapshots[0]
|
||||
if snapshot.Agent.ID != "edge" {
|
||||
t.Fatalf("unexpected snapshot agent id %q", snapshot.Agent.ID)
|
||||
}
|
||||
if !snapshot.Report.Healthy {
|
||||
t.Fatalf("expected healthy report")
|
||||
}
|
||||
if snapshot.Report.Users != 5 {
|
||||
t.Fatalf("unexpected users count %d", snapshot.Report.Users)
|
||||
}
|
||||
if snapshot.Report.Xray.Clients != 5 {
|
||||
t.Fatalf("unexpected xray clients %d", snapshot.Report.Xray.Clients)
|
||||
}
|
||||
if snapshot.UpdatedAt.IsZero() {
|
||||
t.Fatalf("expected updated timestamp")
|
||||
}
|
||||
|
||||
// Ensure snapshots include configured agents without reports.
|
||||
registry, err = NewRegistry(Config{Credentials: []Credential{{ID: "a", Token: "1"}, {ID: "b", Token: "2"}}})
|
||||
if err != nil {
|
||||
t.Fatalf("new registry: %v", err)
|
||||
}
|
||||
snapshots = registry.Statuses()
|
||||
if len(snapshots) != 2 {
|
||||
t.Fatalf("expected 2 snapshots, got %d", len(snapshots))
|
||||
}
|
||||
if snapshots[0].Agent.ID != "a" || snapshots[1].Agent.ID != "b" {
|
||||
t.Fatalf("unexpected snapshot ordering: %+v", snapshots)
|
||||
}
|
||||
|
||||
// Report status with timestamp and ensure Latest is retained.
|
||||
now := time.Now().UTC()
|
||||
registry.ReportStatus(snapshots[0].Agent, agentproto.StatusReport{Users: 1})
|
||||
entries := registry.Statuses()
|
||||
if len(entries) != 2 {
|
||||
t.Fatalf("expected 2 entries, got %d", len(entries))
|
||||
}
|
||||
if entries[0].UpdatedAt.Before(now) {
|
||||
t.Fatalf("expected updated timestamp to be after initial time")
|
||||
}
|
||||
}
|
||||
@ -26,6 +26,7 @@ type PeriodicOptions struct {
|
||||
ValidateCommand []string
|
||||
RestartCommand []string
|
||||
Runner commandRunner
|
||||
OnSync func(SyncResult)
|
||||
}
|
||||
|
||||
// PeriodicSyncer periodically rebuilds the Xray configuration from the database.
|
||||
@ -37,6 +38,14 @@ type PeriodicSyncer struct {
|
||||
validateCommand []string
|
||||
restartCommand []string
|
||||
runner commandRunner
|
||||
onSync func(SyncResult)
|
||||
}
|
||||
|
||||
// SyncResult describes the outcome of a synchronization attempt.
|
||||
type SyncResult struct {
|
||||
Clients int
|
||||
Error error
|
||||
CompletedAt time.Time
|
||||
}
|
||||
|
||||
// NewPeriodicSyncer constructs a new PeriodicSyncer from the provided options.
|
||||
@ -69,6 +78,7 @@ func NewPeriodicSyncer(opts PeriodicOptions) (*PeriodicSyncer, error) {
|
||||
validateCommand: append([]string(nil), opts.ValidateCommand...),
|
||||
restartCommand: append([]string(nil), opts.RestartCommand...),
|
||||
runner: runner,
|
||||
onSync: opts.OnSync,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -101,6 +111,7 @@ func (s *PeriodicSyncer) Start(ctx context.Context) (func(context.Context) error
|
||||
|
||||
func (s *PeriodicSyncer) run(ctx context.Context) {
|
||||
if n, err := s.sync(ctx); err != nil {
|
||||
s.notify(SyncResult{Clients: n, Error: err, CompletedAt: time.Now().UTC()})
|
||||
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||||
s.logger.Error("xray config sync failed", "err", err)
|
||||
}
|
||||
@ -109,6 +120,7 @@ func (s *PeriodicSyncer) run(ctx context.Context) {
|
||||
}
|
||||
} else {
|
||||
s.logger.Info("xray config synchronized", "clients", n)
|
||||
s.notify(SyncResult{Clients: n, CompletedAt: time.Now().UTC()})
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(s.interval)
|
||||
@ -121,6 +133,7 @@ func (s *PeriodicSyncer) run(ctx context.Context) {
|
||||
case <-ticker.C:
|
||||
n, err := s.sync(ctx)
|
||||
if err != nil {
|
||||
s.notify(SyncResult{Clients: n, Error: err, CompletedAt: time.Now().UTC()})
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return
|
||||
}
|
||||
@ -128,6 +141,7 @@ func (s *PeriodicSyncer) run(ctx context.Context) {
|
||||
continue
|
||||
}
|
||||
s.logger.Info("xray config synchronized", "clients", n)
|
||||
s.notify(SyncResult{Clients: n, CompletedAt: time.Now().UTC()})
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -153,6 +167,13 @@ func (s *PeriodicSyncer) sync(ctx context.Context) (int, error) {
|
||||
return len(clients), nil
|
||||
}
|
||||
|
||||
func (s *PeriodicSyncer) notify(result SyncResult) {
|
||||
if s.onSync == nil {
|
||||
return
|
||||
}
|
||||
s.onSync(result)
|
||||
}
|
||||
|
||||
func (s *PeriodicSyncer) runCommand(ctx context.Context, cmd []string, action string) error {
|
||||
output, err := s.runner(ctx, cmd)
|
||||
if err != nil {
|
||||
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
@ -173,6 +174,47 @@ func TestPeriodicSyncerStartStop(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeriodicSyncerOnSyncCallback(t *testing.T) {
|
||||
template, output := writeTemplate(t)
|
||||
var results []SyncResult
|
||||
var mu sync.Mutex
|
||||
opts := PeriodicOptions{
|
||||
Interval: 5 * time.Millisecond,
|
||||
Source: staticSource{clients: []Client{{ID: "uuid-a"}}},
|
||||
Generator: Generator{TemplatePath: template, OutputPath: output},
|
||||
OnSync: func(res SyncResult) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
results = append(results, res)
|
||||
},
|
||||
}
|
||||
syncer, err := NewPeriodicSyncer(opts)
|
||||
if err != nil {
|
||||
t.Fatalf("new syncer: %v", err)
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
stop, err := syncer.Start(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("start: %v", err)
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
cancel()
|
||||
if err := stop(context.Background()); err != nil {
|
||||
t.Fatalf("stop: %v", err)
|
||||
}
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if len(results) == 0 {
|
||||
t.Fatalf("expected at least one sync result")
|
||||
}
|
||||
if results[0].Clients != 1 {
|
||||
t.Fatalf("expected 1 client, got %d", results[0].Clients)
|
||||
}
|
||||
if results[0].CompletedAt.IsZero() {
|
||||
t.Fatalf("expected completion timestamp to be set")
|
||||
}
|
||||
}
|
||||
|
||||
type clientSourceFunc func(ctx context.Context) ([]Client, error)
|
||||
|
||||
func (f clientSourceFunc) ListClients(ctx context.Context) ([]Client, error) {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user