fix(auth): replace public token exchange with one-time code

This commit is contained in:
Haitao Pan 2026-03-17 08:51:01 +08:00
parent 794f86d17d
commit cc684f7c2a
5 changed files with 239 additions and 50 deletions

View File

@ -38,6 +38,7 @@ const defaultEmailVerificationTTL = 10 * time.Minute
const defaultPasswordResetTTL = 30 * time.Minute
const maxMFAVerificationAttempts = 5
const defaultMFALockoutDuration = 5 * time.Minute
const defaultOAuthExchangeCodeTTL = 5 * time.Minute
const sessionCookieName = "xc_session"
@ -46,6 +47,12 @@ type session struct {
expiresAt time.Time
}
type oauthExchangeCode struct {
sessionToken string
sessionExpiresAt time.Time
expiresAt time.Time
}
type handler struct {
store store.Store
mu sync.RWMutex
@ -64,6 +71,9 @@ type handler struct {
resetTTL time.Duration
passwordResets map[string]passwordReset
resetMu sync.RWMutex
oauthExchangeCodes map[string]oauthExchangeCode
oauthExchangeMu sync.RWMutex
oauthExchangeTTL time.Duration
metricsProvider service.UserMetricsProvider
agentStatusReader agentStatusReader
tokenService *auth.TokenService
@ -271,6 +281,8 @@ func RegisterRoutes(r *gin.Engine, opts ...Option) {
registrationVerifications: make(map[string]registrationVerification),
resetTTL: defaultPasswordResetTTL,
passwordResets: make(map[string]passwordReset),
oauthExchangeCodes: make(map[string]oauthExchangeCode),
oauthExchangeTTL: defaultOAuthExchangeCodeTTL,
}
for _, opt := range opts {
@ -294,7 +306,7 @@ func RegisterRoutes(r *gin.Engine, opts ...Option) {
authGroup.POST("/login", h.login)
authGroup.POST("/mfa/verify", h.verifyMFALogin)
// Token exchange endpoint - converts public token to access/refresh tokens
// Token exchange endpoint - converts one-time OAuth exchange code to a real session token.
authGroup.POST("/token/exchange", h.exchangeToken)
// OAuth2 routes
@ -1302,55 +1314,46 @@ func (h *handler) refreshToken(c *gin.Context) {
}
type tokenExchangeRequest struct {
PublicToken string `json:"public_token"`
UserID string `json:"user_id"`
Email string `json:"email"`
Roles string `json:"roles"`
ExchangeCode string `json:"exchange_code"`
}
func (h *handler) exchangeToken(c *gin.Context) {
if h.tokenService == nil {
respondError(c, http.StatusServiceUnavailable, "token_service_unavailable", "token service is not configured")
return
}
var req tokenExchangeRequest
if err := c.ShouldBindJSON(&req); err != nil {
respondError(c, http.StatusBadRequest, "invalid_request", "invalid request payload")
return
}
// Validate public token
if !h.tokenService.ValidatePublicToken(req.PublicToken) {
respondError(c, http.StatusUnauthorized, "invalid_public_token", "invalid public token")
sessionToken, _, ok := h.consumeOAuthExchangeCode(req.ExchangeCode)
if !ok {
respondError(c, http.StatusUnauthorized, "invalid_exchange_code", "invalid or expired exchange code")
return
}
// Parse roles
var roles []string
if req.Roles != "" {
roles = strings.Split(req.Roles, ",")
for i := range roles {
roles[i] = strings.TrimSpace(roles[i])
}
} else {
roles = []string{"user"}
sess, ok := h.lookupSession(sessionToken)
if !ok {
respondError(c, http.StatusUnauthorized, "invalid_exchange_code", "exchange session is invalid or expired")
return
}
// Generate token pair
tokenPair, err := h.tokenService.GenerateTokenPair(req.UserID, req.Email, roles)
user, err := h.store.GetUserByID(c.Request.Context(), sess.userID)
if err != nil {
slog.Error("failed to generate token pair", "err", err)
respondError(c, http.StatusInternalServerError, "token_generation_failed", "failed to generate tokens")
respondError(c, http.StatusInternalServerError, "session_user_lookup_failed", "failed to load session user")
return
}
expiresIn := int64(time.Until(sess.expiresAt).Seconds())
if expiresIn < 0 {
expiresIn = 0
}
c.JSON(http.StatusOK, gin.H{
"public_token": tokenPair.PublicToken,
"access_token": tokenPair.AccessToken,
"refresh_token": tokenPair.RefreshToken,
"token_type": tokenPair.TokenType,
"expires_in": tokenPair.ExpiresIn,
"token": sessionToken,
"access_token": sessionToken,
"token_type": "Bearer",
"expiresAt": sess.expiresAt.UTC(),
"expires_in": expiresIn,
"user": sanitizeUser(user, nil),
})
}
@ -1525,6 +1528,51 @@ func (h *handler) removeSession(token string) {
h.store.DeleteSession(context.Background(), token)
}
func (h *handler) issueOAuthExchangeCode(sessionToken string, sessionExpiresAt time.Time) (string, time.Time, error) {
code, err := h.newRandomToken()
if err != nil {
return "", time.Time{}, err
}
expiresAt := time.Now().Add(h.oauthExchangeTTL)
if h.oauthExchangeTTL <= 0 {
expiresAt = time.Now().Add(defaultOAuthExchangeCodeTTL)
}
if !sessionExpiresAt.IsZero() && sessionExpiresAt.Before(expiresAt) {
expiresAt = sessionExpiresAt
}
h.oauthExchangeMu.Lock()
defer h.oauthExchangeMu.Unlock()
h.oauthExchangeCodes[code] = oauthExchangeCode{
sessionToken: sessionToken,
sessionExpiresAt: sessionExpiresAt,
expiresAt: expiresAt,
}
return code, expiresAt, nil
}
func (h *handler) consumeOAuthExchangeCode(code string) (string, time.Time, bool) {
normalized := strings.TrimSpace(code)
if normalized == "" {
return "", time.Time{}, false
}
h.oauthExchangeMu.Lock()
defer h.oauthExchangeMu.Unlock()
record, ok := h.oauthExchangeCodes[normalized]
if !ok {
return "", time.Time{}, false
}
delete(h.oauthExchangeCodes, normalized)
if time.Now().After(record.expiresAt) {
return "", time.Time{}, false
}
return record.sessionToken, record.sessionExpiresAt, true
}
func (h *handler) newRandomToken() (string, error) {
buffer := make([]byte, 32)
if _, err := rand.Read(buffer); err != nil {
@ -2660,15 +2708,13 @@ func (h *handler) oauthCallback(c *gin.Context) {
return
}
profile, err := provider.FetchProfile(c.Request.Context(), nil) // Exchange is handled inside if we want, or here.
// Let's refine the interface to handle token exchange too.
token, err := provider.Exchange(c.Request.Context(), code)
if err != nil {
respondError(c, http.StatusInternalServerError, "oauth_exchange_failed", "failed to exchange oauth code")
return
}
profile, err = provider.FetchProfile(c.Request.Context(), token)
profile, err := provider.FetchProfile(c.Request.Context(), token)
if err != nil {
respondError(c, http.StatusInternalServerError, "fetch_profile_failed", "failed to fetch user profile")
return
@ -2745,25 +2791,25 @@ func (h *handler) oauthCallback(c *gin.Context) {
}
}
// Create session or generate public token for frontend redirect
if h.tokenService == nil {
respondError(c, http.StatusServiceUnavailable, "token_service_unavailable", "token service not configured")
sessionToken, sessionExpiresAt, err := h.createSession(user.ID)
if err != nil {
respondError(c, http.StatusInternalServerError, "session_creation_failed", "failed to create session")
return
}
publicToken := h.tokenService.GeneratePublicToken(user.ID, user.Email, []string{user.Role})
exchangeCode, _, err := h.issueOAuthExchangeCode(sessionToken, sessionExpiresAt)
if err != nil {
respondError(c, http.StatusInternalServerError, "exchange_code_creation_failed", "failed to issue exchange code")
return
}
// Redirect back to frontend with public token
frontendURL := h.oauthFrontendURL
if frontendURL == "" {
frontendURL = "http://localhost:3000"
}
targetURL := fmt.Sprintf("%s/login?public_token=%s&userId=%s&email=%s&role=%s",
targetURL := fmt.Sprintf("%s/login?exchange_code=%s",
strings.TrimSuffix(frontendURL, "/"),
publicToken,
user.ID,
url.QueryEscape(user.Email),
user.Role)
url.QueryEscape(exchangeCode))
c.Redirect(http.StatusTemporaryRedirect, targetURL)
}

View File

@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"net/url"
@ -17,8 +18,10 @@ import (
"github.com/pquerna/otp"
"github.com/pquerna/otp/totp"
"golang.org/x/crypto/bcrypt"
"golang.org/x/oauth2"
"account/internal/agentserver"
"account/internal/auth"
"account/internal/service"
"account/internal/store"
)
@ -58,6 +61,38 @@ func (s *stubMetricsProvider) Compute(context.Context) (service.UserMetrics, err
return s.metrics, nil
}
type stubOAuthProvider struct {
profile *auth.OAuthUserProfile
exchangeErr error
profileErr error
}
func (s *stubOAuthProvider) AuthCodeURL(state string) string {
return "https://oauth.example.test/authorize?state=" + state
}
func (s *stubOAuthProvider) Exchange(context.Context, string) (*oauth2.Token, error) {
if s.exchangeErr != nil {
return nil, s.exchangeErr
}
return &oauth2.Token{AccessToken: "oauth-token", TokenType: "Bearer"}, nil
}
func (s *stubOAuthProvider) FetchProfile(context.Context, *oauth2.Token) (*auth.OAuthUserProfile, error) {
if s.profileErr != nil {
return nil, s.profileErr
}
if s.profile == nil {
return nil, errors.New("missing oauth profile")
}
cloned := *s.profile
return &cloned, nil
}
func (s *stubOAuthProvider) Name() string {
return "github"
}
type testEmailSender struct {
mu sync.Mutex
messages []capturedEmail
@ -332,6 +367,113 @@ func TestRegisterEndpoint(t *testing.T) {
}
}
func TestOAuthCallbackIssuesOneTimeExchangeCode(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
profile := &auth.OAuthUserProfile{
ID: "oauth-user-1",
Email: "oauth-user@example.com",
Name: "OAuth User",
Verified: true,
}
RegisterRoutes(
router,
WithStore(store.NewMemoryStore()),
WithOAuthProviders(map[string]auth.OAuthProvider{
"github": &stubOAuthProvider{profile: profile},
}),
WithOAuthFrontendURL("https://console.svc.plus"),
)
callbackReq := httptest.NewRequest(http.MethodGet, "/api/auth/oauth/callback/github?code=test-oauth-code", nil)
callbackRec := httptest.NewRecorder()
router.ServeHTTP(callbackRec, callbackReq)
if callbackRec.Code != http.StatusTemporaryRedirect {
t.Fatalf("expected oauth callback redirect, got %d: %s", callbackRec.Code, callbackRec.Body.String())
}
location := callbackRec.Header().Get("Location")
if location == "" {
t.Fatalf("expected oauth callback to set redirect location")
}
redirectURL, err := url.Parse(location)
if err != nil {
t.Fatalf("parse redirect url: %v", err)
}
if redirectURL.Query().Get("public_token") != "" {
t.Fatalf("expected public_token to be removed from oauth redirect, got %q", location)
}
if redirectURL.Query().Get("userId") != "" || redirectURL.Query().Get("role") != "" {
t.Fatalf("expected redirect to avoid caller-asserted identity fields, got %q", location)
}
exchangeCode := redirectURL.Query().Get("exchange_code")
if exchangeCode == "" {
t.Fatalf("expected oauth redirect to include exchange_code, got %q", location)
}
exchangeBody, err := json.Marshal(map[string]string{"exchange_code": exchangeCode})
if err != nil {
t.Fatalf("marshal exchange payload: %v", err)
}
exchangeReq := httptest.NewRequest(http.MethodPost, "/api/auth/token/exchange", bytes.NewReader(exchangeBody))
exchangeReq.Header.Set("Content-Type", "application/json")
exchangeRec := httptest.NewRecorder()
router.ServeHTTP(exchangeRec, exchangeReq)
if exchangeRec.Code != http.StatusOK {
t.Fatalf("expected successful token exchange, got %d: %s", exchangeRec.Code, exchangeRec.Body.String())
}
var exchangeResp struct {
Token string `json:"token"`
AccessToken string `json:"access_token"`
User map[string]interface{} `json:"user"`
}
if err := json.Unmarshal(exchangeRec.Body.Bytes(), &exchangeResp); err != nil {
t.Fatalf("decode exchange response: %v", err)
}
if exchangeResp.Token == "" {
t.Fatalf("expected exchanged session token")
}
if exchangeResp.AccessToken != exchangeResp.Token {
t.Fatalf("expected access_token alias to match session token")
}
if exchangeResp.User == nil {
t.Fatalf("expected exchange response user payload")
}
if got := exchangeResp.User["email"]; got != profile.Email {
t.Fatalf("expected exchange response email %q, got %#v", profile.Email, got)
}
sessionReq := httptest.NewRequest(http.MethodGet, "/api/auth/session", nil)
sessionReq.Header.Set("Authorization", "Bearer "+exchangeResp.Token)
sessionRec := httptest.NewRecorder()
router.ServeHTTP(sessionRec, sessionReq)
if sessionRec.Code != http.StatusOK {
t.Fatalf("expected exchanged session token to resolve session, got %d: %s", sessionRec.Code, sessionRec.Body.String())
}
replayReq := httptest.NewRequest(http.MethodPost, "/api/auth/token/exchange", bytes.NewReader(exchangeBody))
replayReq.Header.Set("Content-Type", "application/json")
replayRec := httptest.NewRecorder()
router.ServeHTTP(replayRec, replayReq)
if replayRec.Code != http.StatusUnauthorized {
t.Fatalf("expected single-use exchange code replay to fail, got %d: %s", replayRec.Code, replayRec.Body.String())
}
var replayResp apiResponse
if err := json.Unmarshal(replayRec.Body.Bytes(), &replayResp); err != nil {
t.Fatalf("decode replay response: %v", err)
}
if replayResp.Error != "invalid_exchange_code" {
t.Fatalf("expected invalid_exchange_code on replay, got %#v", replayResp.Error)
}
}
func TestResendVerificationEndpoint(t *testing.T) {
gin.SetMode(gin.TestMode)

View File

@ -29,12 +29,13 @@
## JWT 令牌服务(可选)
启用 `auth.enable: true` 后提供:
- `POST /api/auth/token/exchange`:使用 `public_token` 换取 access/refresh
- `POST /api/auth/token/exchange`:使用 OAuth 回调签发的一次性 `exchange_code` 换取真实会话 token
- `POST /api/auth/token/refresh`:刷新 access token
注意事项:
- `token/exchange` 需要调用方提供 `user_id/email/roles`
- 当前版本多数保护路由仍使用会话 tokenJWT 仅作为中间件校验存在
- 若开启 JWT中间件要求 `Authorization: Bearer <access-token>`,但业务逻辑仍可能需要会话 token
- `token/exchange` 只接受后端签发的一次性 `exchange_code`,不再接受调用方自报 `user_id/email/roles`
- `token/exchange` 返回的 `token`/`access_token` 是同一个真实会话 token供前端 BFF 写入 `xc_session`
- 当前版本多数保护路由仍使用会话 tokenJWT refresh 仅保留给 `token/refresh`
- 若开启 JWT 中间件,业务逻辑仍可能需要会话 token因此控制面应优先走会话模型
建议:若主要使用会话认证,请将 `auth.enable` 设为 `false`

View File

@ -10,7 +10,7 @@
- `POST /api/auth/register/send`:发送邮箱验证码
- `POST /api/auth/register/verify`:验证邮箱验证码
- `POST /api/auth/login`:登录
- `POST /api/auth/token/exchange`public token 换取 access/refresh
- `POST /api/auth/token/exchange`一次性 OAuth `exchange_code` 换取真实会话 token
- `POST /api/auth/token/refresh`:刷新 access token
### 需要会话(或受保护)

View File

@ -16,7 +16,7 @@
- `invalid_session` / `session_token_required`
- `mfa_code_required` / `invalid_mfa_code`
- `token_service_unavailable`
- `invalid_public_token` / `invalid_refresh_token`
- `invalid_exchange_code` / `invalid_refresh_token`
- `subscription_not_found`
- `agent_status_unavailable`