fix(auth): replace public token exchange with one-time code
This commit is contained in:
parent
794f86d17d
commit
cc684f7c2a
134
api/api.go
134
api/api.go
@ -38,6 +38,7 @@ const defaultEmailVerificationTTL = 10 * time.Minute
|
||||
const defaultPasswordResetTTL = 30 * time.Minute
|
||||
const maxMFAVerificationAttempts = 5
|
||||
const defaultMFALockoutDuration = 5 * time.Minute
|
||||
const defaultOAuthExchangeCodeTTL = 5 * time.Minute
|
||||
|
||||
const sessionCookieName = "xc_session"
|
||||
|
||||
@ -46,6 +47,12 @@ type session struct {
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
type oauthExchangeCode struct {
|
||||
sessionToken string
|
||||
sessionExpiresAt time.Time
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
type handler struct {
|
||||
store store.Store
|
||||
mu sync.RWMutex
|
||||
@ -64,6 +71,9 @@ type handler struct {
|
||||
resetTTL time.Duration
|
||||
passwordResets map[string]passwordReset
|
||||
resetMu sync.RWMutex
|
||||
oauthExchangeCodes map[string]oauthExchangeCode
|
||||
oauthExchangeMu sync.RWMutex
|
||||
oauthExchangeTTL time.Duration
|
||||
metricsProvider service.UserMetricsProvider
|
||||
agentStatusReader agentStatusReader
|
||||
tokenService *auth.TokenService
|
||||
@ -271,6 +281,8 @@ func RegisterRoutes(r *gin.Engine, opts ...Option) {
|
||||
registrationVerifications: make(map[string]registrationVerification),
|
||||
resetTTL: defaultPasswordResetTTL,
|
||||
passwordResets: make(map[string]passwordReset),
|
||||
oauthExchangeCodes: make(map[string]oauthExchangeCode),
|
||||
oauthExchangeTTL: defaultOAuthExchangeCodeTTL,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
@ -294,7 +306,7 @@ func RegisterRoutes(r *gin.Engine, opts ...Option) {
|
||||
authGroup.POST("/login", h.login)
|
||||
authGroup.POST("/mfa/verify", h.verifyMFALogin)
|
||||
|
||||
// Token exchange endpoint - converts public token to access/refresh tokens
|
||||
// Token exchange endpoint - converts one-time OAuth exchange code to a real session token.
|
||||
authGroup.POST("/token/exchange", h.exchangeToken)
|
||||
|
||||
// OAuth2 routes
|
||||
@ -1302,55 +1314,46 @@ func (h *handler) refreshToken(c *gin.Context) {
|
||||
}
|
||||
|
||||
type tokenExchangeRequest struct {
|
||||
PublicToken string `json:"public_token"`
|
||||
UserID string `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Roles string `json:"roles"`
|
||||
ExchangeCode string `json:"exchange_code"`
|
||||
}
|
||||
|
||||
func (h *handler) exchangeToken(c *gin.Context) {
|
||||
if h.tokenService == nil {
|
||||
respondError(c, http.StatusServiceUnavailable, "token_service_unavailable", "token service is not configured")
|
||||
return
|
||||
}
|
||||
|
||||
var req tokenExchangeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
respondError(c, http.StatusBadRequest, "invalid_request", "invalid request payload")
|
||||
return
|
||||
}
|
||||
|
||||
// Validate public token
|
||||
if !h.tokenService.ValidatePublicToken(req.PublicToken) {
|
||||
respondError(c, http.StatusUnauthorized, "invalid_public_token", "invalid public token")
|
||||
sessionToken, _, ok := h.consumeOAuthExchangeCode(req.ExchangeCode)
|
||||
if !ok {
|
||||
respondError(c, http.StatusUnauthorized, "invalid_exchange_code", "invalid or expired exchange code")
|
||||
return
|
||||
}
|
||||
|
||||
// Parse roles
|
||||
var roles []string
|
||||
if req.Roles != "" {
|
||||
roles = strings.Split(req.Roles, ",")
|
||||
for i := range roles {
|
||||
roles[i] = strings.TrimSpace(roles[i])
|
||||
}
|
||||
} else {
|
||||
roles = []string{"user"}
|
||||
sess, ok := h.lookupSession(sessionToken)
|
||||
if !ok {
|
||||
respondError(c, http.StatusUnauthorized, "invalid_exchange_code", "exchange session is invalid or expired")
|
||||
return
|
||||
}
|
||||
|
||||
// Generate token pair
|
||||
tokenPair, err := h.tokenService.GenerateTokenPair(req.UserID, req.Email, roles)
|
||||
user, err := h.store.GetUserByID(c.Request.Context(), sess.userID)
|
||||
if err != nil {
|
||||
slog.Error("failed to generate token pair", "err", err)
|
||||
respondError(c, http.StatusInternalServerError, "token_generation_failed", "failed to generate tokens")
|
||||
respondError(c, http.StatusInternalServerError, "session_user_lookup_failed", "failed to load session user")
|
||||
return
|
||||
}
|
||||
|
||||
expiresIn := int64(time.Until(sess.expiresAt).Seconds())
|
||||
if expiresIn < 0 {
|
||||
expiresIn = 0
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"public_token": tokenPair.PublicToken,
|
||||
"access_token": tokenPair.AccessToken,
|
||||
"refresh_token": tokenPair.RefreshToken,
|
||||
"token_type": tokenPair.TokenType,
|
||||
"expires_in": tokenPair.ExpiresIn,
|
||||
"token": sessionToken,
|
||||
"access_token": sessionToken,
|
||||
"token_type": "Bearer",
|
||||
"expiresAt": sess.expiresAt.UTC(),
|
||||
"expires_in": expiresIn,
|
||||
"user": sanitizeUser(user, nil),
|
||||
})
|
||||
}
|
||||
|
||||
@ -1525,6 +1528,51 @@ func (h *handler) removeSession(token string) {
|
||||
h.store.DeleteSession(context.Background(), token)
|
||||
}
|
||||
|
||||
func (h *handler) issueOAuthExchangeCode(sessionToken string, sessionExpiresAt time.Time) (string, time.Time, error) {
|
||||
code, err := h.newRandomToken()
|
||||
if err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(h.oauthExchangeTTL)
|
||||
if h.oauthExchangeTTL <= 0 {
|
||||
expiresAt = time.Now().Add(defaultOAuthExchangeCodeTTL)
|
||||
}
|
||||
if !sessionExpiresAt.IsZero() && sessionExpiresAt.Before(expiresAt) {
|
||||
expiresAt = sessionExpiresAt
|
||||
}
|
||||
|
||||
h.oauthExchangeMu.Lock()
|
||||
defer h.oauthExchangeMu.Unlock()
|
||||
h.oauthExchangeCodes[code] = oauthExchangeCode{
|
||||
sessionToken: sessionToken,
|
||||
sessionExpiresAt: sessionExpiresAt,
|
||||
expiresAt: expiresAt,
|
||||
}
|
||||
return code, expiresAt, nil
|
||||
}
|
||||
|
||||
func (h *handler) consumeOAuthExchangeCode(code string) (string, time.Time, bool) {
|
||||
normalized := strings.TrimSpace(code)
|
||||
if normalized == "" {
|
||||
return "", time.Time{}, false
|
||||
}
|
||||
|
||||
h.oauthExchangeMu.Lock()
|
||||
defer h.oauthExchangeMu.Unlock()
|
||||
|
||||
record, ok := h.oauthExchangeCodes[normalized]
|
||||
if !ok {
|
||||
return "", time.Time{}, false
|
||||
}
|
||||
delete(h.oauthExchangeCodes, normalized)
|
||||
|
||||
if time.Now().After(record.expiresAt) {
|
||||
return "", time.Time{}, false
|
||||
}
|
||||
return record.sessionToken, record.sessionExpiresAt, true
|
||||
}
|
||||
|
||||
func (h *handler) newRandomToken() (string, error) {
|
||||
buffer := make([]byte, 32)
|
||||
if _, err := rand.Read(buffer); err != nil {
|
||||
@ -2660,15 +2708,13 @@ func (h *handler) oauthCallback(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
profile, err := provider.FetchProfile(c.Request.Context(), nil) // Exchange is handled inside if we want, or here.
|
||||
// Let's refine the interface to handle token exchange too.
|
||||
token, err := provider.Exchange(c.Request.Context(), code)
|
||||
if err != nil {
|
||||
respondError(c, http.StatusInternalServerError, "oauth_exchange_failed", "failed to exchange oauth code")
|
||||
return
|
||||
}
|
||||
|
||||
profile, err = provider.FetchProfile(c.Request.Context(), token)
|
||||
profile, err := provider.FetchProfile(c.Request.Context(), token)
|
||||
if err != nil {
|
||||
respondError(c, http.StatusInternalServerError, "fetch_profile_failed", "failed to fetch user profile")
|
||||
return
|
||||
@ -2745,25 +2791,25 @@ func (h *handler) oauthCallback(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Create session or generate public token for frontend redirect
|
||||
if h.tokenService == nil {
|
||||
respondError(c, http.StatusServiceUnavailable, "token_service_unavailable", "token service not configured")
|
||||
sessionToken, sessionExpiresAt, err := h.createSession(user.ID)
|
||||
if err != nil {
|
||||
respondError(c, http.StatusInternalServerError, "session_creation_failed", "failed to create session")
|
||||
return
|
||||
}
|
||||
|
||||
publicToken := h.tokenService.GeneratePublicToken(user.ID, user.Email, []string{user.Role})
|
||||
exchangeCode, _, err := h.issueOAuthExchangeCode(sessionToken, sessionExpiresAt)
|
||||
if err != nil {
|
||||
respondError(c, http.StatusInternalServerError, "exchange_code_creation_failed", "failed to issue exchange code")
|
||||
return
|
||||
}
|
||||
|
||||
// Redirect back to frontend with public token
|
||||
frontendURL := h.oauthFrontendURL
|
||||
if frontendURL == "" {
|
||||
frontendURL = "http://localhost:3000"
|
||||
}
|
||||
targetURL := fmt.Sprintf("%s/login?public_token=%s&userId=%s&email=%s&role=%s",
|
||||
targetURL := fmt.Sprintf("%s/login?exchange_code=%s",
|
||||
strings.TrimSuffix(frontendURL, "/"),
|
||||
publicToken,
|
||||
user.ID,
|
||||
url.QueryEscape(user.Email),
|
||||
user.Role)
|
||||
url.QueryEscape(exchangeCode))
|
||||
c.Redirect(http.StatusTemporaryRedirect, targetURL)
|
||||
}
|
||||
|
||||
|
||||
142
api/api_test.go
142
api/api_test.go
@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
@ -17,8 +18,10 @@ import (
|
||||
"github.com/pquerna/otp"
|
||||
"github.com/pquerna/otp/totp"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"account/internal/agentserver"
|
||||
"account/internal/auth"
|
||||
"account/internal/service"
|
||||
"account/internal/store"
|
||||
)
|
||||
@ -58,6 +61,38 @@ func (s *stubMetricsProvider) Compute(context.Context) (service.UserMetrics, err
|
||||
return s.metrics, nil
|
||||
}
|
||||
|
||||
type stubOAuthProvider struct {
|
||||
profile *auth.OAuthUserProfile
|
||||
exchangeErr error
|
||||
profileErr error
|
||||
}
|
||||
|
||||
func (s *stubOAuthProvider) AuthCodeURL(state string) string {
|
||||
return "https://oauth.example.test/authorize?state=" + state
|
||||
}
|
||||
|
||||
func (s *stubOAuthProvider) Exchange(context.Context, string) (*oauth2.Token, error) {
|
||||
if s.exchangeErr != nil {
|
||||
return nil, s.exchangeErr
|
||||
}
|
||||
return &oauth2.Token{AccessToken: "oauth-token", TokenType: "Bearer"}, nil
|
||||
}
|
||||
|
||||
func (s *stubOAuthProvider) FetchProfile(context.Context, *oauth2.Token) (*auth.OAuthUserProfile, error) {
|
||||
if s.profileErr != nil {
|
||||
return nil, s.profileErr
|
||||
}
|
||||
if s.profile == nil {
|
||||
return nil, errors.New("missing oauth profile")
|
||||
}
|
||||
cloned := *s.profile
|
||||
return &cloned, nil
|
||||
}
|
||||
|
||||
func (s *stubOAuthProvider) Name() string {
|
||||
return "github"
|
||||
}
|
||||
|
||||
type testEmailSender struct {
|
||||
mu sync.Mutex
|
||||
messages []capturedEmail
|
||||
@ -332,6 +367,113 @@ func TestRegisterEndpoint(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthCallbackIssuesOneTimeExchangeCode(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
router := gin.New()
|
||||
profile := &auth.OAuthUserProfile{
|
||||
ID: "oauth-user-1",
|
||||
Email: "oauth-user@example.com",
|
||||
Name: "OAuth User",
|
||||
Verified: true,
|
||||
}
|
||||
RegisterRoutes(
|
||||
router,
|
||||
WithStore(store.NewMemoryStore()),
|
||||
WithOAuthProviders(map[string]auth.OAuthProvider{
|
||||
"github": &stubOAuthProvider{profile: profile},
|
||||
}),
|
||||
WithOAuthFrontendURL("https://console.svc.plus"),
|
||||
)
|
||||
|
||||
callbackReq := httptest.NewRequest(http.MethodGet, "/api/auth/oauth/callback/github?code=test-oauth-code", nil)
|
||||
callbackRec := httptest.NewRecorder()
|
||||
router.ServeHTTP(callbackRec, callbackReq)
|
||||
|
||||
if callbackRec.Code != http.StatusTemporaryRedirect {
|
||||
t.Fatalf("expected oauth callback redirect, got %d: %s", callbackRec.Code, callbackRec.Body.String())
|
||||
}
|
||||
|
||||
location := callbackRec.Header().Get("Location")
|
||||
if location == "" {
|
||||
t.Fatalf("expected oauth callback to set redirect location")
|
||||
}
|
||||
redirectURL, err := url.Parse(location)
|
||||
if err != nil {
|
||||
t.Fatalf("parse redirect url: %v", err)
|
||||
}
|
||||
if redirectURL.Query().Get("public_token") != "" {
|
||||
t.Fatalf("expected public_token to be removed from oauth redirect, got %q", location)
|
||||
}
|
||||
if redirectURL.Query().Get("userId") != "" || redirectURL.Query().Get("role") != "" {
|
||||
t.Fatalf("expected redirect to avoid caller-asserted identity fields, got %q", location)
|
||||
}
|
||||
|
||||
exchangeCode := redirectURL.Query().Get("exchange_code")
|
||||
if exchangeCode == "" {
|
||||
t.Fatalf("expected oauth redirect to include exchange_code, got %q", location)
|
||||
}
|
||||
|
||||
exchangeBody, err := json.Marshal(map[string]string{"exchange_code": exchangeCode})
|
||||
if err != nil {
|
||||
t.Fatalf("marshal exchange payload: %v", err)
|
||||
}
|
||||
|
||||
exchangeReq := httptest.NewRequest(http.MethodPost, "/api/auth/token/exchange", bytes.NewReader(exchangeBody))
|
||||
exchangeReq.Header.Set("Content-Type", "application/json")
|
||||
exchangeRec := httptest.NewRecorder()
|
||||
router.ServeHTTP(exchangeRec, exchangeReq)
|
||||
|
||||
if exchangeRec.Code != http.StatusOK {
|
||||
t.Fatalf("expected successful token exchange, got %d: %s", exchangeRec.Code, exchangeRec.Body.String())
|
||||
}
|
||||
|
||||
var exchangeResp struct {
|
||||
Token string `json:"token"`
|
||||
AccessToken string `json:"access_token"`
|
||||
User map[string]interface{} `json:"user"`
|
||||
}
|
||||
if err := json.Unmarshal(exchangeRec.Body.Bytes(), &exchangeResp); err != nil {
|
||||
t.Fatalf("decode exchange response: %v", err)
|
||||
}
|
||||
if exchangeResp.Token == "" {
|
||||
t.Fatalf("expected exchanged session token")
|
||||
}
|
||||
if exchangeResp.AccessToken != exchangeResp.Token {
|
||||
t.Fatalf("expected access_token alias to match session token")
|
||||
}
|
||||
if exchangeResp.User == nil {
|
||||
t.Fatalf("expected exchange response user payload")
|
||||
}
|
||||
if got := exchangeResp.User["email"]; got != profile.Email {
|
||||
t.Fatalf("expected exchange response email %q, got %#v", profile.Email, got)
|
||||
}
|
||||
|
||||
sessionReq := httptest.NewRequest(http.MethodGet, "/api/auth/session", nil)
|
||||
sessionReq.Header.Set("Authorization", "Bearer "+exchangeResp.Token)
|
||||
sessionRec := httptest.NewRecorder()
|
||||
router.ServeHTTP(sessionRec, sessionReq)
|
||||
if sessionRec.Code != http.StatusOK {
|
||||
t.Fatalf("expected exchanged session token to resolve session, got %d: %s", sessionRec.Code, sessionRec.Body.String())
|
||||
}
|
||||
|
||||
replayReq := httptest.NewRequest(http.MethodPost, "/api/auth/token/exchange", bytes.NewReader(exchangeBody))
|
||||
replayReq.Header.Set("Content-Type", "application/json")
|
||||
replayRec := httptest.NewRecorder()
|
||||
router.ServeHTTP(replayRec, replayReq)
|
||||
if replayRec.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected single-use exchange code replay to fail, got %d: %s", replayRec.Code, replayRec.Body.String())
|
||||
}
|
||||
|
||||
var replayResp apiResponse
|
||||
if err := json.Unmarshal(replayRec.Body.Bytes(), &replayResp); err != nil {
|
||||
t.Fatalf("decode replay response: %v", err)
|
||||
}
|
||||
if replayResp.Error != "invalid_exchange_code" {
|
||||
t.Fatalf("expected invalid_exchange_code on replay, got %#v", replayResp.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResendVerificationEndpoint(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@ -29,12 +29,13 @@
|
||||
## JWT 令牌服务(可选)
|
||||
|
||||
启用 `auth.enable: true` 后提供:
|
||||
- `POST /api/auth/token/exchange`:使用 `public_token` 换取 access/refresh
|
||||
- `POST /api/auth/token/exchange`:使用 OAuth 回调签发的一次性 `exchange_code` 换取真实会话 token
|
||||
- `POST /api/auth/token/refresh`:刷新 access token
|
||||
|
||||
注意事项:
|
||||
- `token/exchange` 需要调用方提供 `user_id/email/roles`
|
||||
- 当前版本多数保护路由仍使用会话 token,JWT 仅作为中间件校验存在
|
||||
- 若开启 JWT,中间件要求 `Authorization: Bearer <access-token>`,但业务逻辑仍可能需要会话 token
|
||||
- `token/exchange` 只接受后端签发的一次性 `exchange_code`,不再接受调用方自报 `user_id/email/roles`
|
||||
- `token/exchange` 返回的 `token`/`access_token` 是同一个真实会话 token,供前端 BFF 写入 `xc_session`
|
||||
- 当前版本多数保护路由仍使用会话 token,JWT refresh 仅保留给 `token/refresh`
|
||||
- 若开启 JWT 中间件,业务逻辑仍可能需要会话 token;因此控制面应优先走会话模型
|
||||
|
||||
建议:若主要使用会话认证,请将 `auth.enable` 设为 `false`。
|
||||
|
||||
@ -10,7 +10,7 @@
|
||||
- `POST /api/auth/register/send`:发送邮箱验证码
|
||||
- `POST /api/auth/register/verify`:验证邮箱验证码
|
||||
- `POST /api/auth/login`:登录
|
||||
- `POST /api/auth/token/exchange`:public token 换取 access/refresh
|
||||
- `POST /api/auth/token/exchange`:一次性 OAuth `exchange_code` 换取真实会话 token
|
||||
- `POST /api/auth/token/refresh`:刷新 access token
|
||||
|
||||
### 需要会话(或受保护)
|
||||
|
||||
@ -16,7 +16,7 @@
|
||||
- `invalid_session` / `session_token_required`
|
||||
- `mfa_code_required` / `invalid_mfa_code`
|
||||
- `token_service_unavailable`
|
||||
- `invalid_public_token` / `invalid_refresh_token`
|
||||
- `invalid_exchange_code` / `invalid_refresh_token`
|
||||
- `subscription_not_found`
|
||||
- `agent_status_unavailable`
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user