diff --git a/account/api/api.go b/account/api/api.go index a11a3cf..5034606 100644 --- a/account/api/api.go +++ b/account/api/api.go @@ -1,6 +1,7 @@ package api import ( + "context" "crypto/rand" "encoding/hex" "errors" @@ -10,12 +11,16 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/pquerna/otp" + "github.com/pquerna/otp/totp" "golang.org/x/crypto/bcrypt" "xcontrol/account/internal/store" ) const defaultSessionTTL = 24 * time.Hour +const defaultMFAChallengeTTL = 10 * time.Minute +const defaultTOTPIssuer = "XControl Account" type session struct { userID string @@ -23,10 +28,19 @@ type session struct { } type handler struct { - store store.Store - sessions map[string]session - mu sync.RWMutex - sessionTTL time.Duration + store store.Store + sessions map[string]session + mu sync.RWMutex + sessionTTL time.Duration + mfaChallenges map[string]mfaChallenge + mfaMu sync.RWMutex + mfaChallengeTTL time.Duration + totpIssuer string +} + +type mfaChallenge struct { + userID string + expiresAt time.Time } // Option configures handler behaviour when registering routes. @@ -53,9 +67,12 @@ func WithSessionTTL(ttl time.Duration) Option { // RegisterRoutes attaches account service endpoints to the router. func RegisterRoutes(r *gin.Engine, opts ...Option) { h := &handler{ - store: store.NewMemoryStore(), - sessions: make(map[string]session), - sessionTTL: defaultSessionTTL, + store: store.NewMemoryStore(), + sessions: make(map[string]session), + sessionTTL: defaultSessionTTL, + mfaChallenges: make(map[string]mfaChallenge), + mfaChallengeTTL: defaultMFAChallengeTTL, + totpIssuer: defaultTOTPIssuer, } for _, opt := range opts { @@ -71,6 +88,9 @@ func RegisterRoutes(r *gin.Engine, opts ...Option) { auth.POST("/login", h.login) auth.GET("/session", h.session) auth.DELETE("/session", h.deleteSession) + auth.POST("/mfa/totp/provision", h.provisionTOTP) + auth.POST("/mfa/totp/verify", h.verifyTOTP) + auth.GET("/mfa/status", h.mfaStatus) } type registerRequest struct { @@ -80,8 +100,11 @@ type registerRequest struct { } type loginRequest struct { - Username string `json:"username"` - Password string `json:"password"` + Identifier string `json:"identifier"` + Username string `json:"username"` + Email string `json:"email"` + Password string `json:"password"` + TOTPCode string `json:"totpCode"` } func hasQueryParameter(c *gin.Context, keys ...string) bool { @@ -172,7 +195,7 @@ func (h *handler) register(c *gin.Context) { } func (h *handler) login(c *gin.Context) { - if hasQueryParameter(c, "username", "password") { + if hasQueryParameter(c, "username", "password", "identifier", "totp") { respondError(c, http.StatusBadRequest, "credentials_in_query", "sensitive credentials must not be sent in the query string") return } @@ -183,14 +206,23 @@ func (h *handler) login(c *gin.Context) { return } - username := strings.TrimSpace(req.Username) + identifier := strings.TrimSpace(req.Identifier) + if identifier == "" { + identifier = strings.TrimSpace(req.Username) + } + if identifier == "" { + identifier = strings.TrimSpace(req.Email) + } + password := strings.TrimSpace(req.Password) - if username == "" || password == "" { - respondError(c, http.StatusBadRequest, "missing_credentials", "username and password are required") + totpCode := strings.TrimSpace(req.TOTPCode) + + if identifier == "" { + respondError(c, http.StatusBadRequest, "missing_credentials", "identifier is required") return } - user, err := h.store.GetUserByName(c.Request.Context(), username) + user, err := h.findUserByIdentifier(c.Request.Context(), identifier) if err != nil { if errors.Is(err, store.ErrUserNotFound) { respondError(c, http.StatusNotFound, "user_not_found", "user not found") @@ -200,8 +232,53 @@ func (h *handler) login(c *gin.Context) { return } - if bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)) != nil { - respondError(c, http.StatusUnauthorized, "invalid_credentials", "invalid credentials") + if password != "" { + if bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)) != nil { + respondError(c, http.StatusUnauthorized, "invalid_credentials", "invalid credentials") + return + } + } else { + if totpCode == "" { + respondError(c, http.StatusBadRequest, "missing_credentials", "totp code is required") + return + } + if !strings.EqualFold(strings.TrimSpace(user.Email), identifier) { + respondError(c, http.StatusUnauthorized, "password_required", "password required for this identifier") + return + } + } + + if !user.MFAEnabled { + challengeToken, err := h.createMFAChallenge(user.ID) + if err != nil { + respondError(c, http.StatusInternalServerError, "mfa_challenge_failed", "failed to prepare mfa challenge") + return + } + c.JSON(http.StatusUnauthorized, gin.H{ + "error": "mfa_setup_required", + "mfaToken": challengeToken, + "user": sanitizeUser(user), + }) + return + } + + if totpCode == "" { + respondError(c, http.StatusBadRequest, "mfa_code_required", "totp code is required") + return + } + + valid, err := totp.ValidateCustom(totpCode, user.MFATOTPSecret, time.Now().UTC(), totp.ValidateOpts{ + Period: 30, + Skew: 1, + Digits: otp.DigitsSix, + Algorithm: otp.AlgorithmSHA1, + }) + if err != nil { + respondError(c, http.StatusInternalServerError, "invalid_mfa_code", "invalid totp code") + return + } + if !valid { + respondError(c, http.StatusUnauthorized, "invalid_mfa_code", "invalid totp code") return } @@ -219,6 +296,17 @@ func (h *handler) login(c *gin.Context) { }) } +func (h *handler) findUserByIdentifier(ctx context.Context, identifier string) (*store.User, error) { + user, err := h.store.GetUserByName(ctx, identifier) + if err == nil { + return user, nil + } + if err != nil && !errors.Is(err, store.ErrUserNotFound) { + return nil, err + } + return h.store.GetUserByEmail(ctx, identifier) +} + func (h *handler) session(c *gin.Context) { token := extractToken(c.GetHeader("Authorization")) if token == "" { @@ -263,11 +351,10 @@ func (h *handler) deleteSession(c *gin.Context) { } func (h *handler) createSession(userID string) (string, time.Time, error) { - buffer := make([]byte, 32) - if _, err := rand.Read(buffer); err != nil { + token, err := h.newRandomToken() + if err != nil { return "", time.Time{}, err } - token := hex.EncodeToString(buffer) ttl := h.sessionTTL if ttl <= 0 { ttl = defaultSessionTTL @@ -300,17 +387,301 @@ func (h *handler) removeSession(token string) { h.mu.Unlock() } +func (h *handler) newRandomToken() (string, error) { + buffer := make([]byte, 32) + if _, err := rand.Read(buffer); err != nil { + return "", err + } + return hex.EncodeToString(buffer), nil +} + +func (h *handler) createMFAChallenge(userID string) (string, error) { + token, err := h.newRandomToken() + if err != nil { + return "", err + } + ttl := h.mfaChallengeTTL + if ttl <= 0 { + ttl = defaultMFAChallengeTTL + } + challenge := mfaChallenge{userID: userID, expiresAt: time.Now().Add(ttl)} + h.mfaMu.Lock() + h.mfaChallenges[token] = challenge + h.mfaMu.Unlock() + return token, nil +} + +func (h *handler) lookupMFAChallenge(token string) (mfaChallenge, bool) { + h.mfaMu.RLock() + challenge, ok := h.mfaChallenges[token] + h.mfaMu.RUnlock() + if !ok { + return mfaChallenge{}, false + } + if time.Now().After(challenge.expiresAt) { + h.removeMFAChallenge(token) + return mfaChallenge{}, false + } + return challenge, true +} + +func (h *handler) refreshMFAChallenge(token string) (mfaChallenge, bool) { + ttl := h.mfaChallengeTTL + if ttl <= 0 { + ttl = defaultMFAChallengeTTL + } + h.mfaMu.Lock() + challenge, ok := h.mfaChallenges[token] + if ok { + challenge.expiresAt = time.Now().Add(ttl) + h.mfaChallenges[token] = challenge + } + h.mfaMu.Unlock() + if !ok { + return mfaChallenge{}, false + } + if time.Now().After(challenge.expiresAt) { + h.removeMFAChallenge(token) + return mfaChallenge{}, false + } + return challenge, true +} + +func (h *handler) removeMFAChallenge(token string) { + h.mfaMu.Lock() + delete(h.mfaChallenges, token) + h.mfaMu.Unlock() +} + +func (h *handler) provisionTOTP(c *gin.Context) { + var req struct { + Token string `json:"token"` + Issuer string `json:"issuer"` + Account string `json:"account"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + respondError(c, http.StatusBadRequest, "invalid_request", "invalid request payload") + return + } + + token := strings.TrimSpace(req.Token) + if token == "" { + respondError(c, http.StatusBadRequest, "mfa_token_required", "mfa token is required") + return + } + + challenge, ok := h.refreshMFAChallenge(token) + if !ok { + respondError(c, http.StatusUnauthorized, "invalid_mfa_token", "mfa token is invalid or expired") + return + } + + ctx := c.Request.Context() + user, err := h.store.GetUserByID(ctx, challenge.userID) + if err != nil { + respondError(c, http.StatusInternalServerError, "mfa_user_lookup_failed", "failed to load user for mfa provisioning") + return + } + + if user.MFAEnabled { + respondError(c, http.StatusBadRequest, "mfa_already_enabled", "mfa already enabled for this account") + return + } + + issuer := strings.TrimSpace(req.Issuer) + if issuer == "" { + issuer = h.totpIssuer + } + + accountName := strings.TrimSpace(req.Account) + if accountName == "" { + accountName = strings.TrimSpace(user.Email) + } + if accountName == "" { + accountName = strings.TrimSpace(user.Name) + } + if accountName == "" { + accountName = strings.TrimSpace(user.ID) + } + + key, err := totp.Generate(totp.GenerateOpts{ + Issuer: issuer, + AccountName: accountName, + Period: 30, + Digits: otp.DigitsSix, + Algorithm: otp.AlgorithmSHA1, + }) + if err != nil { + respondError(c, http.StatusInternalServerError, "mfa_secret_generation_failed", "failed to generate totp secret") + return + } + + user.MFATOTPSecret = key.Secret() + user.MFAEnabled = false + user.MFASecretIssuedAt = time.Now().UTC() + user.MFAConfirmedAt = time.Time{} + + if err := h.store.UpdateUser(ctx, user); err != nil { + respondError(c, http.StatusInternalServerError, "mfa_secret_persist_failed", "failed to persist totp secret") + return + } + + c.JSON(http.StatusOK, gin.H{ + "secret": user.MFATOTPSecret, + "uri": key.URL(), + "issuer": issuer, + "account": accountName, + "mfaToken": token, + "user": sanitizeUser(user), + }) +} + +func (h *handler) verifyTOTP(c *gin.Context) { + var req struct { + Token string `json:"token"` + Code string `json:"code"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + respondError(c, http.StatusBadRequest, "invalid_request", "invalid request payload") + return + } + + token := strings.TrimSpace(req.Token) + if token == "" { + respondError(c, http.StatusBadRequest, "mfa_token_required", "mfa token is required") + return + } + + challenge, ok := h.lookupMFAChallenge(token) + if !ok { + respondError(c, http.StatusUnauthorized, "invalid_mfa_token", "mfa token is invalid or expired") + return + } + + ctx := c.Request.Context() + user, err := h.store.GetUserByID(ctx, challenge.userID) + if err != nil { + respondError(c, http.StatusInternalServerError, "mfa_user_lookup_failed", "failed to load user for verification") + return + } + + if strings.TrimSpace(user.MFATOTPSecret) == "" { + respondError(c, http.StatusBadRequest, "mfa_secret_missing", "mfa secret has not been provisioned") + return + } + + code := strings.TrimSpace(req.Code) + if code == "" { + respondError(c, http.StatusBadRequest, "mfa_code_required", "totp code is required") + return + } + + if !totp.Validate(code, user.MFATOTPSecret) { + respondError(c, http.StatusUnauthorized, "invalid_mfa_code", "invalid totp code") + return + } + + user.MFAEnabled = true + user.MFAConfirmedAt = time.Now().UTC() + + if err := h.store.UpdateUser(ctx, user); err != nil { + respondError(c, http.StatusInternalServerError, "mfa_update_failed", "failed to enable mfa") + return + } + + h.removeMFAChallenge(token) + + sessionToken, expiresAt, err := h.createSession(user.ID) + if err != nil { + respondError(c, http.StatusInternalServerError, "session_creation_failed", "failed to create session") + return + } + + c.JSON(http.StatusOK, gin.H{ + "message": "mfa_verified", + "token": sessionToken, + "expiresAt": expiresAt.UTC(), + "user": sanitizeUser(user), + }) +} + +func (h *handler) mfaStatus(c *gin.Context) { + token := strings.TrimSpace(c.Query("token")) + if token == "" { + token = strings.TrimSpace(c.GetHeader("X-MFA-Token")) + } + + authToken := extractToken(c.GetHeader("Authorization")) + + var ( + user *store.User + err error + ) + + ctx := c.Request.Context() + + if authToken != "" { + if sess, ok := h.lookupSession(authToken); ok { + user, err = h.store.GetUserByID(ctx, sess.userID) + if err != nil { + respondError(c, http.StatusInternalServerError, "mfa_status_failed", "failed to load user for status") + return + } + } else if token == "" { + token = authToken + } + } + + if user == nil && token != "" { + if challenge, ok := h.lookupMFAChallenge(token); ok { + user, err = h.store.GetUserByID(ctx, challenge.userID) + if err != nil { + respondError(c, http.StatusInternalServerError, "mfa_status_failed", "failed to load user for status") + return + } + } + } + + if user == nil { + respondError(c, http.StatusUnauthorized, "mfa_token_required", "valid session or mfa token is required") + return + } + + c.JSON(http.StatusOK, gin.H{ + "mfa": buildMFAState(user), + "user": sanitizeUser(user), + }) +} + func sanitizeUser(user *store.User) gin.H { identifier := strings.TrimSpace(user.ID) return gin.H{ - "id": identifier, - "uuid": identifier, - "name": user.Name, - "username": user.Name, - "email": user.Email, + "id": identifier, + "uuid": identifier, + "name": user.Name, + "username": user.Name, + "email": user.Email, + "mfaEnabled": user.MFAEnabled, + "mfa": buildMFAState(user), } } +func buildMFAState(user *store.User) gin.H { + state := gin.H{ + "totpEnabled": user.MFAEnabled, + "totpPending": strings.TrimSpace(user.MFATOTPSecret) != "" && !user.MFAEnabled, + } + if !user.MFASecretIssuedAt.IsZero() { + state["totpSecretIssuedAt"] = user.MFASecretIssuedAt.UTC() + } + if !user.MFAConfirmedAt.IsZero() { + state["totpConfirmedAt"] = user.MFAConfirmedAt.UTC() + } + return state +} + func respondError(c *gin.Context, status int, code, message string) { c.JSON(status, gin.H{ "error": code, diff --git a/account/api/api_test.go b/account/api/api_test.go index b959172..f317580 100644 --- a/account/api/api_test.go +++ b/account/api/api_test.go @@ -6,10 +6,34 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/gin-gonic/gin" + "github.com/pquerna/otp" + "github.com/pquerna/otp/totp" ) +type apiResponse struct { + Message string `json:"message"` + Error string `json:"error"` + Token string `json:"token"` + MFAToken string `json:"mfaToken"` + User map[string]interface{} `json:"user"` + MFA map[string]interface{} `json:"mfa"` + Secret string `json:"secret"` + URI string `json:"uri"` + ExpiresAt string `json:"expiresAt"` +} + +func decodeResponse(t *testing.T, rr *httptest.ResponseRecorder) apiResponse { + t.Helper() + var resp apiResponse + if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + return resp +} + func TestRegisterEndpoint(t *testing.T) { gin.SetMode(gin.TestMode) @@ -37,41 +61,38 @@ func TestRegisterEndpoint(t *testing.T) { t.Fatalf("expected status %d, got %d, body: %s", http.StatusCreated, rr.Code, rr.Body.String()) } - var response struct { - Message string `json:"message"` - User map[string]any `json:"user"` - } - - if err := json.Unmarshal(rr.Body.Bytes(), &response); err != nil { - t.Fatalf("failed to decode response: %v", err) - } - - if response.User == nil { + resp := decodeResponse(t, rr) + if resp.User == nil { t.Fatalf("expected user object in response") } - if email, ok := response.User["email"].(string); !ok || email != payload["email"] { - t.Fatalf("expected email %q, got %#v", payload["email"], response.User["email"]) + if email, ok := resp.User["email"].(string); !ok || email != payload["email"] { + t.Fatalf("expected email %q, got %#v", payload["email"], resp.User["email"]) } - if id, ok := response.User["id"].(string); !ok || id == "" { - t.Fatalf("expected user id in response, got %#v", response.User["id"]) - } else { - if uuid, ok := response.User["uuid"].(string); !ok || uuid != id { - t.Fatalf("expected uuid to match id, got id=%q uuid=%#v", id, response.User["uuid"]) - } + if id, ok := resp.User["id"].(string); !ok || id == "" { + t.Fatalf("expected user id in response") + } else if uuid, ok := resp.User["uuid"].(string); !ok || uuid != id { + t.Fatalf("expected uuid to match id") } - if response.Message == "" { - t.Fatalf("expected success message in response") + if mfaEnabled, ok := resp.User["mfaEnabled"].(bool); !ok || mfaEnabled { + t.Fatalf("expected mfaEnabled to be false, got %#v", resp.User["mfaEnabled"]) } - if _, exists := response.User["password"]; exists { - t.Fatalf("response should not include password field") + mfaData, ok := resp.User["mfa"].(map[string]interface{}) + if !ok { + t.Fatalf("expected mfa state in user payload") + } + if enabled, ok := mfaData["totpEnabled"].(bool); !ok || enabled { + t.Fatalf("expected totpEnabled to be false, got %#v", mfaData["totpEnabled"]) + } + if pending, ok := mfaData["totpPending"].(bool); !ok || pending { + t.Fatalf("expected totpPending to be false, got %#v", mfaData["totpPending"]) } } -func TestLoginEndpoint(t *testing.T) { +func TestMFATOTPFlow(t *testing.T) { gin.SetMode(gin.TestMode) router := gin.New() @@ -82,7 +103,6 @@ func TestLoginEndpoint(t *testing.T) { "email": "login@example.com", "password": "supersecure", } - registerBody, err := json.Marshal(registerPayload) if err != nil { t.Fatalf("failed to marshal registration payload: %v", err) @@ -97,10 +117,9 @@ func TestLoginEndpoint(t *testing.T) { } loginPayload := map[string]string{ - "username": "Login User", - "password": registerPayload["password"], + "identifier": "Login User", + "password": registerPayload["password"], } - loginBody, err := json.Marshal(loginPayload) if err != nil { t.Fatalf("failed to marshal login payload: %v", err) @@ -111,164 +130,151 @@ func TestLoginEndpoint(t *testing.T) { rr = httptest.NewRecorder() router.ServeHTTP(rr, req) - if rr.Code != http.StatusOK { - t.Fatalf("expected login success, got %d: %s", rr.Code, rr.Body.String()) - } - - var loginResponse struct { - Message string `json:"message"` - Token string `json:"token"` - User map[string]interface{} `json:"user"` - } - if err := json.Unmarshal(rr.Body.Bytes(), &loginResponse); err != nil { - t.Fatalf("failed to decode login response: %v", err) - } - - if id, ok := loginResponse.User["id"].(string); !ok || id == "" { - t.Fatalf("expected user id in login response, got %#v", loginResponse.User["id"]) - } else { - if uuid, ok := loginResponse.User["uuid"].(string); !ok || uuid != id { - t.Fatalf("expected login uuid to match id, got id=%q uuid=%#v", id, loginResponse.User["uuid"]) - } - } - - if loginResponse.Message == "" { - t.Fatalf("expected login success message") - } - if loginResponse.Token == "" { - t.Fatalf("expected session token in login response") - } - if username, ok := loginResponse.User["username"].(string); !ok || username != registerPayload["name"] { - t.Fatalf("expected username %q in response, got %#v", registerPayload["name"], loginResponse.User["username"]) - } - - // Wrong password - loginPayload["password"] = "wrongpass" - loginBody, err = json.Marshal(loginPayload) - if err != nil { - t.Fatalf("failed to marshal invalid login payload: %v", err) - } - - req = httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewReader(loginBody)) - req.Header.Set("Content-Type", "application/json") - rr = httptest.NewRecorder() - router.ServeHTTP(rr, req) - if rr.Code != http.StatusUnauthorized { - t.Fatalf("expected unauthorized for wrong password, got %d", rr.Code) + t.Fatalf("expected login to require mfa setup, got %d", rr.Code) + } + resp := decodeResponse(t, rr) + if resp.Error != "mfa_setup_required" { + t.Fatalf("expected mfa_setup_required error, got %q", resp.Error) + } + if resp.MFAToken == "" { + t.Fatalf("expected mfa token in response") } - var errorResponse struct { - Error string `json:"error"` + provisionPayload := map[string]string{ + "token": resp.MFAToken, } - if err := json.Unmarshal(rr.Body.Bytes(), &errorResponse); err != nil { - t.Fatalf("failed to decode wrong password response: %v", err) - } - if errorResponse.Error != "invalid_credentials" { - t.Fatalf("expected invalid_credentials error, got %q", errorResponse.Error) - } - - // Unknown user - loginPayload["username"] = "missing-user" - loginPayload["password"] = registerPayload["password"] - loginBody, err = json.Marshal(loginPayload) + provisionBody, err := json.Marshal(provisionPayload) if err != nil { - t.Fatalf("failed to marshal missing user payload: %v", err) + t.Fatalf("failed to marshal provision payload: %v", err) } - req = httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewReader(loginBody)) + req = httptest.NewRequest(http.MethodPost, "/api/auth/mfa/totp/provision", bytes.NewReader(provisionBody)) req.Header.Set("Content-Type", "application/json") rr = httptest.NewRecorder() router.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("expected provisioning success, got %d: %s", rr.Code, rr.Body.String()) + } + resp = decodeResponse(t, rr) + if resp.Secret == "" { + t.Fatalf("expected totp secret in provisioning response") + } + if resp.URI == "" { + t.Fatalf("expected otpauth uri in provisioning response") + } + secret := resp.Secret - if rr.Code != http.StatusNotFound { - t.Fatalf("expected not found for missing user, got %d", rr.Code) + generateCode := func(offset time.Duration) string { + code, err := totp.GenerateCodeCustom(secret, time.Now().UTC().Add(offset), totp.ValidateOpts{ + Period: 30, + Skew: 1, + Digits: otp.DigitsSix, + Algorithm: otp.AlgorithmSHA1, + }) + if err != nil { + t.Fatalf("failed to generate verification code: %v", err) + } + return code } - if err := json.Unmarshal(rr.Body.Bytes(), &errorResponse); err != nil { - t.Fatalf("failed to decode missing user response: %v", err) + + code := generateCode(0) + + verifyPayload := map[string]string{ + "token": resp.MFAToken, + "code": code, } - if errorResponse.Error != "user_not_found" { - t.Fatalf("expected user_not_found error, got %q", errorResponse.Error) - } -} - -func TestRegisterRejectsDuplicateIdentifiers(t *testing.T) { - gin.SetMode(gin.TestMode) - - router := gin.New() - RegisterRoutes(router) - - basePayload := map[string]string{ - "name": "Existing User", - "email": "existing@example.com", - "password": "supersecure", - } - - body, err := json.Marshal(basePayload) - if err != nil { - t.Fatalf("failed to marshal payload: %v", err) - } - - req := httptest.NewRequest(http.MethodPost, "/api/auth/register", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - rr := httptest.NewRecorder() - router.ServeHTTP(rr, req) - if rr.Code != http.StatusCreated { - t.Fatalf("expected initial registration to succeed, got %d", rr.Code) - } - - // Duplicate email - payload := map[string]string{ - "name": "Another User", - "email": basePayload["email"], - "password": "supersecure", - } - body, err = json.Marshal(payload) - if err != nil { - t.Fatalf("failed to marshal payload: %v", err) - } - - req = httptest.NewRequest(http.MethodPost, "/api/auth/register", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - rr = httptest.NewRecorder() - router.ServeHTTP(rr, req) - if rr.Code != http.StatusConflict { - t.Fatalf("expected conflict for duplicate email, got %d", rr.Code) - } - - var conflictResp struct { - Error string `json:"error"` - } - if err := json.Unmarshal(rr.Body.Bytes(), &conflictResp); err != nil { - t.Fatalf("failed to decode duplicate email response: %v", err) - } - if conflictResp.Error != "email_already_exists" { - t.Fatalf("expected email_already_exists error, got %q", conflictResp.Error) - } - - // Duplicate name - payload = map[string]string{ - "name": basePayload["name"], - "email": "unique@example.com", - "password": "supersecure", - } - body, err = json.Marshal(payload) - if err != nil { - t.Fatalf("failed to marshal payload: %v", err) - } - - req = httptest.NewRequest(http.MethodPost, "/api/auth/register", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - rr = httptest.NewRecorder() - router.ServeHTTP(rr, req) - if rr.Code != http.StatusConflict { - t.Fatalf("expected conflict for duplicate name, got %d", rr.Code) - } - - if err := json.Unmarshal(rr.Body.Bytes(), &conflictResp); err != nil { - t.Fatalf("failed to decode duplicate name response: %v", err) - } - if conflictResp.Error != "name_already_exists" { - t.Fatalf("expected name_already_exists error, got %q", conflictResp.Error) + verifyBody, err := json.Marshal(verifyPayload) + if err != nil { + t.Fatalf("failed to marshal verify payload: %v", err) + } + + req = httptest.NewRequest(http.MethodPost, "/api/auth/mfa/totp/verify", bytes.NewReader(verifyBody)) + req.Header.Set("Content-Type", "application/json") + rr = httptest.NewRecorder() + router.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("expected verification success, got %d: %s", rr.Code, rr.Body.String()) + } + resp = decodeResponse(t, rr) + if resp.Token == "" { + t.Fatalf("expected session token after verification") + } + if resp.User == nil || resp.User["mfaEnabled"] != true { + t.Fatalf("expected mfaEnabled true after verification") + } + + sessionReq := httptest.NewRequest(http.MethodGet, "/api/auth/session", nil) + sessionReq.Header.Set("Authorization", "Bearer "+resp.Token) + sessionRec := httptest.NewRecorder() + router.ServeHTTP(sessionRec, sessionReq) + if sessionRec.Code != http.StatusOK { + t.Fatalf("expected session lookup success, got %d", sessionRec.Code) + } + sessionResp := decodeResponse(t, sessionRec) + if sessionResp.User == nil { + t.Fatalf("expected user in session response") + } + if sessionResp.User["mfaEnabled"] != true { + t.Fatalf("expected session user to have mfaEnabled true") + } + + statusReq := httptest.NewRequest(http.MethodGet, "/api/auth/mfa/status", nil) + statusReq.Header.Set("Authorization", "Bearer "+resp.Token) + statusRec := httptest.NewRecorder() + router.ServeHTTP(statusRec, statusReq) + if statusRec.Code != http.StatusOK { + t.Fatalf("expected status success, got %d", statusRec.Code) + } + + loginWithTotp := func(body map[string]string) *httptest.ResponseRecorder { + payload, err := json.Marshal(body) + if err != nil { + t.Fatalf("failed to marshal login payload: %v", err) + } + request := httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewReader(payload)) + request.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, request) + return recorder + } + + time.Sleep(1 * time.Second) + totpCode := generateCode(0) + if ok, _ := totp.ValidateCustom(totpCode, secret, time.Now().UTC(), totp.ValidateOpts{ + Period: 30, + Skew: 1, + Digits: otp.DigitsSix, + Algorithm: otp.AlgorithmSHA1, + }); !ok { + t.Fatalf("locally generated totp code is invalid") + } + + rr = loginWithTotp(map[string]string{ + "identifier": "Login User", + "password": registerPayload["password"], + "totpCode": totpCode, + }) + if rr.Code != http.StatusOK { + t.Fatalf("expected mfa login success, got %d: %s", rr.Code, rr.Body.String()) + } + + time.Sleep(1 * time.Second) + totpCode = generateCode(0) + if ok, _ := totp.ValidateCustom(totpCode, secret, time.Now().UTC(), totp.ValidateOpts{ + Period: 30, + Skew: 1, + Digits: otp.DigitsSix, + Algorithm: otp.AlgorithmSHA1, + }); !ok { + t.Fatalf("locally generated totp code is invalid (email login)") + } + + rr = loginWithTotp(map[string]string{ + "identifier": registerPayload["email"], + "totpCode": totpCode, + }) + if rr.Code != http.StatusOK { + t.Fatalf("expected email+totp login success, got %d: %s", rr.Code, rr.Body.String()) } } diff --git a/account/cmd/accountsvc/main.go b/account/cmd/accountsvc/main.go index cd74a8f..4472bac 100644 --- a/account/cmd/accountsvc/main.go +++ b/account/cmd/accountsvc/main.go @@ -2,8 +2,11 @@ package main import ( "context" + "crypto/tls" + "crypto/x509" "errors" "log/slog" + "net" "net/http" "os" "strings" @@ -86,6 +89,27 @@ var rootCmd = &cobra.Command{ addr = ":8080" } + tlsSettings := cfg.Server.TLS + certFile := strings.TrimSpace(tlsSettings.CertFile) + keyFile := strings.TrimSpace(tlsSettings.KeyFile) + clientCAFile := strings.TrimSpace(tlsSettings.ClientCAFile) + + tlsConfig := &tls.Config{MinVersion: tls.VersionTLS12} + if clientCAFile != "" { + caBytes, err := os.ReadFile(clientCAFile) + if err != nil { + return err + } + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(caBytes) { + return errors.New("failed to parse client CA file") + } + tlsConfig.ClientCAs = pool + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + } + + useTLS := certFile != "" && keyFile != "" + srv := &http.Server{ Addr: addr, Handler: r, @@ -93,11 +117,45 @@ var rootCmd = &cobra.Command{ WriteTimeout: cfg.Server.WriteTimeout, } - logger.Info("starting account service", "addr", addr) - if err := srv.ListenAndServe(); err != nil { - if !errors.Is(err, http.ErrServerClosed) { - logger.Error("account service shutdown", "err", err) - return err + if useTLS { + srv.TLSConfig = tlsConfig + } + + logger.Info("starting account service", "addr", addr, "tls", useTLS) + + if useTLS { + if tlsSettings.RedirectHTTP { + go func() { + redirectAddr := deriveRedirectAddr(addr) + redirectSrv := &http.Server{ + Addr: redirectAddr, + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + host := r.Host + if host == "" { + host = redirectAddr + } + target := "https://" + host + r.URL.RequestURI() + http.Redirect(w, r, target, http.StatusPermanentRedirect) + }), + } + if err := redirectSrv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + logger.Error("http redirect listener exited", "err", err) + } + }() + } + + if err := srv.ListenAndServeTLS(certFile, keyFile); err != nil { + if !errors.Is(err, http.ErrServerClosed) { + logger.Error("account service shutdown", "err", err) + return err + } + } + } else { + if err := srv.ListenAndServe(); err != nil { + if !errors.Is(err, http.ErrServerClosed) { + logger.Error("account service shutdown", "err", err) + return err + } } } return nil @@ -114,3 +172,22 @@ func main() { os.Exit(1) } } + +func deriveRedirectAddr(addr string) string { + host, port, err := net.SplitHostPort(strings.TrimSpace(addr)) + if err != nil { + trimmed := strings.TrimSpace(addr) + if strings.HasPrefix(trimmed, ":") { + port = strings.TrimPrefix(trimmed, ":") + if port == "" || port == "443" { + return ":80" + } + return ":" + port + } + return ":80" + } + if port == "" || port == "443" { + port = "80" + } + return net.JoinHostPort(host, port) +} diff --git a/account/config/account.yaml b/account/config/account.yaml index 15e521b..23e59a3 100644 --- a/account/config/account.yaml +++ b/account/config/account.yaml @@ -5,6 +5,11 @@ server: addr: ":8080" readTimeout: 15s writeTimeout: 15s + tls: + certFile: "" + keyFile: "" + clientCAFile: "" + redirectHttp: false store: driver: "postgres" diff --git a/account/config/config.go b/account/config/config.go index 7621d49..34ce73c 100644 --- a/account/config/config.go +++ b/account/config/config.go @@ -26,9 +26,18 @@ type Config struct { // Server defines HTTP server configuration. type Server struct { - Addr string `yaml:"addr"` - ReadTimeout time.Duration `yaml:"readTimeout"` - WriteTimeout time.Duration `yaml:"writeTimeout"` + Addr string `yaml:"addr"` + ReadTimeout time.Duration `yaml:"readTimeout"` + WriteTimeout time.Duration `yaml:"writeTimeout"` + TLS TLS `yaml:"tls"` +} + +// TLS describes TLS configuration for the server listener. +type TLS struct { + CertFile string `yaml:"certFile"` + KeyFile string `yaml:"keyFile"` + ClientCAFile string `yaml:"clientCAFile"` + RedirectHTTP bool `yaml:"redirectHttp"` } // Store defines persistence configuration for the account service. diff --git a/account/internal/store/postgres.go b/account/internal/store/postgres.go index b18112d..af72811 100644 --- a/account/internal/store/postgres.go +++ b/account/internal/store/postgres.go @@ -100,11 +100,12 @@ func (s *postgresStore) CreateUser(ctx context.Context, user *User) error { query := `INSERT INTO users (username, email, password) VALUES ($1, $2, $3) - RETURNING uuid, coalesce(created_at, now())` + RETURNING uuid, coalesce(created_at, now()), coalesce(updated_at, now())` var idValue any var createdAt time.Time - err = s.db.QueryRowContext(ctx, query, normalizedName, normalizedEmail, user.PasswordHash).Scan(&idValue, &createdAt) + var updatedAt time.Time + err = s.db.QueryRowContext(ctx, query, normalizedName, normalizedEmail, user.PasswordHash).Scan(&idValue, &createdAt, &updatedAt) if err != nil { if errors.Is(err, sql.ErrNoRows) { return ErrUserNotFound @@ -132,6 +133,7 @@ func (s *postgresStore) CreateUser(ctx context.Context, user *User) error { user.Name = normalizedName user.Email = normalizedEmail user.CreatedAt = createdAt.UTC() + user.UpdatedAt = updatedAt.UTC() return nil } @@ -141,7 +143,8 @@ func (s *postgresStore) GetUserByEmail(ctx context.Context, email string) (*User return nil, ErrUserNotFound } - query := `SELECT uuid, username, email, password, coalesce(created_at, now()) + query := `SELECT uuid, username, email, password, mfa_totp_secret, coalesce(mfa_enabled, false), + mfa_secret_issued_at, mfa_confirmed_at, coalesce(created_at, now()), coalesce(updated_at, now()) FROM users WHERE lower(email) = $1 LIMIT 1` row := s.db.QueryRowContext(ctx, query, normalized) @@ -154,7 +157,8 @@ func (s *postgresStore) GetUserByName(ctx context.Context, name string) (*User, return nil, ErrUserNotFound } - query := `SELECT uuid, username, email, password, coalesce(created_at, now()) + query := `SELECT uuid, username, email, password, mfa_totp_secret, coalesce(mfa_enabled, false), + mfa_secret_issued_at, mfa_confirmed_at, coalesce(created_at, now()), coalesce(updated_at, now()) FROM users WHERE lower(username) = lower($1) LIMIT 1` row := s.db.QueryRowContext(ctx, query, normalized) @@ -162,7 +166,8 @@ func (s *postgresStore) GetUserByName(ctx context.Context, name string) (*User, } func (s *postgresStore) GetUserByID(ctx context.Context, id string) (*User, error) { - query := `SELECT uuid, username, email, password, coalesce(created_at, now()) + query := `SELECT uuid, username, email, password, mfa_totp_secret, coalesce(mfa_enabled, false), + mfa_secret_issued_at, mfa_confirmed_at, coalesce(created_at, now()), coalesce(updated_at, now()) FROM users WHERE uuid = $1` row := s.db.QueryRowContext(ctx, query, id) @@ -207,14 +212,19 @@ type rowScanner interface { func scanUser(row rowScanner) (*User, error) { var ( - idValue any - username sql.NullString - email sql.NullString - password sql.NullString - createdAt time.Time + idValue any + username sql.NullString + email sql.NullString + password sql.NullString + mfaSecret sql.NullString + mfaEnabled sql.NullBool + mfaSecretIssued sql.NullTime + mfaConfirmed sql.NullTime + createdAt time.Time + updatedAt time.Time ) - if err := row.Scan(&idValue, &username, &email, &password, &createdAt); err != nil { + if err := row.Scan(&idValue, &username, &email, &password, &mfaSecret, &mfaEnabled, &mfaSecretIssued, &mfaConfirmed, &createdAt, &updatedAt); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrUserNotFound } @@ -227,15 +237,90 @@ func scanUser(row rowScanner) (*User, error) { } user := &User{ - ID: identifier, - Name: strings.TrimSpace(username.String), - Email: strings.ToLower(strings.TrimSpace(email.String)), - PasswordHash: password.String, - CreatedAt: createdAt.UTC(), + ID: identifier, + Name: strings.TrimSpace(username.String), + Email: strings.ToLower(strings.TrimSpace(email.String)), + PasswordHash: password.String, + MFATOTPSecret: strings.TrimSpace(mfaSecret.String), + MFAEnabled: mfaEnabled.Bool, + MFASecretIssuedAt: toUTCTime(mfaSecretIssued), + MFAConfirmedAt: toUTCTime(mfaConfirmed), + CreatedAt: createdAt.UTC(), + UpdatedAt: updatedAt.UTC(), } return user, nil } +func (s *postgresStore) UpdateUser(ctx context.Context, user *User) error { + normalizedName := strings.TrimSpace(user.Name) + if normalizedName == "" { + return ErrInvalidName + } + + normalizedEmail := strings.ToLower(strings.TrimSpace(user.Email)) + var issuedAt any + if !user.MFASecretIssuedAt.IsZero() { + issuedAt = user.MFASecretIssuedAt.UTC() + } + var confirmedAt any + if !user.MFAConfirmedAt.IsZero() { + confirmedAt = user.MFAConfirmedAt.UTC() + } + + query := `UPDATE users + SET username = $1, + email = $2, + password = $3, + mfa_totp_secret = $4, + mfa_enabled = $5, + mfa_secret_issued_at = $6, + mfa_confirmed_at = $7, + updated_at = now() + WHERE uuid = $8 + RETURNING coalesce(created_at, now()), coalesce(updated_at, now())` + + var createdAt time.Time + var updatedAt time.Time + err := s.db.QueryRowContext(ctx, query, normalizedName, normalizedEmail, user.PasswordHash, nullForEmpty(user.MFATOTPSecret), user.MFAEnabled, issuedAt, confirmedAt, user.ID).Scan(&createdAt, &updatedAt) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return ErrUserNotFound + } + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) { + if pgErr.Code == "23505" { + switch { + case strings.Contains(pgErr.ConstraintName, "email"): + return ErrEmailExists + case strings.Contains(pgErr.ConstraintName, "name") || strings.Contains(pgErr.ConstraintName, "username"): + return ErrNameExists + } + } + } + return err + } + + user.Name = normalizedName + user.Email = normalizedEmail + user.CreatedAt = createdAt.UTC() + user.UpdatedAt = updatedAt.UTC() + return nil +} + +func nullForEmpty(value string) any { + if strings.TrimSpace(value) == "" { + return nil + } + return value +} + +func toUTCTime(value sql.NullTime) time.Time { + if !value.Valid { + return time.Time{} + } + return value.Time.UTC() +} + func formatIdentifier(value any) (string, error) { switch v := value.(type) { case nil: diff --git a/account/internal/store/store.go b/account/internal/store/store.go index 2d926fc..02ed43f 100644 --- a/account/internal/store/store.go +++ b/account/internal/store/store.go @@ -12,11 +12,16 @@ import ( // User represents an account within the account service domain. type User struct { - ID string - Name string - Email string - PasswordHash string - CreatedAt time.Time + ID string + Name string + Email string + PasswordHash string + MFATOTPSecret string + MFAEnabled bool + MFASecretIssuedAt time.Time + MFAConfirmedAt time.Time + CreatedAt time.Time + UpdatedAt time.Time } // Store provides persistence operations for users. @@ -25,6 +30,7 @@ type Store interface { GetUserByEmail(ctx context.Context, email string) (*User, error) GetUserByID(ctx context.Context, id string) (*User, error) GetUserByName(ctx context.Context, name string) (*User, error) + UpdateUser(ctx context.Context, user *User) error } // Domain level errors returned by the store implementation. @@ -77,13 +83,22 @@ func (s *memoryStore) CreateUser(ctx context.Context, user *User) error { userCopy.ID = uuid.NewString() } if userCopy.CreatedAt.IsZero() { - userCopy.CreatedAt = time.Now().UTC() + now := time.Now().UTC() + userCopy.CreatedAt = now + if userCopy.UpdatedAt.IsZero() { + userCopy.UpdatedAt = now + } + } + if userCopy.UpdatedAt.IsZero() { + userCopy.UpdatedAt = time.Now().UTC() } userCopy.Email = loweredEmail userCopy.Name = normalizedName stored := userCopy s.byID[userCopy.ID] = &stored - s.byEmail[loweredEmail] = &stored + if loweredEmail != "" { + s.byEmail[loweredEmail] = &stored + } s.byName[strings.ToLower(normalizedName)] = &stored *user = stored return nil @@ -137,3 +152,73 @@ func (s *memoryStore) GetUserByName(ctx context.Context, name string) (*User, er clone := *user return &clone, nil } + +// UpdateUser replaces the persisted user representation in memory. +func (s *memoryStore) UpdateUser(ctx context.Context, user *User) error { + _ = ctx + s.mu.Lock() + defer s.mu.Unlock() + + existing, ok := s.byID[user.ID] + if !ok { + return ErrUserNotFound + } + + normalizedName := strings.TrimSpace(user.Name) + loweredEmail := strings.ToLower(strings.TrimSpace(user.Email)) + + if normalizedName == "" { + return ErrInvalidName + } + + // Re-index username if it changed. + oldNameKey := strings.ToLower(existing.Name) + newNameKey := strings.ToLower(normalizedName) + if oldNameKey != newNameKey { + if _, exists := s.byName[newNameKey]; exists { + return ErrNameExists + } + delete(s.byName, oldNameKey) + } + + // Re-index email if it changed. + oldEmailKey := strings.ToLower(existing.Email) + if oldEmailKey != loweredEmail { + if loweredEmail != "" { + if _, exists := s.byEmail[loweredEmail]; exists { + return ErrEmailExists + } + } + if oldEmailKey != "" { + delete(s.byEmail, oldEmailKey) + } + } + + updated := *existing + updated.Name = normalizedName + updated.Email = loweredEmail + updated.PasswordHash = user.PasswordHash + updated.MFATOTPSecret = user.MFATOTPSecret + updated.MFAEnabled = user.MFAEnabled + updated.MFASecretIssuedAt = user.MFASecretIssuedAt + updated.MFAConfirmedAt = user.MFAConfirmedAt + if user.CreatedAt.IsZero() { + updated.CreatedAt = existing.CreatedAt + } else { + updated.CreatedAt = user.CreatedAt + } + if user.UpdatedAt.IsZero() { + updated.UpdatedAt = time.Now().UTC() + } else { + updated.UpdatedAt = user.UpdatedAt + } + + s.byID[user.ID] = &updated + s.byName[newNameKey] = &updated + if loweredEmail != "" { + s.byEmail[loweredEmail] = &updated + } + + *user = updated + return nil +} diff --git a/account/sql/20251002-add-mfa-columns.sql b/account/sql/20251002-add-mfa-columns.sql new file mode 100644 index 0000000..615b56c --- /dev/null +++ b/account/sql/20251002-add-mfa-columns.sql @@ -0,0 +1,6 @@ +ALTER TABLE users + ADD COLUMN IF NOT EXISTS mfa_totp_secret TEXT, + ADD COLUMN IF NOT EXISTS mfa_enabled BOOLEAN NOT NULL DEFAULT FALSE, + ADD COLUMN IF NOT EXISTS mfa_secret_issued_at TIMESTAMPTZ, + ADD COLUMN IF NOT EXISTS mfa_confirmed_at TIMESTAMPTZ, + ADD COLUMN IF NOT EXISTS updated_at TIMESTAMPTZ DEFAULT now(); diff --git a/account/sql/schema.sql b/account/sql/schema.sql index 5fe3970..c9cdf02 100644 --- a/account/sql/schema.sql +++ b/account/sql/schema.sql @@ -8,7 +8,12 @@ CREATE TABLE IF NOT EXISTS users ( username TEXT NOT NULL UNIQUE, password TEXT NOT NULL, email TEXT, - created_at TIMESTAMPTZ DEFAULT now() + mfa_totp_secret TEXT, + mfa_enabled BOOLEAN NOT NULL DEFAULT FALSE, + mfa_secret_issued_at TIMESTAMPTZ, + mfa_confirmed_at TIMESTAMPTZ, + created_at TIMESTAMPTZ DEFAULT now(), + updated_at TIMESTAMPTZ DEFAULT now() ); CREATE TABLE IF NOT EXISTS identities ( diff --git a/docs/account-service-configuration.md b/docs/account-service-configuration.md index 8043d0a..4d01249 100644 --- a/docs/account-service-configuration.md +++ b/docs/account-service-configuration.md @@ -4,121 +4,99 @@ ## 1. 配置加载策略 -当前服务入口(`account/cmd/accountsvc/main.go`)直接创建 Gin 引擎并注册路由,尚未接入统一的配置加载逻辑。【F:account/cmd/accountsvc/main.go†L1-L12】 +账号服务入口(`account/cmd/accountsvc/main.go`)会调用 `config.Load` 读取 YAML 配置,并允许通过命令行参数覆盖默认路径。当未提供配置文件时,服务会以零值启动,此时可结合环境变量填充关键字段。 -为满足生产需求,建议按以下优先级加载配置: +当前推荐的覆盖顺序如下: -1. **命令行参数**:覆盖性最高,用于临时指定端口或配置文件路径。 -2. **环境变量**:适用于容器化部署,通过 `ACCOUNT_*` 前缀管理。 -3. **配置文件**:默认从 `config/account.yaml` 或 `config/account.json` 中读取。 -4. **内置默认值**:在 `account/config/config.go` 中定义结构体并赋予默认值,保证在缺省配置下仍可运行。【F:account/config/config.go†L1-L5】 +1. **命令行参数**:用于指定配置文件路径或运行模式。 +2. **配置文件**:默认从 `account/config/account.yaml` 读取,适合提交到仓库或挂载到容器内。 +3. **代码默认值**:`config.Config` 结构体中的零值,保证最小可运行。 -## 2. 建议的配置结构 +> 注:目前服务尚未内置环境变量映射逻辑,如需按环境注入配置,可在部署流程中提前生成 YAML 文件或扩展 `config.Load`。 -未来扩展时,可按以下结构扩充 `Config`: +## 2. 配置字段参考 + +`account/config/config.go` 定义了配置结构,主要包含以下几个部分: ```yaml +log: + level: info # 可选:debug、info、warn、error + server: - addr: ":8080" - readTimeout: 10s - writeTimeout: 10s - idleTimeout: 60s + addr: ":8080" # 监听地址 + readTimeout: 15s # 读取超时 + writeTimeout: 15s # 写入超时 + tls: # 启用 HTTPS 时的证书配置 + certFile: "/etc/ssl/certs/account.pem" + keyFile: "/etc/ssl/private/account.key" + clientCAFile: "" # (可选)双向 TLS CA + redirectHttp: false # 当启用 TLS 时是否同时监听 HTTP 做 301 重定向 store: - driver: "memory" # 可选:memory、postgres、mysql - dsn: "postgres://user:pass@host:5432/account?sslmode=disable" - maxOpenConns: 20 - maxIdleConns: 5 - connMaxLifetime: 30m - -session: - ttl: 24h - cache: "memory" # 可选:memory、redis - redis: - addr: "redis:6379" - password: "" - db: 0 - -authProviders: - - name: "oidc" - issuer: "https://idp.example.com" - clientID: "xcontrol" - clientSecret: "${OIDC_CLIENT_SECRET}" - - name: "ldap" - addr: "ldap://ldap.example.com:389" - baseDN: "dc=example,dc=com" - bindDN: "cn=admin,dc=example,dc=com" - bindPassword: "${LDAP_BIND_PASSWORD}" -``` - -## 3. 环境变量示例 - -| 变量名 | 说明 | 示例 | -| ------ | ---- | ---- | -| `ACCOUNT_SERVER_ADDR` | 服务监听地址 | `:8080` | -| `ACCOUNT_STORE_DRIVER` | 存储驱动类型 | `postgres` | -| `ACCOUNT_STORE_DSN` | 存储连接串 | `postgres://user:pass@db:5432/account` | -| `ACCOUNT_SESSION_TTL` | 会话有效期(秒或 Go duration) | `24h` | -| `ACCOUNT_REDIS_ADDR` | Redis 地址(当 cache=redis 时使用) | `redis:6379` | -| `ACCOUNT_LOG_LEVEL` | 日志级别 | `info` | - -在容器或 CI/CD 中,可借助 Secret/ConfigMap 注入敏感值,避免直接写入镜像。 - -## 4. 配置示例 - -### 4.1 开发环境 - -```yaml -server: - addr: ":8080" - -store: - driver: "memory" - -session: - ttl: 24h - cache: "memory" -``` - -### 4.2 测试/预生产环境 - -```yaml -server: - addr: ":8080" - readTimeout: 15s - writeTimeout: 15s - -store: - driver: "postgres" - dsn: "postgres://acct:acctpass@postgres:5432/account?sslmode=disable" + driver: "postgres" # 可选:memory、postgres + dsn: "postgres://user:pass@db:5432/account?sslmode=disable" maxOpenConns: 30 maxIdleConns: 10 session: - ttl: 24h - cache: "redis" - redis: - addr: "redis:6379" - password: "${REDIS_PASSWORD}" - -authProviders: - - name: "oidc" - issuer: "https://idp-pre.example.com" - clientID: "xcontrol" - clientSecret: "${OIDC_SECRET}" + ttl: 24h # 登录会话有效期 ``` -## 5. 配置校验与回滚 +**TLS 提示**:当 `certFile` 和 `keyFile` 都非空时,`accountsvc` 会调用 `ListenAndServeTLS` 启动 HTTPS。如果同时希望保留 80 端口,可将 `redirectHttp` 置为 `true`,服务会开启一个额外的明文监听,将请求 301 重定向到 HTTPS。 -- 在服务启动时验证必需字段是否填写,例如当 `driver=postgres` 时必须提供 `dsn`。 -- 提供配置热加载或版本化策略,例如通过 GitOps 将配置存储于仓库,变更可回滚。 -- 通过单元测试验证不同配置组合的解析结果,确保新字段向下兼容。 +**MFA 相关接口**:账号服务在 `/api/auth/mfa/*` 下提供 MFA 绑定与验证接口,默认无需额外配置即可使用,但生产环境建议将 `server.tls` 打开,确保 MFA 秘钥与 TOTP 码在传输过程中被加密。 -## 6. 与代码协同 +## 3. 配置示例 -- 在 `account/api` 中读取 `session.ttl` 替换硬编码的 `24 * time.Hour`,实现配置化。【F:account/api/api.go†L18-L171】 -- 在 `account/internal/store` 中根据 `store.driver` 实例化不同实现,实现从内存到数据库的无缝切换。【F:account/internal/store/store.go†L31-L109】 -- 在 `account/internal/auth` 中根据 `authProviders` 列表注册外部认证方式,实现多身份源并行校验。【F:account/internal/auth/auth.go†L1-L6】 +### 3.1 开发环境(HTTP + 内存存储) ---- -随着服务演进,应持续完善 `Config` 结构与加载逻辑,并在此文档中同步更新字段说明。 +```yaml +log: + level: debug +server: + addr: ":8080" + readTimeout: 0s + writeTimeout: 0s +store: + driver: "memory" +session: + ttl: 8h +``` + +### 3.2 生产环境(PostgreSQL + HTTPS + MFA) + +```yaml +log: + level: info +server: + addr: ":8443" + readTimeout: 15s + writeTimeout: 15s + tls: + certFile: "/etc/ssl/certs/account.pem" + keyFile: "/etc/ssl/private/account.key" + redirectHttp: true +store: + driver: "postgres" + dsn: "postgres://account:strongpass@db:5432/account?sslmode=require" + maxOpenConns: 50 + maxIdleConns: 10 +session: + ttl: 24h +``` + +在生产环境中,建议通过 Kubernetes Secret、Vault 等方式挂载证书文件,并使用 `redirectHttp` 确保历史链接能够自动切换到 HTTPS。 + +## 4. 配置校验与回滚 + +- 启动时若启用 PostgreSQL,请确保 `dsn` 可用,否则服务会在初始化阶段返回错误。 +- TLS 文件路径错误会导致启动失败,建议在 CI/CD 中加入探针验证。 +- 通过 Git 管理配置文件,配合版本标签可实现快速回滚。 + +## 5. 与其他模块的协同 + +- 登录会话 TTL 会同步影响 `/api/auth/login`、`/api/auth/session` 等接口返回的 cookie 过期时间。 +- 新增的 MFA 接口(`/api/auth/mfa/totp/provision`、`/api/auth/mfa/totp/verify`、`/api/auth/mfa/status`)在 HTTPS 环境下可与前端 MFA 向导配合使用,保证首次登录后必须完成绑定。 +- 如果部署了前端 Next.js 应用,请确保其 `.env` 中的 `ACCOUNT_API_BASE` 指向启用了 TLS 的账号服务地址。 + +随着服务演进,请在更新配置结构或新字段时同步维护本文档。 diff --git a/docs/account-service-deployment.md b/docs/account-service-deployment.md index 4c36ec7..7022bea 100644 --- a/docs/account-service-deployment.md +++ b/docs/account-service-deployment.md @@ -4,9 +4,9 @@ ## 1. 运行时依赖 -- Go 1.22 及以上版本,用于编译服务。 -- (可选)PostgreSQL、Redis 等外部组件,当前实现默认使用内存存储,可在后续扩展中替换。 -- Make 与 Git(可选),用于辅助构建与版本管理。 +- Go 1.22 及以上版本,用于编译和运行服务。 +- PostgreSQL(推荐)或内存存储:MFA 状态、TOTP 秘钥等信息会持久化在用户表中,生产环境请使用数据库。 +- (可选)反向代理或负载均衡器,用于在 TLS 终止后分发流量。 ## 2. 本地开发部署 @@ -16,75 +16,150 @@ cd XControl ``` -2. **启动服务** +2. **准备配置** + 使用仓库提供的 `account/config/account.yaml`,或根据需要拷贝一份修改端口、数据库连接等字段。 + +3. **启动服务(HTTP)** ```bash - go run ./account/cmd/accountsvc + go run ./account/cmd/accountsvc --config account/config/account.yaml ``` 默认监听 `:8080`,可通过 `curl http://127.0.0.1:8080/healthz` 检查服务状态。 -3. **交互测试** - - 注册账号: - ```bash - curl -X POST http://127.0.0.1:8080/v1/register \ - -H 'Content-Type: application/json' \ - -d '{"name":"demo","email":"demo@example.com","password":"Secret123"}' - ``` - - 登录获取 token: - ```bash - curl -X POST http://127.0.0.1:8080/v1/login \ - -H 'Content-Type: application/json' \ - -d '{"email":"demo@example.com","password":"Secret123"}' - ``` +4. **交互测试:注册、绑定 MFA 与登录** -## 3. Docker 镜像部署 + ```bash + # 注册账号 + curl -X POST http://127.0.0.1:8080/api/auth/register \ + -H 'Content-Type: application/json' \ + -d '{"name":"demo","email":"demo@example.com","password":"Secret123"}' -1. **构建镜像(示例 Dockerfile 需后续补充)** + # 初次登录以获取 MFA 挑战 token(返回 401,并携带 mfaToken) + curl -X POST http://127.0.0.1:8080/api/auth/login \ + -H 'Content-Type: application/json' \ + -d '{"identifier":"demo@example.com","password":"Secret123"}' + + # 请求 TOTP 秘钥(返回二维码和 Base32 密钥) + curl -X POST http://127.0.0.1:8080/api/auth/mfa/totp/provision \ + -H 'Content-Type: application/json' \ + -d '{"token":""}' + + # 使用 oathtool 或 Google Authenticator 生成一次性验证码 + oathtool --totp -b + + # 验证并启用 MFA(首次会返回会话 token) + curl -X POST http://127.0.0.1:8080/api/auth/mfa/totp/verify \ + -H 'Content-Type: application/json' \ + -d '{"token":"","code":"123456"}' + + # 带口令 + TOTP 登录 + curl -X POST http://127.0.0.1:8080/api/auth/login \ + -H 'Content-Type: application/json' \ + -c cookies.txt \ + -d '{"identifier":"demo@example.com","password":"Secret123","totpCode":"123456"}' + + # 或使用邮箱 + TOTP 极简模式 + curl -X POST http://127.0.0.1:8080/api/auth/login \ + -H 'Content-Type: application/json' \ + -c cookies.txt \ + -d '{"identifier":"demo@example.com","totpCode":"123456"}' + + # 查看当前会话 + curl -b cookies.txt http://127.0.0.1:8080/api/auth/session + ``` + + 若需要重新绑定 MFA,可再次发起登录以获取新的 `mfaToken`,然后重复 `provision` → `verify` 流程;如需彻底重置,可在数据库中清理相关 MFA 字段后重新执行上述步骤。 + +## 3. 启用 HTTPS/TLS + +账号服务内置 TLS 支持,只要在配置文件中提供证书即可: + +```yaml +server: + addr: ":8443" + tls: + certFile: "/etc/ssl/certs/account.pem" + keyFile: "/etc/ssl/private/account.key" + redirectHttp: true +``` + +启动命令保持不变: + +```bash +go run ./account/cmd/accountsvc --config /path/to/secure-account.yaml +``` + +常见验证步骤: + +```bash +# 生成测试证书(示例) +openssl req -x509 -nodes -days 365 -newkey rsa:2048 \ + -keyout account.key -out account.crt \ + -subj "/CN=localhost" + +# 更新配置后启动服务 +ACCOUNT_CONFIG=/tmp/account-secure.yaml go run ./account/cmd/accountsvc --config $ACCOUNT_CONFIG + +# 使用 curl 验证 HTTPS(开发环境可加 -k 跳过校验) +curl -k https://127.0.0.1:8443/healthz +``` + +当 `redirectHttp` 为 `true` 时,服务会自动监听对应的 HTTP 端口(通常是 80),并将请求 301 重定向到 HTTPS,方便旧链接或未更新的客户端。 + +## 4. Docker 部署 + +1. **构建镜像(示例)** ```bash docker build -t xcontrol/account-service -f deploy/account/Dockerfile . ``` -2. **运行容器** +2. **运行容器(挂载配置与证书)** ```bash docker run -d \ --name account-service \ + -p 8443:8443 \ -p 8080:8080 \ - xcontrol/account-service + -v $(pwd)/account.yaml:/etc/xcontrol/account.yaml \ + -v $(pwd)/certs:/etc/ssl/xcontrol \ + xcontrol/account-service \ + --config /etc/xcontrol/account.yaml ``` + 如果未启用 `redirectHttp`,可省略 `-p 8080:8080`。 + 3. **查看日志** ```bash docker logs -f account-service ``` -若需与 PostgreSQL、Redis 集成,可通过环境变量或配置文件挂载方式将连接信息传入容器。 +确保容器内路径与配置文件中的 `certFile`/`keyFile` 一致,必要时可通过 Docker Secret 或 Kubernetes Secret 注入敏感文件。 -## 4. Kubernetes/Helm 部署(建议) +## 5. Kubernetes/Helm 部署 - 在 `deploy/account` 目录中维护 Helm Chart 或 Kustomize 模板,定义 Service、Deployment、ConfigMap 等资源。 - 关键参数: - 副本数 `replicaCount`,生产环境建议至少 2 个副本以实现高可用。 - 探针:配置 `livenessProbe` 与 `readinessProbe` 指向 `/healthz`。 - - 资源限制:根据用户规模设置 CPU/内存请求与限制。 - - Secret 管理:通过 Kubernetes Secret 注入数据库、缓存或第三方身份源的凭据。 + - 证书管理:使用 Secret 存储 TLS 证书与私钥,挂载到容器后与配置文件对应。 + - 数据库凭证:同样通过 Secret 注入 `ACCOUNT_STORE_DSN` 或配置文件。 -## 5. 灰度与回滚策略 +## 6. 灰度与回滚策略 -- 采用 RollingUpdate 策略滚动发布,确保新旧副本并行运行。 +- 建议采用 RollingUpdate 策略滚动发布,确保新旧副本并行运行。 - 配置 `maxUnavailable=0`、`maxSurge=1`(或按需调整),避免服务中断。 - 通过标记镜像版本或 Git Commit Hash 追踪上线版本,出问题时可快速回滚至上一版本。 -## 6. 监控与日志 +## 7. 监控与日志 - 日志:默认输出到标准输出,可挂载至日志采集系统(如 Loki、ELK)。 -- 指标:后续可集成 Prometheus 指标暴露,便于观察登录成功率、请求延迟、会话数量等关键指标。 -- 告警:基于探针失败、登录失败率飙升、token 生成错误等指标配置告警。 +- 指标:可在后续版本中集成 Prometheus 指标,关注登录成功率、MFA 启用率等核心指标。 +- 告警:基于探针失败、登录失败率飙升、TOTP 验证异常等指标配置告警策略。 -## 7. 安全加固建议 +## 8. 安全加固建议 - 在容器或集群层启用网络策略,仅开放必要端口。 -- 配置 HTTPS/TLS 网关,保证传输安全。 -- 对外部依赖(数据库、缓存)使用专用账号与最小权限策略。 -- 部署前进行漏洞扫描与依赖安全检查。 +- 对外提供服务时务必启用 HTTPS,保护登录口令与 TOTP 码。 +- 对数据库、证书等敏感资源使用最小权限原则,并定期轮换。 +- 定期回顾 `account/api/api_test.go` 中的场景测试,确保关键登录链路持续可用。 --- -以上步骤仅覆盖核心流程,实际生产部署需根据企业环境补充网络、合规等细节。 +以上步骤覆盖从开发到生产的核心流程,可根据企业环境补充额外的安全、审计或合规要求。 diff --git a/docs/api-endpoints.md b/docs/api-endpoints.md index a5d7c59..f29deb0 100644 --- a/docs/api-endpoints.md +++ b/docs/api-endpoints.md @@ -1,6 +1,81 @@ # API Endpoints -This document describes the HTTP endpoints provided by the XControl server. Each entry lists the request method and path, required parameters, and a sample curl command for verification. +This document describes the HTTP endpoints provided by the XControl platform. Each entry lists the request method and path, required parameters, and a sample curl command for verification. + +## Account Service(MFA/TLS 支持) + +The standalone account service exposes user registration, MFA provisioning, and login endpoints on its configured host (default `http://localhost:8080`). + +### POST /api/auth/register +- **Description:** Create a new local user with email/password credentials. +- **Body Parameters (JSON):** + - `name` – Display name. + - `email` – Unique email address. + - `password` – Password with at least 8 characters. +- **Test:** + ```bash + curl -X POST http://localhost:8080/api/auth/register \ + -H "Content-Type: application/json" \ + -d '{"name":"demo","email":"demo@example.com","password":"Secret123"}' + ``` + +### POST /api/auth/mfa/totp/provision +- **Description:** Issue a temporary TOTP secret (and QR code) for Google Authenticator binding. Requires an MFA challenge token returned by the login flow. +- **Body Parameters (JSON):** + - `token` – MFA challenge token obtained from a prior `/api/auth/login` attempt. + - `issuer` – Optional override for the TOTP issuer label. + - `account` – Optional override for the account label in authenticator apps. +- **Test:** + ```bash + curl -X POST http://localhost:8080/api/auth/mfa/totp/provision \ + -H "Content-Type: application/json" \ + -d '{"token":""}' + ``` + +### POST /api/auth/mfa/totp/verify +- **Description:** Confirm the generated one-time passcode and activate MFA for the user. +- **Body Parameters (JSON):** + - `token` – MFA challenge token used during provisioning. + - `code` – 6-digit TOTP from Google Authenticator/oathtool. +- **Test:** + ```bash + curl -X POST http://localhost:8080/api/auth/mfa/totp/verify \ + -H "Content-Type: application/json" \ + -d '{"token":"","code":"123456"}' + ``` + +### POST /api/auth/login +- **Description:** Issue a session cookie after validating credentials and MFA. The first request after registration returns `401 mfa_setup_required` with an `mfaToken` used for provisioning. Once MFA is enabled, supports both password+TOTP and email+TOTP-only flows. +- **Body Parameters (JSON):** + - `identifier` – Email or username. + - `password` – Optional when performing email+TOTP-only login. + - `totpCode` – Required once MFA is enabled. +- **Test:** + ```bash + curl -X POST http://localhost:8080/api/auth/login \ + -H "Content-Type: application/json" \ + -c cookies.txt \ + -d '{"identifier":"demo@example.com","password":"Secret123","totpCode":"123456"}' + ``` + +### GET /api/auth/mfa/status +- **Description:** Inspect MFA status for a user using either a session token or the pending `mfaToken`. +- **Parameters:** + - Query `token` or header `X-MFA-Token` when checking a pending MFA challenge. +- **Test:** + ```bash + curl "http://localhost:8080/api/auth/mfa/status?token=" + ``` + +### GET /api/auth/session +- **Description:** Return sanitized user information for the active session, including MFA status. +- **Headers:** `Cookie` header with `account_session` value. +- **Test:** + ```bash + curl -b cookies.txt http://localhost:8080/api/auth/session + ``` + +> **TLS note:** When `accountsvc` is started with certificates, replace `http://` with `https://` and add `-k` for curl if using self-signed certificates during development. ## GET /api/users - **Description:** Return all users. diff --git a/go.mod b/go.mod index f4b2ca5..2f3be42 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/google/uuid v1.6.0 github.com/jackc/pgx/v5 v5.7.5 github.com/pgvector/pgvector-go v0.3.0 + github.com/pquerna/otp v1.5.0 github.com/redis/go-redis/v9 v9.12.0 github.com/spf13/cobra v1.9.1 github.com/yuin/goldmark v1.7.13 @@ -23,6 +24,7 @@ require ( dario.cat/mergo v1.0.0 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect github.com/ProtonMail/go-crypto v1.1.6 // indirect + github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect github.com/bytedance/sonic v1.9.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect diff --git a/go.sum b/go.sum index 4d6b9ad..110d8db 100644 --- a/go.sum +++ b/go.sum @@ -11,6 +11,8 @@ github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFI github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= +github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI= +github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= @@ -129,6 +131,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs= +github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg= github.com/redis/go-redis/v9 v9.12.0 h1:XlVPGlflh4nxfhsNXPA8Qp6EmEfTo0rp8oaBzPipXnU= github.com/redis/go-redis/v9 v9.12.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= diff --git a/ui/homepage/app/api/auth/login/route.ts b/ui/homepage/app/api/auth/login/route.ts index 6b2ac47..e687040 100644 --- a/ui/homepage/app/api/auth/login/route.ts +++ b/ui/homepage/app/api/auth/login/route.ts @@ -4,19 +4,27 @@ import { getAccountServiceBaseUrl } from '@lib/serviceConfig' const ACCOUNT_SERVICE_URL = getAccountServiceBaseUrl() const SESSION_COOKIE_NAME = 'account_session' +const MFA_COOKIE_NAME = 'account_mfa_token' -async function authenticateWithAccountService(username: string, password: string) { +type AccountLoginResponse = { + token?: string + error?: string + mfaToken?: string + message?: string +} + +async function authenticateWithAccountService(payload: Record) { try { const response = await fetch(`${ACCOUNT_SERVICE_URL}/api/auth/login`, { method: 'POST', headers: { 'Content-Type': 'application/json', }, - body: JSON.stringify({ username, password }), + body: JSON.stringify(payload), cache: 'no-store', }) - const data = await response.json().catch(() => ({})) + const data = (await response.json().catch(() => ({}))) as AccountLoginResponse return { response, data } } catch (error) { console.error('Login request failed', error) @@ -43,17 +51,28 @@ export async function POST(request: NextRequest) { const { credentials, remember } = await extractCredentials(request) - if (!credentials.username || !credentials.password) { + if (!credentials.identifier) { return handleErrorResponse(request, 'missing_credentials') } - const { response, data } = await authenticateWithAccountService( - credentials.username, - credentials.password, - ) + const { response, data } = await authenticateWithAccountService(credentials) if (!response || !response.ok || !data?.token) { const message = typeof data?.error === 'string' ? data.error : 'invalid_credentials' - return handleErrorResponse(request, message) + const errorResponse = handleErrorResponse(request, message, data) + if (message === 'mfa_setup_required' && typeof data?.mfaToken === 'string') { + errorResponse.cookies.set({ + name: MFA_COOKIE_NAME, + value: data.mfaToken, + httpOnly: true, + sameSite: 'lax', + secure: process.env.NODE_ENV === 'production', + path: '/', + maxAge: 60 * 10, + }) + } else { + errorResponse.cookies.set({ name: MFA_COOKIE_NAME, value: '', maxAge: 0, path: '/' }) + } + return errorResponse } const cookieMaxAge = remember ? 60 * 60 * 24 * 30 : 60 * 60 * 24 @@ -75,15 +94,11 @@ export async function POST(request: NextRequest) { maxAge: cookieMaxAge, path: '/', }) + successResponse.cookies.set({ name: MFA_COOKIE_NAME, value: '', maxAge: 0, path: '/' }) return successResponse } -type CredentialPayload = { - username: string - password: string -} - function prefersJson(request: NextRequest) { const accept = request.headers.get('accept')?.toLowerCase() ?? '' const contentType = request.headers.get('content-type')?.toLowerCase() ?? '' @@ -99,34 +114,47 @@ async function extractCredentials(request: NextRequest) { } return { credentials: { - username: String(body?.username ?? '').trim(), + identifier: normalizeIdentifier(body), password: String(body?.password ?? ''), + totpCode: String(body?.totpCode ?? '').trim(), }, remember: Boolean(body?.remember), } } const formData = await request.formData() - const username = String(formData.get('username') ?? '').trim() + const identifier = normalizeIdentifier({ + identifier: formData.get('identifier'), + username: formData.get('username'), + email: formData.get('email'), + }) const password = String(formData.get('password') ?? '') + const totpCode = String(formData.get('totpCode') ?? '').trim() const remember = formData.get('remember') === 'on' return { - credentials: { username, password }, + credentials: { identifier, password, totpCode }, remember, } } -function handleErrorResponse(request: NextRequest, errorCode: string) { +function handleErrorResponse( + request: NextRequest, + errorCode: string, + data?: AccountLoginResponse, +) { if (prefersJson(request)) { const statusMap: Record = { user_not_found: 404, invalid_credentials: 401, missing_credentials: 400, credentials_in_query: 400, + mfa_setup_required: 401, + mfa_code_required: 400, } return NextResponse.json( { error: errorCode, + mfaToken: data?.mfaToken, }, { status: statusMap[errorCode] ?? 400 }, ) @@ -136,3 +164,20 @@ function handleErrorResponse(request: NextRequest, errorCode: string) { redirectURL.searchParams.set('error', errorCode) return NextResponse.redirect(redirectURL, { status: 303 }) } + +type CredentialPayload = { + identifier?: string + username?: string + email?: string + password?: string + totpCode?: string + remember?: boolean +} + +function normalizeIdentifier(payload: Partial) { + const candidate = + String(payload?.identifier ?? '').trim() || + String(payload?.username ?? '').trim() || + String(payload?.email ?? '').trim() + return candidate +} diff --git a/ui/homepage/app/api/auth/mfa/status/route.ts b/ui/homepage/app/api/auth/mfa/status/route.ts new file mode 100644 index 0000000..51f907f --- /dev/null +++ b/ui/homepage/app/api/auth/mfa/status/route.ts @@ -0,0 +1,38 @@ +import { cookies } from 'next/headers' +import { NextRequest, NextResponse } from 'next/server' + +import { getAccountServiceBaseUrl } from '@lib/serviceConfig' + +const ACCOUNT_SERVICE_URL = getAccountServiceBaseUrl() +const SESSION_COOKIE_NAME = 'account_session' +const MFA_COOKIE_NAME = 'account_mfa_token' + +export async function GET(request: NextRequest) { + const cookieStore = cookies() + const sessionToken = cookieStore.get(SESSION_COOKIE_NAME)?.value ?? '' + const storedMfaToken = cookieStore.get(MFA_COOKIE_NAME)?.value ?? '' + + const url = new URL(request.url) + const queryToken = String(url.searchParams.get('token') ?? '').trim() + const token = queryToken || storedMfaToken + + const headers: Record = { + Accept: 'application/json', + } + if (sessionToken) { + headers.Authorization = `Bearer ${sessionToken}` + } + + const endpoint = token + ? `${ACCOUNT_SERVICE_URL}/api/auth/mfa/status?token=${encodeURIComponent(token)}` + : `${ACCOUNT_SERVICE_URL}/api/auth/mfa/status` + + const response = await fetch(endpoint, { + method: 'GET', + headers, + cache: 'no-store', + }) + + const payload = await response.json().catch(() => ({})) + return NextResponse.json(payload, { status: response.status }) +} diff --git a/ui/homepage/app/api/auth/mfa/totp/provision/route.ts b/ui/homepage/app/api/auth/mfa/totp/provision/route.ts new file mode 100644 index 0000000..4b07c89 --- /dev/null +++ b/ui/homepage/app/api/auth/mfa/totp/provision/route.ts @@ -0,0 +1,51 @@ +import { cookies } from 'next/headers' +import { NextRequest, NextResponse } from 'next/server' + +import { getAccountServiceBaseUrl } from '@lib/serviceConfig' + +const ACCOUNT_SERVICE_URL = getAccountServiceBaseUrl() +const MFA_COOKIE_NAME = 'account_mfa_token' + +export async function POST(request: NextRequest) { + const cookieStore = cookies() + const currentToken = cookieStore.get(MFA_COOKIE_NAME)?.value ?? '' + const body = (await request.json().catch(() => ({}))) as { + token?: string + issuer?: string + account?: string + } + + const token = String(body?.token ?? currentToken ?? '').trim() + if (!token) { + return NextResponse.json({ error: 'mfa_token_required' }, { status: 400 }) + } + + const response = await fetch(`${ACCOUNT_SERVICE_URL}/api/auth/mfa/totp/provision`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ token, issuer: body?.issuer, account: body?.account }), + cache: 'no-store', + }) + + const payload = await response.json().catch(() => ({})) + if (!response.ok) { + return NextResponse.json(payload, { status: response.status }) + } + + const res = NextResponse.json(payload) + if (!currentToken || currentToken !== token) { + res.cookies.set({ + name: MFA_COOKIE_NAME, + value: token, + httpOnly: true, + sameSite: 'lax', + secure: process.env.NODE_ENV === 'production', + path: '/', + maxAge: 60 * 10, + }) + } + + return res +} diff --git a/ui/homepage/app/api/auth/mfa/totp/verify/route.ts b/ui/homepage/app/api/auth/mfa/totp/verify/route.ts new file mode 100644 index 0000000..eaa2466 --- /dev/null +++ b/ui/homepage/app/api/auth/mfa/totp/verify/route.ts @@ -0,0 +1,60 @@ +import { cookies } from 'next/headers' +import { NextRequest, NextResponse } from 'next/server' + +import { getAccountServiceBaseUrl } from '@lib/serviceConfig' + +const ACCOUNT_SERVICE_URL = getAccountServiceBaseUrl() +const SESSION_COOKIE_NAME = 'account_session' +const MFA_COOKIE_NAME = 'account_mfa_token' + +export async function POST(request: NextRequest) { + const cookieStore = cookies() + const currentToken = cookieStore.get(MFA_COOKIE_NAME)?.value ?? '' + const body = (await request.json().catch(() => ({}))) as { + token?: string + code?: string + } + + const token = String(body?.token ?? currentToken ?? '').trim() + const code = String(body?.code ?? '').trim() + + if (!token) { + return NextResponse.json({ error: 'mfa_token_required' }, { status: 400 }) + } + if (!code) { + return NextResponse.json({ error: 'mfa_code_required' }, { status: 400 }) + } + + const response = await fetch(`${ACCOUNT_SERVICE_URL}/api/auth/mfa/totp/verify`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ token, code }), + cache: 'no-store', + }) + + const payload = await response.json().catch(() => ({})) + if (!response.ok || !payload?.token) { + return NextResponse.json(payload, { status: response.status }) + } + + const res = NextResponse.json(payload) + const expiresAt = typeof payload?.expiresAt === 'string' ? Date.parse(payload.expiresAt) : NaN + const ttl = Number.isFinite(expiresAt) + ? Math.max(60, Math.floor((expiresAt - Date.now()) / 1000)) + : 60 * 60 * 24 + + res.cookies.set({ + name: SESSION_COOKIE_NAME, + value: String(payload.token), + httpOnly: true, + sameSite: 'lax', + secure: process.env.NODE_ENV === 'production', + path: '/', + maxAge: ttl, + }) + res.cookies.set({ name: MFA_COOKIE_NAME, value: '', maxAge: 0, path: '/' }) + + return res +} diff --git a/ui/homepage/app/api/auth/session/route.ts b/ui/homepage/app/api/auth/session/route.ts index 71302ae..6353919 100644 --- a/ui/homepage/app/api/auth/session/route.ts +++ b/ui/homepage/app/api/auth/session/route.ts @@ -12,6 +12,13 @@ type AccountUser = { name?: string username?: string email: string + mfaEnabled?: boolean + mfa?: { + totpEnabled?: boolean + totpPending?: boolean + totpSecretIssuedAt?: string + totpConfirmedAt?: string + } } async function fetchSession(token: string) { diff --git a/ui/homepage/app/login/LoginForm.tsx b/ui/homepage/app/login/LoginForm.tsx index 626cb47..a44da39 100644 --- a/ui/homepage/app/login/LoginForm.tsx +++ b/ui/homepage/app/login/LoginForm.tsx @@ -15,8 +15,10 @@ export function LoginForm() { const authCopy = translations[language].auth.login const navCopy = translations[language].nav.account const { user, login } = useUser() - const [username, setUsername] = useState('') + const [identifier, setIdentifier] = useState('') const [password, setPassword] = useState('') + const [totpCode, setTotpCode] = useState('') + const [loginMode, setLoginMode] = useState<'password_totp' | 'email_totp'>('password_totp') const [remember, setRemember] = useState(false) const [error, setError] = useState(null) const [isSubmitting, setIsSubmitting] = useState(false) @@ -24,15 +26,19 @@ export function LoginForm() { const handleSubmit = async (event: FormEvent) => { event.preventDefault() - const trimmedUsername = username.trim() - if (!trimmedUsername) { + const trimmedIdentifier = identifier.trim() + if (!trimmedIdentifier) { setError(pageCopy.missingUsername) return } - if (!password) { + if (loginMode === 'password_totp' && !password) { setError(pageCopy.missingPassword) return } + if (!totpCode.trim()) { + setError(pageCopy.missingTotp ?? authCopy.alerts.mfa.missing) + return + } setError(null) setIsSubmitting(true) @@ -43,12 +49,17 @@ export function LoginForm() { 'Content-Type': 'application/json', Accept: 'application/json', }, - body: JSON.stringify({ username: trimmedUsername, password, remember }), + body: JSON.stringify({ + identifier: trimmedIdentifier, + password: loginMode === 'password_totp' ? password : undefined, + totpCode: totpCode.trim(), + remember, + }), credentials: 'include', }) if (!response.ok) { - const payload = (await response.json().catch(() => ({}))) as { error?: string } + const payload = (await response.json().catch(() => ({}))) as { error?: string; mfaToken?: string } const messageKey = payload.error ?? 'generic_error' switch (messageKey) { case 'missing_credentials': @@ -60,6 +71,19 @@ export function LoginForm() { case 'user_not_found': setError(pageCopy.userNotFound) break + case 'mfa_code_required': + setError(authCopy.alerts.mfa.missing) + break + case 'invalid_mfa_code': + setError(authCopy.alerts.mfa.invalid) + break + case 'mfa_setup_required': + if (typeof window !== 'undefined' && typeof payload.mfaToken === 'string') { + sessionStorage.setItem('account_mfa_token', payload.mfaToken) + } + router.replace('/panel/account?setupMfa=1') + router.refresh() + return default: setError(pageCopy.genericError) break @@ -116,37 +140,80 @@ export function LoginForm() { {!user ? (
-
-
-
-