diff --git a/api/api.go b/api/api.go index be20151..08f464a 100644 --- a/api/api.go +++ b/api/api.go @@ -38,6 +38,7 @@ const defaultEmailVerificationTTL = 10 * time.Minute const defaultPasswordResetTTL = 30 * time.Minute const maxMFAVerificationAttempts = 5 const defaultMFALockoutDuration = 5 * time.Minute +const defaultOAuthExchangeCodeTTL = 5 * time.Minute const sessionCookieName = "xc_session" @@ -46,6 +47,12 @@ type session struct { expiresAt time.Time } +type oauthExchangeCode struct { + sessionToken string + sessionExpiresAt time.Time + expiresAt time.Time +} + type handler struct { store store.Store mu sync.RWMutex @@ -64,6 +71,9 @@ type handler struct { resetTTL time.Duration passwordResets map[string]passwordReset resetMu sync.RWMutex + oauthExchangeCodes map[string]oauthExchangeCode + oauthExchangeMu sync.RWMutex + oauthExchangeTTL time.Duration metricsProvider service.UserMetricsProvider agentStatusReader agentStatusReader tokenService *auth.TokenService @@ -271,6 +281,8 @@ func RegisterRoutes(r *gin.Engine, opts ...Option) { registrationVerifications: make(map[string]registrationVerification), resetTTL: defaultPasswordResetTTL, passwordResets: make(map[string]passwordReset), + oauthExchangeCodes: make(map[string]oauthExchangeCode), + oauthExchangeTTL: defaultOAuthExchangeCodeTTL, } for _, opt := range opts { @@ -294,7 +306,7 @@ func RegisterRoutes(r *gin.Engine, opts ...Option) { authGroup.POST("/login", h.login) authGroup.POST("/mfa/verify", h.verifyMFALogin) - // Token exchange endpoint - converts public token to access/refresh tokens + // Token exchange endpoint - converts one-time OAuth exchange code to a real session token. authGroup.POST("/token/exchange", h.exchangeToken) // OAuth2 routes @@ -1302,55 +1314,46 @@ func (h *handler) refreshToken(c *gin.Context) { } type tokenExchangeRequest struct { - PublicToken string `json:"public_token"` - UserID string `json:"user_id"` - Email string `json:"email"` - Roles string `json:"roles"` + ExchangeCode string `json:"exchange_code"` } func (h *handler) exchangeToken(c *gin.Context) { - if h.tokenService == nil { - respondError(c, http.StatusServiceUnavailable, "token_service_unavailable", "token service is not configured") - return - } - var req tokenExchangeRequest if err := c.ShouldBindJSON(&req); err != nil { respondError(c, http.StatusBadRequest, "invalid_request", "invalid request payload") return } - // Validate public token - if !h.tokenService.ValidatePublicToken(req.PublicToken) { - respondError(c, http.StatusUnauthorized, "invalid_public_token", "invalid public token") + sessionToken, _, ok := h.consumeOAuthExchangeCode(req.ExchangeCode) + if !ok { + respondError(c, http.StatusUnauthorized, "invalid_exchange_code", "invalid or expired exchange code") return } - // Parse roles - var roles []string - if req.Roles != "" { - roles = strings.Split(req.Roles, ",") - for i := range roles { - roles[i] = strings.TrimSpace(roles[i]) - } - } else { - roles = []string{"user"} + sess, ok := h.lookupSession(sessionToken) + if !ok { + respondError(c, http.StatusUnauthorized, "invalid_exchange_code", "exchange session is invalid or expired") + return } - // Generate token pair - tokenPair, err := h.tokenService.GenerateTokenPair(req.UserID, req.Email, roles) + user, err := h.store.GetUserByID(c.Request.Context(), sess.userID) if err != nil { - slog.Error("failed to generate token pair", "err", err) - respondError(c, http.StatusInternalServerError, "token_generation_failed", "failed to generate tokens") + respondError(c, http.StatusInternalServerError, "session_user_lookup_failed", "failed to load session user") return } + expiresIn := int64(time.Until(sess.expiresAt).Seconds()) + if expiresIn < 0 { + expiresIn = 0 + } + c.JSON(http.StatusOK, gin.H{ - "public_token": tokenPair.PublicToken, - "access_token": tokenPair.AccessToken, - "refresh_token": tokenPair.RefreshToken, - "token_type": tokenPair.TokenType, - "expires_in": tokenPair.ExpiresIn, + "token": sessionToken, + "access_token": sessionToken, + "token_type": "Bearer", + "expiresAt": sess.expiresAt.UTC(), + "expires_in": expiresIn, + "user": sanitizeUser(user, nil), }) } @@ -1525,6 +1528,51 @@ func (h *handler) removeSession(token string) { h.store.DeleteSession(context.Background(), token) } +func (h *handler) issueOAuthExchangeCode(sessionToken string, sessionExpiresAt time.Time) (string, time.Time, error) { + code, err := h.newRandomToken() + if err != nil { + return "", time.Time{}, err + } + + expiresAt := time.Now().Add(h.oauthExchangeTTL) + if h.oauthExchangeTTL <= 0 { + expiresAt = time.Now().Add(defaultOAuthExchangeCodeTTL) + } + if !sessionExpiresAt.IsZero() && sessionExpiresAt.Before(expiresAt) { + expiresAt = sessionExpiresAt + } + + h.oauthExchangeMu.Lock() + defer h.oauthExchangeMu.Unlock() + h.oauthExchangeCodes[code] = oauthExchangeCode{ + sessionToken: sessionToken, + sessionExpiresAt: sessionExpiresAt, + expiresAt: expiresAt, + } + return code, expiresAt, nil +} + +func (h *handler) consumeOAuthExchangeCode(code string) (string, time.Time, bool) { + normalized := strings.TrimSpace(code) + if normalized == "" { + return "", time.Time{}, false + } + + h.oauthExchangeMu.Lock() + defer h.oauthExchangeMu.Unlock() + + record, ok := h.oauthExchangeCodes[normalized] + if !ok { + return "", time.Time{}, false + } + delete(h.oauthExchangeCodes, normalized) + + if time.Now().After(record.expiresAt) { + return "", time.Time{}, false + } + return record.sessionToken, record.sessionExpiresAt, true +} + func (h *handler) newRandomToken() (string, error) { buffer := make([]byte, 32) if _, err := rand.Read(buffer); err != nil { @@ -2660,15 +2708,13 @@ func (h *handler) oauthCallback(c *gin.Context) { return } - profile, err := provider.FetchProfile(c.Request.Context(), nil) // Exchange is handled inside if we want, or here. - // Let's refine the interface to handle token exchange too. token, err := provider.Exchange(c.Request.Context(), code) if err != nil { respondError(c, http.StatusInternalServerError, "oauth_exchange_failed", "failed to exchange oauth code") return } - profile, err = provider.FetchProfile(c.Request.Context(), token) + profile, err := provider.FetchProfile(c.Request.Context(), token) if err != nil { respondError(c, http.StatusInternalServerError, "fetch_profile_failed", "failed to fetch user profile") return @@ -2745,25 +2791,25 @@ func (h *handler) oauthCallback(c *gin.Context) { } } - // Create session or generate public token for frontend redirect - if h.tokenService == nil { - respondError(c, http.StatusServiceUnavailable, "token_service_unavailable", "token service not configured") + sessionToken, sessionExpiresAt, err := h.createSession(user.ID) + if err != nil { + respondError(c, http.StatusInternalServerError, "session_creation_failed", "failed to create session") return } - publicToken := h.tokenService.GeneratePublicToken(user.ID, user.Email, []string{user.Role}) + exchangeCode, _, err := h.issueOAuthExchangeCode(sessionToken, sessionExpiresAt) + if err != nil { + respondError(c, http.StatusInternalServerError, "exchange_code_creation_failed", "failed to issue exchange code") + return + } - // Redirect back to frontend with public token frontendURL := h.oauthFrontendURL if frontendURL == "" { frontendURL = "http://localhost:3000" } - targetURL := fmt.Sprintf("%s/login?public_token=%s&userId=%s&email=%s&role=%s", + targetURL := fmt.Sprintf("%s/login?exchange_code=%s", strings.TrimSuffix(frontendURL, "/"), - publicToken, - user.ID, - url.QueryEscape(user.Email), - user.Role) + url.QueryEscape(exchangeCode)) c.Redirect(http.StatusTemporaryRedirect, targetURL) } diff --git a/api/api_test.go b/api/api_test.go index b09ea9f..6f724e0 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "net/http" "net/http/httptest" "net/url" @@ -17,8 +18,10 @@ import ( "github.com/pquerna/otp" "github.com/pquerna/otp/totp" "golang.org/x/crypto/bcrypt" + "golang.org/x/oauth2" "account/internal/agentserver" + "account/internal/auth" "account/internal/service" "account/internal/store" ) @@ -58,6 +61,38 @@ func (s *stubMetricsProvider) Compute(context.Context) (service.UserMetrics, err return s.metrics, nil } +type stubOAuthProvider struct { + profile *auth.OAuthUserProfile + exchangeErr error + profileErr error +} + +func (s *stubOAuthProvider) AuthCodeURL(state string) string { + return "https://oauth.example.test/authorize?state=" + state +} + +func (s *stubOAuthProvider) Exchange(context.Context, string) (*oauth2.Token, error) { + if s.exchangeErr != nil { + return nil, s.exchangeErr + } + return &oauth2.Token{AccessToken: "oauth-token", TokenType: "Bearer"}, nil +} + +func (s *stubOAuthProvider) FetchProfile(context.Context, *oauth2.Token) (*auth.OAuthUserProfile, error) { + if s.profileErr != nil { + return nil, s.profileErr + } + if s.profile == nil { + return nil, errors.New("missing oauth profile") + } + cloned := *s.profile + return &cloned, nil +} + +func (s *stubOAuthProvider) Name() string { + return "github" +} + type testEmailSender struct { mu sync.Mutex messages []capturedEmail @@ -332,6 +367,113 @@ func TestRegisterEndpoint(t *testing.T) { } } +func TestOAuthCallbackIssuesOneTimeExchangeCode(t *testing.T) { + gin.SetMode(gin.TestMode) + + router := gin.New() + profile := &auth.OAuthUserProfile{ + ID: "oauth-user-1", + Email: "oauth-user@example.com", + Name: "OAuth User", + Verified: true, + } + RegisterRoutes( + router, + WithStore(store.NewMemoryStore()), + WithOAuthProviders(map[string]auth.OAuthProvider{ + "github": &stubOAuthProvider{profile: profile}, + }), + WithOAuthFrontendURL("https://console.svc.plus"), + ) + + callbackReq := httptest.NewRequest(http.MethodGet, "/api/auth/oauth/callback/github?code=test-oauth-code", nil) + callbackRec := httptest.NewRecorder() + router.ServeHTTP(callbackRec, callbackReq) + + if callbackRec.Code != http.StatusTemporaryRedirect { + t.Fatalf("expected oauth callback redirect, got %d: %s", callbackRec.Code, callbackRec.Body.String()) + } + + location := callbackRec.Header().Get("Location") + if location == "" { + t.Fatalf("expected oauth callback to set redirect location") + } + redirectURL, err := url.Parse(location) + if err != nil { + t.Fatalf("parse redirect url: %v", err) + } + if redirectURL.Query().Get("public_token") != "" { + t.Fatalf("expected public_token to be removed from oauth redirect, got %q", location) + } + if redirectURL.Query().Get("userId") != "" || redirectURL.Query().Get("role") != "" { + t.Fatalf("expected redirect to avoid caller-asserted identity fields, got %q", location) + } + + exchangeCode := redirectURL.Query().Get("exchange_code") + if exchangeCode == "" { + t.Fatalf("expected oauth redirect to include exchange_code, got %q", location) + } + + exchangeBody, err := json.Marshal(map[string]string{"exchange_code": exchangeCode}) + if err != nil { + t.Fatalf("marshal exchange payload: %v", err) + } + + exchangeReq := httptest.NewRequest(http.MethodPost, "/api/auth/token/exchange", bytes.NewReader(exchangeBody)) + exchangeReq.Header.Set("Content-Type", "application/json") + exchangeRec := httptest.NewRecorder() + router.ServeHTTP(exchangeRec, exchangeReq) + + if exchangeRec.Code != http.StatusOK { + t.Fatalf("expected successful token exchange, got %d: %s", exchangeRec.Code, exchangeRec.Body.String()) + } + + var exchangeResp struct { + Token string `json:"token"` + AccessToken string `json:"access_token"` + User map[string]interface{} `json:"user"` + } + if err := json.Unmarshal(exchangeRec.Body.Bytes(), &exchangeResp); err != nil { + t.Fatalf("decode exchange response: %v", err) + } + if exchangeResp.Token == "" { + t.Fatalf("expected exchanged session token") + } + if exchangeResp.AccessToken != exchangeResp.Token { + t.Fatalf("expected access_token alias to match session token") + } + if exchangeResp.User == nil { + t.Fatalf("expected exchange response user payload") + } + if got := exchangeResp.User["email"]; got != profile.Email { + t.Fatalf("expected exchange response email %q, got %#v", profile.Email, got) + } + + sessionReq := httptest.NewRequest(http.MethodGet, "/api/auth/session", nil) + sessionReq.Header.Set("Authorization", "Bearer "+exchangeResp.Token) + sessionRec := httptest.NewRecorder() + router.ServeHTTP(sessionRec, sessionReq) + if sessionRec.Code != http.StatusOK { + t.Fatalf("expected exchanged session token to resolve session, got %d: %s", sessionRec.Code, sessionRec.Body.String()) + } + + replayReq := httptest.NewRequest(http.MethodPost, "/api/auth/token/exchange", bytes.NewReader(exchangeBody)) + replayReq.Header.Set("Content-Type", "application/json") + replayRec := httptest.NewRecorder() + router.ServeHTTP(replayRec, replayReq) + if replayRec.Code != http.StatusUnauthorized { + t.Fatalf("expected single-use exchange code replay to fail, got %d: %s", replayRec.Code, replayRec.Body.String()) + } + + var replayResp apiResponse + if err := json.Unmarshal(replayRec.Body.Bytes(), &replayResp); err != nil { + t.Fatalf("decode replay response: %v", err) + } + if replayResp.Error != "invalid_exchange_code" { + t.Fatalf("expected invalid_exchange_code on replay, got %#v", replayResp.Error) + } +} + func TestResendVerificationEndpoint(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/docs/api/auth.md b/docs/api/auth.md index 6deac7b..3d81ef2 100644 --- a/docs/api/auth.md +++ b/docs/api/auth.md @@ -29,12 +29,13 @@ ## JWT 令牌服务(可选) 启用 `auth.enable: true` 后提供: -- `POST /api/auth/token/exchange`:使用 `public_token` 换取 access/refresh +- `POST /api/auth/token/exchange`:使用 OAuth 回调签发的一次性 `exchange_code` 换取真实会话 token - `POST /api/auth/token/refresh`:刷新 access token 注意事项: -- `token/exchange` 需要调用方提供 `user_id/email/roles` -- 当前版本多数保护路由仍使用会话 token,JWT 仅作为中间件校验存在 -- 若开启 JWT,中间件要求 `Authorization: Bearer `,但业务逻辑仍可能需要会话 token +- `token/exchange` 只接受后端签发的一次性 `exchange_code`,不再接受调用方自报 `user_id/email/roles` +- `token/exchange` 返回的 `token`/`access_token` 是同一个真实会话 token,供前端 BFF 写入 `xc_session` +- 当前版本多数保护路由仍使用会话 token,JWT refresh 仅保留给 `token/refresh` +- 若开启 JWT 中间件,业务逻辑仍可能需要会话 token;因此控制面应优先走会话模型 建议:若主要使用会话认证,请将 `auth.enable` 设为 `false`。 diff --git a/docs/api/endpoints.md b/docs/api/endpoints.md index bb3ed50..8cb410c 100644 --- a/docs/api/endpoints.md +++ b/docs/api/endpoints.md @@ -10,7 +10,7 @@ - `POST /api/auth/register/send`:发送邮箱验证码 - `POST /api/auth/register/verify`:验证邮箱验证码 - `POST /api/auth/login`:登录 -- `POST /api/auth/token/exchange`:public token 换取 access/refresh +- `POST /api/auth/token/exchange`:一次性 OAuth `exchange_code` 换取真实会话 token - `POST /api/auth/token/refresh`:刷新 access token ### 需要会话(或受保护) diff --git a/docs/api/errors.md b/docs/api/errors.md index e9d6b80..f257cea 100644 --- a/docs/api/errors.md +++ b/docs/api/errors.md @@ -16,7 +16,7 @@ - `invalid_session` / `session_token_required` - `mfa_code_required` / `invalid_mfa_code` - `token_service_unavailable` -- `invalid_public_token` / `invalid_refresh_token` +- `invalid_exchange_code` / `invalid_refresh_token` - `subscription_not_found` - `agent_status_unavailable`