feat: Add multi-factor authentication login flow and config synchronization endpoints.
This commit is contained in:
parent
5532fc7e52
commit
1c32d2f01b
137
api/api.go
137
api/api.go
@ -267,6 +267,7 @@ func RegisterRoutes(r *gin.Engine, opts ...Option) {
|
||||
authGroup.POST("/register/send", h.sendEmailVerification)
|
||||
|
||||
authGroup.POST("/login", h.login)
|
||||
authGroup.POST("/mfa/verify", h.verifyMFALogin)
|
||||
|
||||
// Token exchange endpoint - converts public token to access/refresh tokens
|
||||
authGroup.POST("/token/exchange", h.exchangeToken)
|
||||
@ -277,8 +278,11 @@ func RegisterRoutes(r *gin.Engine, opts ...Option) {
|
||||
|
||||
// Token refresh endpoint - generates new access token using refresh token
|
||||
authGroup.POST("/token/refresh", h.refreshToken)
|
||||
authGroup.POST("/refresh", h.refreshToken)
|
||||
|
||||
authGroup.GET("/mfa/status", h.mfaStatus)
|
||||
authGroup.GET("/sync/config", h.syncConfigSnapshot)
|
||||
authGroup.POST("/sync/ack", h.syncConfigAck)
|
||||
|
||||
// Sandbox binding read endpoint.
|
||||
// Used by the Console Guest/Demo experience. Must be readable either via a
|
||||
@ -1057,7 +1061,22 @@ func (h *handler) login(c *gin.Context) {
|
||||
|
||||
if user.MFAEnabled {
|
||||
if totpCode == "" {
|
||||
respondError(c, http.StatusBadRequest, "mfa_code_required", "totp code is required")
|
||||
mfaTicket, err := h.createMFAChallenge(user.ID)
|
||||
if err != nil {
|
||||
respondError(c, http.StatusInternalServerError, "mfa_challenge_creation_failed", "failed to create mfa challenge")
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "mfa required",
|
||||
"mfaRequired": true,
|
||||
"mfa_required": true,
|
||||
"mfaMethod": "totp",
|
||||
"mfa_method": "totp",
|
||||
"mfaTicket": mfaTicket,
|
||||
"mfa_ticket": mfaTicket,
|
||||
// Kept for backward compatibility with existing clients.
|
||||
"mfaToken": mfaTicket,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@ -1085,10 +1104,14 @@ func (h *handler) login(c *gin.Context) {
|
||||
h.setSessionCookie(c, token, expiresAt)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "login successful",
|
||||
"token": token,
|
||||
"expiresAt": expiresAt.UTC(),
|
||||
"user": sanitizeUser(user, nil),
|
||||
"message": "login successful",
|
||||
"token": token,
|
||||
"access_token": token,
|
||||
"expiresAt": expiresAt.UTC(),
|
||||
"expires_in": int64(time.Until(expiresAt).Seconds()),
|
||||
"mfaRequired": false,
|
||||
"mfa_required": false,
|
||||
"user": sanitizeUser(user, nil),
|
||||
})
|
||||
return
|
||||
}
|
||||
@ -1102,10 +1125,14 @@ func (h *handler) login(c *gin.Context) {
|
||||
h.setSessionCookie(c, token, expiresAt)
|
||||
|
||||
response := gin.H{
|
||||
"message": "login successful",
|
||||
"token": token,
|
||||
"expiresAt": expiresAt.UTC(),
|
||||
"user": sanitizeUser(user, nil),
|
||||
"message": "login successful",
|
||||
"token": token,
|
||||
"access_token": token,
|
||||
"expiresAt": expiresAt.UTC(),
|
||||
"expires_in": int64(time.Until(expiresAt).Seconds()),
|
||||
"mfaRequired": false,
|
||||
"mfa_required": false,
|
||||
"user": sanitizeUser(user, nil),
|
||||
}
|
||||
|
||||
if !h.isReadOnlyAccount(user) {
|
||||
@ -1119,6 +1146,98 @@ func (h *handler) login(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
func (h *handler) verifyMFALogin(c *gin.Context) {
|
||||
var req struct {
|
||||
MFATicket string `json:"mfa_ticket"`
|
||||
MFAToken string `json:"mfaToken"`
|
||||
Code string `json:"code"`
|
||||
TOTPCode string `json:"totpCode"`
|
||||
Method string `json:"method"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
respondError(c, http.StatusBadRequest, "invalid_request", "invalid request payload")
|
||||
return
|
||||
}
|
||||
|
||||
mfaTicket := strings.TrimSpace(req.MFATicket)
|
||||
if mfaTicket == "" {
|
||||
mfaTicket = strings.TrimSpace(req.MFAToken)
|
||||
}
|
||||
if mfaTicket == "" {
|
||||
respondError(c, http.StatusBadRequest, "mfa_ticket_required", "mfa ticket is required")
|
||||
return
|
||||
}
|
||||
|
||||
code := strings.TrimSpace(req.Code)
|
||||
if code == "" {
|
||||
code = strings.TrimSpace(req.TOTPCode)
|
||||
}
|
||||
if code == "" {
|
||||
respondError(c, http.StatusBadRequest, "mfa_code_required", "totp code is required")
|
||||
return
|
||||
}
|
||||
|
||||
method := strings.ToLower(strings.TrimSpace(req.Method))
|
||||
if method == "" {
|
||||
method = "totp"
|
||||
}
|
||||
if method != "totp" {
|
||||
respondError(c, http.StatusBadRequest, "unsupported_mfa_method", "unsupported mfa method")
|
||||
return
|
||||
}
|
||||
|
||||
challenge, ok := h.lookupMFAChallenge(mfaTicket)
|
||||
if !ok {
|
||||
respondError(c, http.StatusUnauthorized, "invalid_mfa_ticket", "mfa ticket is invalid or expired")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.store.GetUserByID(c.Request.Context(), challenge.userID)
|
||||
if err != nil {
|
||||
respondError(c, http.StatusInternalServerError, "authentication_failed", "failed to authenticate user")
|
||||
return
|
||||
}
|
||||
if !user.MFAEnabled {
|
||||
respondError(c, http.StatusBadRequest, "mfa_not_enabled", "multi-factor authentication is not enabled")
|
||||
return
|
||||
}
|
||||
|
||||
valid, err := totp.ValidateCustom(code, user.MFATOTPSecret, time.Now().UTC(), totp.ValidateOpts{
|
||||
Period: 30,
|
||||
Skew: 1,
|
||||
Digits: otp.DigitsSix,
|
||||
Algorithm: otp.AlgorithmSHA1,
|
||||
})
|
||||
if err != nil {
|
||||
respondError(c, http.StatusInternalServerError, "invalid_mfa_code", "invalid totp code")
|
||||
return
|
||||
}
|
||||
if !valid {
|
||||
respondError(c, http.StatusUnauthorized, "invalid_mfa_code", "invalid totp code")
|
||||
return
|
||||
}
|
||||
|
||||
h.removeMFAChallenge(mfaTicket)
|
||||
|
||||
token, expiresAt, err := h.createSession(user.ID)
|
||||
if err != nil {
|
||||
respondError(c, http.StatusInternalServerError, "session_creation_failed", "failed to create session")
|
||||
return
|
||||
}
|
||||
|
||||
h.setSessionCookie(c, token, expiresAt)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "login successful",
|
||||
"token": token,
|
||||
"access_token": token,
|
||||
"expiresAt": expiresAt.UTC(),
|
||||
"expires_in": int64(time.Until(expiresAt).Seconds()),
|
||||
"mfaRequired": false,
|
||||
"mfa_required": false,
|
||||
"user": sanitizeUser(user, nil),
|
||||
})
|
||||
}
|
||||
|
||||
type tokenRefreshRequest struct {
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
|
||||
@ -1,45 +1,196 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"account/internal/store"
|
||||
"account/internal/xrayconfig"
|
||||
)
|
||||
|
||||
// syncConfig handles POST /api/config/sync requests. The endpoint currently
|
||||
// verifies that the caller has a valid authenticated session (using the
|
||||
// xc_session cookie or Authorization header) and returns a placeholder
|
||||
// response indicating that the desktop sync feature is not yet implemented.
|
||||
//
|
||||
// The full implementation is outlined in docs/account-xstream-desktop-integration.md
|
||||
// and will be wired in subsequent iterations.
|
||||
type syncConfigAckRequest struct {
|
||||
Version int64 `json:"version"`
|
||||
DeviceID string `json:"device_id"`
|
||||
AppliedAt string `json:"applied_at"`
|
||||
}
|
||||
|
||||
func (h *handler) syncConfigSnapshot(c *gin.Context) {
|
||||
h.respondSyncConfigSnapshot(c)
|
||||
}
|
||||
|
||||
func (h *handler) syncConfig(c *gin.Context) {
|
||||
token := extractToken(c.GetHeader("Authorization"))
|
||||
if token == "" {
|
||||
if cookie, err := c.Cookie(sessionCookieName); err == nil {
|
||||
token = strings.TrimSpace(cookie)
|
||||
}
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
respondError(c, http.StatusUnauthorized, "session_token_required", "session token is required")
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := h.lookupSession(token); !ok {
|
||||
respondError(c, http.StatusUnauthorized, "invalid_session", "session token is invalid or expired")
|
||||
return
|
||||
}
|
||||
// Backward-compatible endpoint: old clients call POST /api/auth/config/sync.
|
||||
h.respondSyncConfigSnapshot(c)
|
||||
}
|
||||
|
||||
func (h *handler) respondSyncConfigSnapshot(c *gin.Context) {
|
||||
user, ok := h.requireAuthenticatedUser(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if h.isReadOnlyAccount(user) {
|
||||
respondError(c, http.StatusForbidden, "read_only_account", "demo account is read-only")
|
||||
|
||||
sinceVersion := int64(0)
|
||||
if raw := strings.TrimSpace(c.Query("since_version")); raw != "" {
|
||||
v, err := strconv.ParseInt(raw, 10, 64)
|
||||
if err != nil || v < 0 {
|
||||
respondError(c, http.StatusBadRequest, "invalid_since_version", "since_version must be a non-negative integer")
|
||||
return
|
||||
}
|
||||
sinceVersion = v
|
||||
}
|
||||
|
||||
version := deriveSyncVersion(user)
|
||||
updatedAt := time.Now().UTC()
|
||||
if !user.UpdatedAt.IsZero() {
|
||||
updatedAt = user.UpdatedAt.UTC()
|
||||
}
|
||||
|
||||
renderedJSON, digest, warnings, err := h.renderUserXrayConfig(user)
|
||||
if err != nil {
|
||||
respondError(c, http.StatusInternalServerError, "config_render_failed", "failed to render xray config")
|
||||
return
|
||||
}
|
||||
|
||||
respondError(c, http.StatusNotImplemented, "desktop_sync_unavailable", "desktop configuration sync is not yet available")
|
||||
changed := sinceVersion < version
|
||||
profiles := []gin.H{}
|
||||
nodes := []gin.H{}
|
||||
if changed {
|
||||
profiles = append(profiles, gin.H{
|
||||
"id": strings.TrimSpace(user.ID),
|
||||
"remark": strings.TrimSpace(user.Name),
|
||||
"address": extractHostFromPublicURL(h.publicURL),
|
||||
"port": 1443,
|
||||
"uuid": strings.TrimSpace(user.ProxyUUID),
|
||||
"flow": xrayconfig.DefaultFlow,
|
||||
"source": "server",
|
||||
})
|
||||
nodes = append(nodes, gin.H{
|
||||
"id": strings.TrimSpace(user.ID),
|
||||
"name": strings.TrimSpace(user.Name),
|
||||
"protocol": "vless",
|
||||
"transport": "tcp",
|
||||
"security": "tls",
|
||||
"address": extractHostFromPublicURL(h.publicURL),
|
||||
"port": 1443,
|
||||
"uuid": strings.TrimSpace(user.ProxyUUID),
|
||||
"flow": xrayconfig.DefaultFlow,
|
||||
"source": "server",
|
||||
"updated_at": updatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"schema_version": 1,
|
||||
"changed": changed,
|
||||
"version": version,
|
||||
"updated_at": updatedAt,
|
||||
"profiles": profiles,
|
||||
"nodes": nodes,
|
||||
"routes": []gin.H{},
|
||||
"dns": gin.H{
|
||||
"mode": "secure_tunnel",
|
||||
"servers": []string{},
|
||||
},
|
||||
"meta": gin.H{
|
||||
"digest": digest,
|
||||
"warnings": warnings,
|
||||
},
|
||||
"rendered_json": renderedJSON,
|
||||
"digest": digest,
|
||||
"warnings": warnings,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *handler) syncConfigAck(c *gin.Context) {
|
||||
user, ok := h.requireAuthenticatedUser(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req syncConfigAckRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
respondError(c, http.StatusBadRequest, "invalid_request", "invalid request payload")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Version <= 0 {
|
||||
respondError(c, http.StatusBadRequest, "invalid_version", "version must be positive")
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(req.DeviceID) == "" {
|
||||
respondError(c, http.StatusBadRequest, "device_id_required", "device_id is required")
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(req.AppliedAt) == "" {
|
||||
respondError(c, http.StatusBadRequest, "applied_at_required", "applied_at is required")
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"acked": true,
|
||||
"version": req.Version,
|
||||
"device_id": strings.TrimSpace(req.DeviceID),
|
||||
"user_id": strings.TrimSpace(user.ID),
|
||||
"received_at": time.Now().UTC(),
|
||||
})
|
||||
}
|
||||
|
||||
func deriveSyncVersion(user *store.User) int64 {
|
||||
if user == nil {
|
||||
return time.Now().UTC().Unix()
|
||||
}
|
||||
if !user.UpdatedAt.IsZero() {
|
||||
return user.UpdatedAt.UTC().Unix()
|
||||
}
|
||||
if !user.CreatedAt.IsZero() {
|
||||
return user.CreatedAt.UTC().Unix()
|
||||
}
|
||||
return time.Now().UTC().Unix()
|
||||
}
|
||||
|
||||
func (h *handler) renderUserXrayConfig(user *store.User) (string, string, []string, error) {
|
||||
domain := extractHostFromPublicURL(h.publicURL)
|
||||
if domain == "" {
|
||||
domain = "accounts.svc.plus"
|
||||
}
|
||||
|
||||
clientID := strings.TrimSpace(user.ProxyUUID)
|
||||
if clientID == "" {
|
||||
clientID = strings.TrimSpace(user.ID)
|
||||
}
|
||||
clients := []xrayconfig.Client{{
|
||||
ID: clientID,
|
||||
Email: strings.TrimSpace(user.Email),
|
||||
Flow: xrayconfig.DefaultFlow,
|
||||
}}
|
||||
|
||||
gen := xrayconfig.Generator{
|
||||
Definition: xrayconfig.TCPDefinition(),
|
||||
Domain: domain,
|
||||
}
|
||||
buf, err := gen.Render(clients)
|
||||
if err != nil {
|
||||
return "", "", nil, err
|
||||
}
|
||||
sum := sha256.Sum256(buf)
|
||||
return string(buf), hex.EncodeToString(sum[:]), []string{}, nil
|
||||
}
|
||||
|
||||
func extractHostFromPublicURL(raw string) string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
u, err := url.Parse(trimmed)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(u.Hostname())
|
||||
}
|
||||
|
||||
35
internal/utils/crypto.go
Normal file
35
internal/utils/crypto.go
Normal file
@ -0,0 +1,35 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
||||
// EncryptSyncData encrypts data using XChaCha20-Poly1305.
|
||||
// The result is CipherText + MAC. Nonce is 24 bytes.
|
||||
func EncryptSyncData(secret, nonce, plaintext []byte) ([]byte, error) {
|
||||
aead, err := chacha20poly1305.NewX(secret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create aead: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt appends the MAC to the ciphertext.
|
||||
return aead.Seal(nil, nonce, plaintext, nil), nil
|
||||
}
|
||||
|
||||
// DecryptSyncData decrypts data using XChaCha20-Poly1305.
|
||||
// The data should be CipherText + MAC. Nonce is 24 bytes.
|
||||
func DecryptSyncData(secret, nonce, data []byte) ([]byte, error) {
|
||||
aead, err := chacha20poly1305.NewX(secret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create aead: %w", err)
|
||||
}
|
||||
|
||||
plaintext, err := aead.Open(nil, nonce, data, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypt: %w", err)
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user