From bc2562b8772078e5a2100c78ce49af8140f2f012 Mon Sep 17 00:00:00 2001 From: Haitao Pan Date: Thu, 5 Feb 2026 09:37:04 +0800 Subject: [PATCH] fix: move mfa/status endpoint outside auth middleware and implement persistent session storage - Moved /api/auth/mfa/status outside authProtected group to allow pre-login MFA checks - Added session management to Store interface with CreateSession, GetSession, DeleteSession - Implemented session persistence in both memoryStore and postgresStore - Updated handler to use store-based sessions instead of in-memory map - Added database schema for users, sessions, agents, and email_blacklist tables - This fixes the 401 error when checking MFA status before login --- api/api.go | 27 ++++++++--------------- cmd/accountsvc/main.go | 41 +++++++++++++++++++++++++++++++++++ internal/store/postgres.go | 29 +++++++++++++++++++++++++ internal/store/store.go | 44 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 123 insertions(+), 18 deletions(-) diff --git a/api/api.go b/api/api.go index 6611f7c..348aca5 100644 --- a/api/api.go +++ b/api/api.go @@ -44,7 +44,6 @@ type session struct { type handler struct { store store.Store - sessions map[string]session mu sync.RWMutex sessionTTL time.Duration mfaChallenges map[string]mfaChallenge @@ -207,7 +206,6 @@ func WithOAuthFrontendURL(url string) Option { func RegisterRoutes(r *gin.Engine, opts ...Option) { h := &handler{ store: store.NewMemoryStore(), - sessions: make(map[string]session), sessionTTL: defaultSessionTTL, mfaChallenges: make(map[string]mfaChallenge), mfaChallengeTTL: defaultMFAChallengeTTL, @@ -247,6 +245,8 @@ func RegisterRoutes(r *gin.Engine, opts ...Option) { // Token refresh endpoint - generates new access token using refresh token authGroup.POST("/token/refresh", h.refreshToken) + authGroup.GET("/mfa/status", h.mfaStatus) + // Protected routes requiring authentication authProtected := authGroup.Group("") if h.tokenService != nil { @@ -260,7 +260,6 @@ func RegisterRoutes(r *gin.Engine, opts ...Option) { authProtected.POST("/mfa/totp/provision", h.provisionTOTP) authProtected.POST("/mfa/totp/verify", h.verifyTOTP) authProtected.POST("/mfa/disable", h.disableMFA) - authProtected.GET("/mfa/status", h.mfaStatus) authProtected.POST("/password/reset", h.requestPasswordReset) authProtected.POST("/password/reset/confirm", h.confirmPasswordReset) @@ -1246,9 +1245,9 @@ func (h *handler) createSession(userID string) (string, time.Time, error) { } expiresAt := time.Now().Add(ttl) - h.mu.Lock() - defer h.mu.Unlock() - h.sessions[token] = session{userID: userID, expiresAt: expiresAt} + if err := h.store.CreateSession(context.Background(), token, userID, expiresAt); err != nil { + return "", time.Time{}, err + } return token, expiresAt, nil } @@ -1263,23 +1262,15 @@ func (h *handler) setSessionCookie(c *gin.Context, token string, expiresAt time. } func (h *handler) lookupSession(token string) (session, bool) { - h.mu.RLock() - sess, ok := h.sessions[token] - h.mu.RUnlock() - if !ok { + userID, expiresAt, err := h.store.GetSession(context.Background(), token) + if err != nil { return session{}, false } - if time.Now().After(sess.expiresAt) { - h.removeSession(token) - return session{}, false - } - return sess, true + return session{userID: userID, expiresAt: expiresAt}, true } func (h *handler) removeSession(token string) { - h.mu.Lock() - delete(h.sessions, token) - h.mu.Unlock() + h.store.DeleteSession(context.Background(), token) } func (h *handler) newRandomToken() (string, error) { diff --git a/cmd/accountsvc/main.go b/cmd/accountsvc/main.go index 6e16dad..655fcdb 100644 --- a/cmd/accountsvc/main.go +++ b/cmd/accountsvc/main.go @@ -428,6 +428,47 @@ func applyRBACSchema(ctx context.Context, db *gorm.DB, driver string) error { } statements := []string{ + `CREATE TABLE IF NOT EXISTS public.users ( + uuid UUID PRIMARY KEY DEFAULT gen_random_uuid(), + username TEXT NOT NULL UNIQUE, + email TEXT NOT NULL UNIQUE, + email_verified BOOLEAN NOT NULL DEFAULT FALSE, + password TEXT NOT NULL, + mfa_totp_secret TEXT, + mfa_enabled BOOLEAN NOT NULL DEFAULT FALSE, + mfa_secret_issued_at TIMESTAMPTZ, + mfa_confirmed_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + level INTEGER NOT NULL DEFAULT 20, + role TEXT NOT NULL DEFAULT 'user', + groups JSONB NOT NULL DEFAULT '[]'::jsonb, + permissions JSONB NOT NULL DEFAULT '[]'::jsonb, + active BOOLEAN NOT NULL DEFAULT TRUE, + proxy_uuid UUID NOT NULL DEFAULT gen_random_uuid(), + proxy_uuid_expires_at TIMESTAMPTZ +)`, + `CREATE TABLE IF NOT EXISTS public.email_blacklist ( + email TEXT PRIMARY KEY, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +)`, + `CREATE TABLE IF NOT EXISTS public.agents ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL DEFAULT '', + groups JSONB NOT NULL DEFAULT '[]'::jsonb, + healthy BOOLEAN NOT NULL DEFAULT FALSE, + last_heartbeat TIMESTAMPTZ, + clients_count INTEGER NOT NULL DEFAULT 0, + sync_revision TEXT NOT NULL DEFAULT '', + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +)`, + `CREATE TABLE IF NOT EXISTS public.sessions ( + token TEXT PRIMARY KEY, + user_id TEXT NOT NULL, + expires_at TIMESTAMPTZ NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +)`, `CREATE TABLE IF NOT EXISTS public.rbac_roles ( role_key TEXT PRIMARY KEY, description TEXT NOT NULL DEFAULT '', diff --git a/internal/store/postgres.go b/internal/store/postgres.go index 61ed2e8..e556852 100644 --- a/internal/store/postgres.go +++ b/internal/store/postgres.go @@ -1322,3 +1322,32 @@ func (s *postgresStore) DeleteStaleAgents(ctx context.Context, staleThreshold ti count, _ := result.RowsAffected() return int(count), nil } + +func (s *postgresStore) CreateSession(ctx context.Context, token, userID string, expiresAt time.Time) error { + const query = "INSERT INTO sessions (token, user_id, expires_at) VALUES ($1, $2, $3) ON CONFLICT (token) DO UPDATE SET user_id = EXCLUDED.user_id, expires_at = EXCLUDED.expires_at" + _, err := s.db.ExecContext(ctx, query, token, userID, expiresAt.UTC()) + return err +} + +func (s *postgresStore) GetSession(ctx context.Context, token string) (string, time.Time, error) { + const query = "SELECT user_id, expires_at FROM sessions WHERE token = $1" + var userID string + var expiresAt time.Time + err := s.db.QueryRowContext(ctx, query, token).Scan(&userID, &expiresAt) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return "", time.Time{}, ErrSessionNotFound + } + return "", time.Time{}, err + } + if time.Now().After(expiresAt) { + return "", time.Time{}, ErrSessionNotFound + } + return userID, expiresAt.UTC(), nil +} + +func (s *postgresStore) DeleteSession(ctx context.Context, token string) error { + const query = "DELETE FROM sessions WHERE token = $1" + _, err := s.db.ExecContext(ctx, query, token) + return err +} diff --git a/internal/store/store.go b/internal/store/store.go index 5d2d5e1..28556d3 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -94,6 +94,11 @@ type Store interface { IsBlacklisted(ctx context.Context, email string) (bool, error) ListBlacklist(ctx context.Context) ([]string, error) + // Session management + CreateSession(ctx context.Context, token, userID string, expiresAt time.Time) error + GetSession(ctx context.Context, token string) (string, time.Time, error) + DeleteSession(ctx context.Context, token string) error + // Agent management UpsertAgent(ctx context.Context, agent *Agent) error GetAgent(ctx context.Context, id string) (*Agent, error) @@ -125,8 +130,16 @@ type memoryStore struct { subscriptions map[string]map[string]*Subscription identities map[string]*Identity agents map[string]*Agent + sessions map[string]*sessionRecord } +type sessionRecord struct { + UserID string + ExpiresAt time.Time +} + +var ErrSessionNotFound = errors.New("session not found") + // NewMemoryStore creates a new in-memory store implementation with super // administrator counting disabled by default to avoid accidental exposure of // privileged metadata in environments where the caller has not explicitly @@ -151,6 +164,7 @@ func newMemoryStore(allowSuperAdminCounting bool) Store { subscriptions: make(map[string]map[string]*Subscription), identities: make(map[string]*Identity), agents: make(map[string]*Agent), + sessions: make(map[string]*sessionRecord), } } @@ -805,3 +819,33 @@ func (s *memoryStore) DeleteStaleAgents(ctx context.Context, staleThreshold time } return count, nil } + +func (s *memoryStore) CreateSession(ctx context.Context, token, userID string, expiresAt time.Time) error { + s.mu.Lock() + defer s.mu.Unlock() + s.sessions[token] = &sessionRecord{ + UserID: userID, + ExpiresAt: expiresAt, + } + return nil +} + +func (s *memoryStore) GetSession(ctx context.Context, token string) (string, time.Time, error) { + s.mu.RLock() + defer s.mu.RUnlock() + sess, ok := s.sessions[token] + if !ok { + return "", time.Time{}, ErrSessionNotFound + } + if time.Now().After(sess.ExpiresAt) { + return "", time.Time{}, ErrSessionNotFound + } + return sess.UserID, sess.ExpiresAt, nil +} + +func (s *memoryStore) DeleteSession(ctx context.Context, token string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.sessions, token) + return nil +}