fix: move mfa/status endpoint outside auth middleware and implement persistent session storage

- Moved /api/auth/mfa/status outside authProtected group to allow pre-login MFA checks
- Added session management to Store interface with CreateSession, GetSession, DeleteSession
- Implemented session persistence in both memoryStore and postgresStore
- Updated handler to use store-based sessions instead of in-memory map
- Added database schema for users, sessions, agents, and email_blacklist tables
- This fixes the 401 error when checking MFA status before login
This commit is contained in:
Haitao Pan 2026-02-05 09:37:04 +08:00
parent 29bb103aa3
commit bc2562b877
4 changed files with 123 additions and 18 deletions

View File

@ -44,7 +44,6 @@ type session struct {
type handler struct {
store store.Store
sessions map[string]session
mu sync.RWMutex
sessionTTL time.Duration
mfaChallenges map[string]mfaChallenge
@ -207,7 +206,6 @@ func WithOAuthFrontendURL(url string) Option {
func RegisterRoutes(r *gin.Engine, opts ...Option) {
h := &handler{
store: store.NewMemoryStore(),
sessions: make(map[string]session),
sessionTTL: defaultSessionTTL,
mfaChallenges: make(map[string]mfaChallenge),
mfaChallengeTTL: defaultMFAChallengeTTL,
@ -247,6 +245,8 @@ func RegisterRoutes(r *gin.Engine, opts ...Option) {
// Token refresh endpoint - generates new access token using refresh token
authGroup.POST("/token/refresh", h.refreshToken)
authGroup.GET("/mfa/status", h.mfaStatus)
// Protected routes requiring authentication
authProtected := authGroup.Group("")
if h.tokenService != nil {
@ -260,7 +260,6 @@ func RegisterRoutes(r *gin.Engine, opts ...Option) {
authProtected.POST("/mfa/totp/provision", h.provisionTOTP)
authProtected.POST("/mfa/totp/verify", h.verifyTOTP)
authProtected.POST("/mfa/disable", h.disableMFA)
authProtected.GET("/mfa/status", h.mfaStatus)
authProtected.POST("/password/reset", h.requestPasswordReset)
authProtected.POST("/password/reset/confirm", h.confirmPasswordReset)
@ -1246,9 +1245,9 @@ func (h *handler) createSession(userID string) (string, time.Time, error) {
}
expiresAt := time.Now().Add(ttl)
h.mu.Lock()
defer h.mu.Unlock()
h.sessions[token] = session{userID: userID, expiresAt: expiresAt}
if err := h.store.CreateSession(context.Background(), token, userID, expiresAt); err != nil {
return "", time.Time{}, err
}
return token, expiresAt, nil
}
@ -1263,23 +1262,15 @@ func (h *handler) setSessionCookie(c *gin.Context, token string, expiresAt time.
}
func (h *handler) lookupSession(token string) (session, bool) {
h.mu.RLock()
sess, ok := h.sessions[token]
h.mu.RUnlock()
if !ok {
userID, expiresAt, err := h.store.GetSession(context.Background(), token)
if err != nil {
return session{}, false
}
if time.Now().After(sess.expiresAt) {
h.removeSession(token)
return session{}, false
}
return sess, true
return session{userID: userID, expiresAt: expiresAt}, true
}
func (h *handler) removeSession(token string) {
h.mu.Lock()
delete(h.sessions, token)
h.mu.Unlock()
h.store.DeleteSession(context.Background(), token)
}
func (h *handler) newRandomToken() (string, error) {

View File

@ -428,6 +428,47 @@ func applyRBACSchema(ctx context.Context, db *gorm.DB, driver string) error {
}
statements := []string{
`CREATE TABLE IF NOT EXISTS public.users (
uuid UUID PRIMARY KEY DEFAULT gen_random_uuid(),
username TEXT NOT NULL UNIQUE,
email TEXT NOT NULL UNIQUE,
email_verified BOOLEAN NOT NULL DEFAULT FALSE,
password TEXT NOT NULL,
mfa_totp_secret TEXT,
mfa_enabled BOOLEAN NOT NULL DEFAULT FALSE,
mfa_secret_issued_at TIMESTAMPTZ,
mfa_confirmed_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
level INTEGER NOT NULL DEFAULT 20,
role TEXT NOT NULL DEFAULT 'user',
groups JSONB NOT NULL DEFAULT '[]'::jsonb,
permissions JSONB NOT NULL DEFAULT '[]'::jsonb,
active BOOLEAN NOT NULL DEFAULT TRUE,
proxy_uuid UUID NOT NULL DEFAULT gen_random_uuid(),
proxy_uuid_expires_at TIMESTAMPTZ
)`,
`CREATE TABLE IF NOT EXISTS public.email_blacklist (
email TEXT PRIMARY KEY,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
)`,
`CREATE TABLE IF NOT EXISTS public.agents (
id TEXT PRIMARY KEY,
name TEXT NOT NULL DEFAULT '',
groups JSONB NOT NULL DEFAULT '[]'::jsonb,
healthy BOOLEAN NOT NULL DEFAULT FALSE,
last_heartbeat TIMESTAMPTZ,
clients_count INTEGER NOT NULL DEFAULT 0,
sync_revision TEXT NOT NULL DEFAULT '',
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
)`,
`CREATE TABLE IF NOT EXISTS public.sessions (
token TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
expires_at TIMESTAMPTZ NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
)`,
`CREATE TABLE IF NOT EXISTS public.rbac_roles (
role_key TEXT PRIMARY KEY,
description TEXT NOT NULL DEFAULT '',

View File

@ -1322,3 +1322,32 @@ func (s *postgresStore) DeleteStaleAgents(ctx context.Context, staleThreshold ti
count, _ := result.RowsAffected()
return int(count), nil
}
func (s *postgresStore) CreateSession(ctx context.Context, token, userID string, expiresAt time.Time) error {
const query = "INSERT INTO sessions (token, user_id, expires_at) VALUES ($1, $2, $3) ON CONFLICT (token) DO UPDATE SET user_id = EXCLUDED.user_id, expires_at = EXCLUDED.expires_at"
_, err := s.db.ExecContext(ctx, query, token, userID, expiresAt.UTC())
return err
}
func (s *postgresStore) GetSession(ctx context.Context, token string) (string, time.Time, error) {
const query = "SELECT user_id, expires_at FROM sessions WHERE token = $1"
var userID string
var expiresAt time.Time
err := s.db.QueryRowContext(ctx, query, token).Scan(&userID, &expiresAt)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return "", time.Time{}, ErrSessionNotFound
}
return "", time.Time{}, err
}
if time.Now().After(expiresAt) {
return "", time.Time{}, ErrSessionNotFound
}
return userID, expiresAt.UTC(), nil
}
func (s *postgresStore) DeleteSession(ctx context.Context, token string) error {
const query = "DELETE FROM sessions WHERE token = $1"
_, err := s.db.ExecContext(ctx, query, token)
return err
}

View File

@ -94,6 +94,11 @@ type Store interface {
IsBlacklisted(ctx context.Context, email string) (bool, error)
ListBlacklist(ctx context.Context) ([]string, error)
// Session management
CreateSession(ctx context.Context, token, userID string, expiresAt time.Time) error
GetSession(ctx context.Context, token string) (string, time.Time, error)
DeleteSession(ctx context.Context, token string) error
// Agent management
UpsertAgent(ctx context.Context, agent *Agent) error
GetAgent(ctx context.Context, id string) (*Agent, error)
@ -125,8 +130,16 @@ type memoryStore struct {
subscriptions map[string]map[string]*Subscription
identities map[string]*Identity
agents map[string]*Agent
sessions map[string]*sessionRecord
}
type sessionRecord struct {
UserID string
ExpiresAt time.Time
}
var ErrSessionNotFound = errors.New("session not found")
// NewMemoryStore creates a new in-memory store implementation with super
// administrator counting disabled by default to avoid accidental exposure of
// privileged metadata in environments where the caller has not explicitly
@ -151,6 +164,7 @@ func newMemoryStore(allowSuperAdminCounting bool) Store {
subscriptions: make(map[string]map[string]*Subscription),
identities: make(map[string]*Identity),
agents: make(map[string]*Agent),
sessions: make(map[string]*sessionRecord),
}
}
@ -805,3 +819,33 @@ func (s *memoryStore) DeleteStaleAgents(ctx context.Context, staleThreshold time
}
return count, nil
}
func (s *memoryStore) CreateSession(ctx context.Context, token, userID string, expiresAt time.Time) error {
s.mu.Lock()
defer s.mu.Unlock()
s.sessions[token] = &sessionRecord{
UserID: userID,
ExpiresAt: expiresAt,
}
return nil
}
func (s *memoryStore) GetSession(ctx context.Context, token string) (string, time.Time, error) {
s.mu.RLock()
defer s.mu.RUnlock()
sess, ok := s.sessions[token]
if !ok {
return "", time.Time{}, ErrSessionNotFound
}
if time.Now().After(sess.ExpiresAt) {
return "", time.Time{}, ErrSessionNotFound
}
return sess.UserID, sess.ExpiresAt, nil
}
func (s *memoryStore) DeleteSession(ctx context.Context, token string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessions, token)
return nil
}