feat: Add multi-factor authentication login flow and config synchronization endpoints.

This commit is contained in:
Haitao Pan 2026-02-17 11:59:18 +08:00
parent 5532fc7e52
commit 1c32d2f01b
3 changed files with 340 additions and 35 deletions

View File

@ -267,6 +267,7 @@ func RegisterRoutes(r *gin.Engine, opts ...Option) {
authGroup.POST("/register/send", h.sendEmailVerification)
authGroup.POST("/login", h.login)
authGroup.POST("/mfa/verify", h.verifyMFALogin)
// Token exchange endpoint - converts public token to access/refresh tokens
authGroup.POST("/token/exchange", h.exchangeToken)
@ -277,8 +278,11 @@ func RegisterRoutes(r *gin.Engine, opts ...Option) {
// Token refresh endpoint - generates new access token using refresh token
authGroup.POST("/token/refresh", h.refreshToken)
authGroup.POST("/refresh", h.refreshToken)
authGroup.GET("/mfa/status", h.mfaStatus)
authGroup.GET("/sync/config", h.syncConfigSnapshot)
authGroup.POST("/sync/ack", h.syncConfigAck)
// Sandbox binding read endpoint.
// Used by the Console Guest/Demo experience. Must be readable either via a
@ -1057,7 +1061,22 @@ func (h *handler) login(c *gin.Context) {
if user.MFAEnabled {
if totpCode == "" {
respondError(c, http.StatusBadRequest, "mfa_code_required", "totp code is required")
mfaTicket, err := h.createMFAChallenge(user.ID)
if err != nil {
respondError(c, http.StatusInternalServerError, "mfa_challenge_creation_failed", "failed to create mfa challenge")
return
}
c.JSON(http.StatusOK, gin.H{
"message": "mfa required",
"mfaRequired": true,
"mfa_required": true,
"mfaMethod": "totp",
"mfa_method": "totp",
"mfaTicket": mfaTicket,
"mfa_ticket": mfaTicket,
// Kept for backward compatibility with existing clients.
"mfaToken": mfaTicket,
})
return
}
@ -1087,7 +1106,11 @@ func (h *handler) login(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"message": "login successful",
"token": token,
"access_token": token,
"expiresAt": expiresAt.UTC(),
"expires_in": int64(time.Until(expiresAt).Seconds()),
"mfaRequired": false,
"mfa_required": false,
"user": sanitizeUser(user, nil),
})
return
@ -1104,7 +1127,11 @@ func (h *handler) login(c *gin.Context) {
response := gin.H{
"message": "login successful",
"token": token,
"access_token": token,
"expiresAt": expiresAt.UTC(),
"expires_in": int64(time.Until(expiresAt).Seconds()),
"mfaRequired": false,
"mfa_required": false,
"user": sanitizeUser(user, nil),
}
@ -1119,6 +1146,98 @@ func (h *handler) login(c *gin.Context) {
c.JSON(http.StatusOK, response)
}
func (h *handler) verifyMFALogin(c *gin.Context) {
var req struct {
MFATicket string `json:"mfa_ticket"`
MFAToken string `json:"mfaToken"`
Code string `json:"code"`
TOTPCode string `json:"totpCode"`
Method string `json:"method"`
}
if err := c.ShouldBindJSON(&req); err != nil {
respondError(c, http.StatusBadRequest, "invalid_request", "invalid request payload")
return
}
mfaTicket := strings.TrimSpace(req.MFATicket)
if mfaTicket == "" {
mfaTicket = strings.TrimSpace(req.MFAToken)
}
if mfaTicket == "" {
respondError(c, http.StatusBadRequest, "mfa_ticket_required", "mfa ticket is required")
return
}
code := strings.TrimSpace(req.Code)
if code == "" {
code = strings.TrimSpace(req.TOTPCode)
}
if code == "" {
respondError(c, http.StatusBadRequest, "mfa_code_required", "totp code is required")
return
}
method := strings.ToLower(strings.TrimSpace(req.Method))
if method == "" {
method = "totp"
}
if method != "totp" {
respondError(c, http.StatusBadRequest, "unsupported_mfa_method", "unsupported mfa method")
return
}
challenge, ok := h.lookupMFAChallenge(mfaTicket)
if !ok {
respondError(c, http.StatusUnauthorized, "invalid_mfa_ticket", "mfa ticket is invalid or expired")
return
}
user, err := h.store.GetUserByID(c.Request.Context(), challenge.userID)
if err != nil {
respondError(c, http.StatusInternalServerError, "authentication_failed", "failed to authenticate user")
return
}
if !user.MFAEnabled {
respondError(c, http.StatusBadRequest, "mfa_not_enabled", "multi-factor authentication is not enabled")
return
}
valid, err := totp.ValidateCustom(code, user.MFATOTPSecret, time.Now().UTC(), totp.ValidateOpts{
Period: 30,
Skew: 1,
Digits: otp.DigitsSix,
Algorithm: otp.AlgorithmSHA1,
})
if err != nil {
respondError(c, http.StatusInternalServerError, "invalid_mfa_code", "invalid totp code")
return
}
if !valid {
respondError(c, http.StatusUnauthorized, "invalid_mfa_code", "invalid totp code")
return
}
h.removeMFAChallenge(mfaTicket)
token, expiresAt, err := h.createSession(user.ID)
if err != nil {
respondError(c, http.StatusInternalServerError, "session_creation_failed", "failed to create session")
return
}
h.setSessionCookie(c, token, expiresAt)
c.JSON(http.StatusOK, gin.H{
"message": "login successful",
"token": token,
"access_token": token,
"expiresAt": expiresAt.UTC(),
"expires_in": int64(time.Until(expiresAt).Seconds()),
"mfaRequired": false,
"mfa_required": false,
"user": sanitizeUser(user, nil),
})
}
type tokenRefreshRequest struct {
RefreshToken string `json:"refresh_token"`
}

View File

@ -1,45 +1,196 @@
package api
import (
"crypto/sha256"
"encoding/hex"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"account/internal/store"
"account/internal/xrayconfig"
)
// syncConfig handles POST /api/config/sync requests. The endpoint currently
// verifies that the caller has a valid authenticated session (using the
// xc_session cookie or Authorization header) and returns a placeholder
// response indicating that the desktop sync feature is not yet implemented.
//
// The full implementation is outlined in docs/account-xstream-desktop-integration.md
// and will be wired in subsequent iterations.
type syncConfigAckRequest struct {
Version int64 `json:"version"`
DeviceID string `json:"device_id"`
AppliedAt string `json:"applied_at"`
}
func (h *handler) syncConfigSnapshot(c *gin.Context) {
h.respondSyncConfigSnapshot(c)
}
func (h *handler) syncConfig(c *gin.Context) {
token := extractToken(c.GetHeader("Authorization"))
if token == "" {
if cookie, err := c.Cookie(sessionCookieName); err == nil {
token = strings.TrimSpace(cookie)
}
}
if token == "" {
respondError(c, http.StatusUnauthorized, "session_token_required", "session token is required")
return
}
if _, ok := h.lookupSession(token); !ok {
respondError(c, http.StatusUnauthorized, "invalid_session", "session token is invalid or expired")
return
}
// Backward-compatible endpoint: old clients call POST /api/auth/config/sync.
h.respondSyncConfigSnapshot(c)
}
func (h *handler) respondSyncConfigSnapshot(c *gin.Context) {
user, ok := h.requireAuthenticatedUser(c)
if !ok {
return
}
if h.isReadOnlyAccount(user) {
respondError(c, http.StatusForbidden, "read_only_account", "demo account is read-only")
sinceVersion := int64(0)
if raw := strings.TrimSpace(c.Query("since_version")); raw != "" {
v, err := strconv.ParseInt(raw, 10, 64)
if err != nil || v < 0 {
respondError(c, http.StatusBadRequest, "invalid_since_version", "since_version must be a non-negative integer")
return
}
sinceVersion = v
}
version := deriveSyncVersion(user)
updatedAt := time.Now().UTC()
if !user.UpdatedAt.IsZero() {
updatedAt = user.UpdatedAt.UTC()
}
renderedJSON, digest, warnings, err := h.renderUserXrayConfig(user)
if err != nil {
respondError(c, http.StatusInternalServerError, "config_render_failed", "failed to render xray config")
return
}
respondError(c, http.StatusNotImplemented, "desktop_sync_unavailable", "desktop configuration sync is not yet available")
changed := sinceVersion < version
profiles := []gin.H{}
nodes := []gin.H{}
if changed {
profiles = append(profiles, gin.H{
"id": strings.TrimSpace(user.ID),
"remark": strings.TrimSpace(user.Name),
"address": extractHostFromPublicURL(h.publicURL),
"port": 1443,
"uuid": strings.TrimSpace(user.ProxyUUID),
"flow": xrayconfig.DefaultFlow,
"source": "server",
})
nodes = append(nodes, gin.H{
"id": strings.TrimSpace(user.ID),
"name": strings.TrimSpace(user.Name),
"protocol": "vless",
"transport": "tcp",
"security": "tls",
"address": extractHostFromPublicURL(h.publicURL),
"port": 1443,
"uuid": strings.TrimSpace(user.ProxyUUID),
"flow": xrayconfig.DefaultFlow,
"source": "server",
"updated_at": updatedAt,
})
}
c.JSON(http.StatusOK, gin.H{
"schema_version": 1,
"changed": changed,
"version": version,
"updated_at": updatedAt,
"profiles": profiles,
"nodes": nodes,
"routes": []gin.H{},
"dns": gin.H{
"mode": "secure_tunnel",
"servers": []string{},
},
"meta": gin.H{
"digest": digest,
"warnings": warnings,
},
"rendered_json": renderedJSON,
"digest": digest,
"warnings": warnings,
})
}
func (h *handler) syncConfigAck(c *gin.Context) {
user, ok := h.requireAuthenticatedUser(c)
if !ok {
return
}
var req syncConfigAckRequest
if err := c.ShouldBindJSON(&req); err != nil {
respondError(c, http.StatusBadRequest, "invalid_request", "invalid request payload")
return
}
if req.Version <= 0 {
respondError(c, http.StatusBadRequest, "invalid_version", "version must be positive")
return
}
if strings.TrimSpace(req.DeviceID) == "" {
respondError(c, http.StatusBadRequest, "device_id_required", "device_id is required")
return
}
if strings.TrimSpace(req.AppliedAt) == "" {
respondError(c, http.StatusBadRequest, "applied_at_required", "applied_at is required")
return
}
c.JSON(http.StatusOK, gin.H{
"acked": true,
"version": req.Version,
"device_id": strings.TrimSpace(req.DeviceID),
"user_id": strings.TrimSpace(user.ID),
"received_at": time.Now().UTC(),
})
}
func deriveSyncVersion(user *store.User) int64 {
if user == nil {
return time.Now().UTC().Unix()
}
if !user.UpdatedAt.IsZero() {
return user.UpdatedAt.UTC().Unix()
}
if !user.CreatedAt.IsZero() {
return user.CreatedAt.UTC().Unix()
}
return time.Now().UTC().Unix()
}
func (h *handler) renderUserXrayConfig(user *store.User) (string, string, []string, error) {
domain := extractHostFromPublicURL(h.publicURL)
if domain == "" {
domain = "accounts.svc.plus"
}
clientID := strings.TrimSpace(user.ProxyUUID)
if clientID == "" {
clientID = strings.TrimSpace(user.ID)
}
clients := []xrayconfig.Client{{
ID: clientID,
Email: strings.TrimSpace(user.Email),
Flow: xrayconfig.DefaultFlow,
}}
gen := xrayconfig.Generator{
Definition: xrayconfig.TCPDefinition(),
Domain: domain,
}
buf, err := gen.Render(clients)
if err != nil {
return "", "", nil, err
}
sum := sha256.Sum256(buf)
return string(buf), hex.EncodeToString(sum[:]), []string{}, nil
}
func extractHostFromPublicURL(raw string) string {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return ""
}
u, err := url.Parse(trimmed)
if err != nil {
return ""
}
return strings.TrimSpace(u.Hostname())
}

35
internal/utils/crypto.go Normal file
View File

@ -0,0 +1,35 @@
package utils
import (
"fmt"
"golang.org/x/crypto/chacha20poly1305"
)
// EncryptSyncData encrypts data using XChaCha20-Poly1305.
// The result is CipherText + MAC. Nonce is 24 bytes.
func EncryptSyncData(secret, nonce, plaintext []byte) ([]byte, error) {
aead, err := chacha20poly1305.NewX(secret)
if err != nil {
return nil, fmt.Errorf("create aead: %w", err)
}
// Encrypt appends the MAC to the ciphertext.
return aead.Seal(nil, nonce, plaintext, nil), nil
}
// DecryptSyncData decrypts data using XChaCha20-Poly1305.
// The data should be CipherText + MAC. Nonce is 24 bytes.
func DecryptSyncData(secret, nonce, data []byte) ([]byte, error) {
aead, err := chacha20poly1305.NewX(secret)
if err != nil {
return nil, fmt.Errorf("create aead: %w", err)
}
plaintext, err := aead.Open(nil, nonce, data, nil)
if err != nil {
return nil, fmt.Errorf("decrypt: %w", err)
}
return plaintext, nil
}