accounts/internal/store/postgres.go

1238 lines
31 KiB
Go

package store
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"strconv"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
_ "github.com/jackc/pgx/v5/stdlib"
)
// Config describes how to construct a Store implementation.
type Config struct {
Driver string
DSN string
MaxOpenConns int
MaxIdleConns int
ConnMaxLifetime time.Duration
ConnMaxIdleTime time.Duration
AllowSuperAdminCounting bool
}
// New creates a Store implementation based on the provided configuration.
func New(ctx context.Context, cfg Config) (Store, func(context.Context) error, error) {
driver := strings.ToLower(strings.TrimSpace(cfg.Driver))
if driver == "" || driver == "memory" {
ms := newMemoryStore(cfg.AllowSuperAdminCounting)
return ms, func(context.Context) error { return nil }, nil
}
switch driver {
case "postgres", "postgresql", "pgx":
if strings.TrimSpace(cfg.DSN) == "" {
return nil, nil, errors.New("store dsn is required for postgres driver")
}
db, err := sql.Open("pgx", cfg.DSN)
if err != nil {
return nil, nil, err
}
if cfg.MaxOpenConns > 0 {
db.SetMaxOpenConns(cfg.MaxOpenConns)
}
if cfg.MaxIdleConns > 0 {
db.SetMaxIdleConns(cfg.MaxIdleConns)
}
if cfg.ConnMaxLifetime > 0 {
db.SetConnMaxLifetime(cfg.ConnMaxLifetime)
}
if cfg.ConnMaxIdleTime > 0 {
db.SetConnMaxIdleTime(cfg.ConnMaxIdleTime)
}
if err := db.PingContext(ctx); err != nil {
db.Close()
return nil, nil, err
}
cleanup := func(context.Context) error {
return db.Close()
}
return &postgresStore{db: db, allowSuperAdminCounting: cfg.AllowSuperAdminCounting}, cleanup, nil
default:
return nil, nil, fmt.Errorf("unsupported store driver %q", cfg.Driver)
}
}
type schemaCapabilities struct {
hasMFATOTPSecret bool
hasMFAEnabled bool
hasMFASecretIssuedAt bool
hasMFAConfirmedAt bool
hasCreatedAt bool
hasUpdatedAt bool
hasLevel bool
hasRole bool
hasGroups bool
hasPermissions bool
hasActive bool
hasProxyUUID bool
hasProxyUUIDExpiresAt bool
}
func (c schemaCapabilities) supportsMFA() bool {
return c.hasMFATOTPSecret && c.hasMFAEnabled && c.hasMFASecretIssuedAt && c.hasMFAConfirmedAt
}
type postgresStore struct {
db *sql.DB
allowSuperAdminCounting bool
capsMu sync.RWMutex
caps schemaCapabilities
capsLoaded bool
}
func (s *postgresStore) CreateUser(ctx context.Context, user *User) error {
normalizedEmail := strings.ToLower(strings.TrimSpace(user.Email))
normalizedName := strings.TrimSpace(user.Name)
if normalizedName == "" {
return ErrInvalidName
}
caps, err := s.capabilities(ctx)
if err != nil {
return err
}
normalizeUserRoleFields(user)
var (
verifiedAt any
)
if user.EmailVerified {
verifiedAt = time.Now().UTC()
}
if normalizedEmail != "" {
const emailExistsQuery = "SELECT EXISTS(SELECT 1 FROM users WHERE lower(email) = $1)"
emailExists, err := s.userExists(ctx, emailExistsQuery, normalizedEmail)
if err != nil {
return err
}
if emailExists {
return ErrEmailExists
}
}
const nameExistsQuery = "SELECT EXISTS(SELECT 1 FROM users WHERE lower(username) = lower($1))"
nameExists, err := s.userExists(ctx, nameExistsQuery, normalizedName)
if err != nil {
return err
}
if nameExists {
return ErrNameExists
}
columns := []string{"username", "password", "email", "email_verified_at"}
placeholders := []string{"$1", "$2", "$3", "$4"}
args := []any{normalizedName, user.PasswordHash, normalizedEmail, verifiedAt}
idx := len(args) + 1
if caps.hasLevel {
columns = append(columns, "level")
placeholders = append(placeholders, fmt.Sprintf("$%d", idx))
args = append(args, user.Level)
idx++
}
if caps.hasRole {
columns = append(columns, "role")
placeholders = append(placeholders, fmt.Sprintf("$%d", idx))
args = append(args, user.Role)
idx++
}
if caps.hasGroups {
encoded, err := encodeStringSlice(user.Groups)
if err != nil {
return err
}
columns = append(columns, "groups")
placeholders = append(placeholders, fmt.Sprintf("$%d", idx))
args = append(args, encoded)
idx++
}
if caps.hasPermissions {
encoded, err := encodeStringSlice(user.Permissions)
if err != nil {
return err
}
columns = append(columns, "permissions")
placeholders = append(placeholders, fmt.Sprintf("$%d", idx))
args = append(args, encoded)
idx++
}
if caps.hasActive {
columns = append(columns, "active")
placeholders = append(placeholders, fmt.Sprintf("$%d", idx))
args = append(args, user.Active)
idx++
}
if caps.hasProxyUUID {
columns = append(columns, "proxy_uuid")
placeholders = append(placeholders, fmt.Sprintf("$%d", idx))
args = append(args, nullForEmpty(user.ProxyUUID))
idx++
}
if caps.hasProxyUUIDExpiresAt {
columns = append(columns, "proxy_uuid_expires_at")
placeholders = append(placeholders, fmt.Sprintf("$%d", idx))
args = append(args, user.ProxyUUIDExpiresAt)
idx++
}
query := fmt.Sprintf(`INSERT INTO users (%s)
VALUES (%s)
RETURNING uuid, coalesce(created_at, now()), coalesce(updated_at, now()), email_verified`, strings.Join(columns, ", "), strings.Join(placeholders, ", "))
var idValue any
var createdAt time.Time
var updatedAt time.Time
var emailVerified sql.NullBool
err = s.db.QueryRowContext(ctx, query, args...).Scan(&idValue, &createdAt, &updatedAt, &emailVerified)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return ErrUserNotFound
}
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) {
if pgErr.Code == "23505" { // unique_violation
switch {
case strings.Contains(pgErr.ConstraintName, "email"):
return ErrEmailExists
case strings.Contains(pgErr.ConstraintName, "name") || strings.Contains(pgErr.ConstraintName, "username"):
return ErrNameExists
}
}
}
return err
}
identifier, err := formatIdentifier(idValue)
if err != nil {
return err
}
user.ID = identifier
user.Name = normalizedName
user.Email = normalizedEmail
user.CreatedAt = createdAt.UTC()
user.UpdatedAt = updatedAt.UTC()
user.EmailVerified = emailVerified.Bool
return nil
}
func (s *postgresStore) userExists(ctx context.Context, query string, arg any) (bool, error) {
var exists bool
err := s.db.QueryRowContext(ctx, query, arg).Scan(&exists)
if err != nil {
if isDatabaseEmptyError(err) {
return false, nil
}
return false, err
}
return exists, nil
}
func isDatabaseEmptyError(err error) bool {
if errors.Is(err, sql.ErrNoRows) {
return true
}
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) {
if pgErr.Code == "42P01" { // undefined_table
return true
}
}
return false
}
func (s *postgresStore) GetUserByEmail(ctx context.Context, email string) (*User, error) {
normalized := strings.ToLower(strings.TrimSpace(email))
if normalized == "" {
return nil, ErrUserNotFound
}
caps, err := s.capabilities(ctx)
if err != nil {
return nil, err
}
query := s.selectUserQuery(caps, "WHERE lower(email) = $1 LIMIT 1")
row := s.db.QueryRowContext(ctx, query, normalized)
return scanUser(row)
}
func (s *postgresStore) GetUserByName(ctx context.Context, name string) (*User, error) {
normalized := strings.TrimSpace(name)
if normalized == "" {
return nil, ErrUserNotFound
}
caps, err := s.capabilities(ctx)
if err != nil {
return nil, err
}
query := s.selectUserQuery(caps, "WHERE lower(username) = lower($1) LIMIT 1")
row := s.db.QueryRowContext(ctx, query, normalized)
return scanUser(row)
}
func (s *postgresStore) GetUserByID(ctx context.Context, id string) (*User, error) {
caps, err := s.capabilities(ctx)
if err != nil {
return nil, err
}
query := s.selectUserQuery(caps, "WHERE uuid = $1")
row := s.db.QueryRowContext(ctx, query, id)
return scanUser(row)
}
type rowScanner interface {
Scan(dest ...any) error
}
func scanUser(row rowScanner) (*User, error) {
var (
idValue any
username sql.NullString
email sql.NullString
emailVerified sql.NullBool
password sql.NullString
mfaSecret sql.NullString
mfaEnabled sql.NullBool
mfaSecretIssued sql.NullTime
mfaConfirmed sql.NullTime
createdAt time.Time
updatedAt time.Time
levelValue sql.NullInt64
roleValue sql.NullString
groupsRaw []byte
permissionsRaw []byte
activeValue sql.NullBool
proxyUUID sql.NullString
proxyExpiresAt sql.NullTime
)
if err := row.Scan(&idValue, &username, &email, &emailVerified, &password, &mfaSecret, &mfaEnabled, &mfaSecretIssued, &mfaConfirmed, &createdAt, &updatedAt, &levelValue, &roleValue, &groupsRaw, &permissionsRaw, &activeValue, &proxyUUID, &proxyExpiresAt); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrUserNotFound
}
return nil, err
}
identifier, err := formatIdentifier(idValue)
if err != nil {
return nil, err
}
user := &User{
ID: identifier,
Name: strings.TrimSpace(username.String),
Email: strings.ToLower(strings.TrimSpace(email.String)),
EmailVerified: emailVerified.Bool,
PasswordHash: password.String,
MFATOTPSecret: strings.TrimSpace(mfaSecret.String),
MFAEnabled: mfaEnabled.Bool,
MFASecretIssuedAt: toUTCTime(mfaSecretIssued),
MFAConfirmedAt: toUTCTime(mfaConfirmed),
CreatedAt: createdAt.UTC(),
UpdatedAt: updatedAt.UTC(),
}
if levelValue.Valid {
user.Level = int(levelValue.Int64)
}
user.Role = strings.TrimSpace(roleValue.String)
user.Groups = decodeStringSlice(groupsRaw)
user.Permissions = decodeStringSlice(permissionsRaw)
user.Active = activeValue.Valid && activeValue.Bool
user.ProxyUUID = strings.TrimSpace(proxyUUID.String)
if proxyExpiresAt.Valid {
t := proxyExpiresAt.Time.UTC()
user.ProxyUUIDExpiresAt = &t
}
normalizeUserRoleFields(user)
return user, nil
}
func (s *postgresStore) UpdateUser(ctx context.Context, user *User) error {
normalizedName := strings.TrimSpace(user.Name)
if normalizedName == "" {
return ErrInvalidName
}
normalizedEmail := strings.ToLower(strings.TrimSpace(user.Email))
caps, err := s.capabilities(ctx)
if err != nil {
return err
}
var issuedAt any
if !user.MFASecretIssuedAt.IsZero() {
issuedAt = user.MFASecretIssuedAt.UTC()
}
var confirmedAt any
if !user.MFAConfirmedAt.IsZero() {
confirmedAt = user.MFAConfirmedAt.UTC()
}
builder := strings.Builder{}
builder.WriteString("UPDATE users SET username = $1, email = $2, password = $3")
if user.EmailVerified {
builder.WriteString(", email_verified_at = COALESCE(email_verified_at, now())")
} else {
builder.WriteString(", email_verified_at = NULL")
}
normalizeUserRoleFields(user)
args := []any{normalizedName, normalizedEmail, user.PasswordHash}
idx := 4
if caps.hasMFATOTPSecret {
builder.WriteString(fmt.Sprintf(", mfa_totp_secret = $%d", idx))
args = append(args, nullForEmpty(user.MFATOTPSecret))
idx++
} else if strings.TrimSpace(user.MFATOTPSecret) != "" {
return ErrMFANotSupported
}
if caps.hasMFAEnabled {
builder.WriteString(fmt.Sprintf(", mfa_enabled = $%d", idx))
args = append(args, user.MFAEnabled)
idx++
} else if user.MFAEnabled {
return ErrMFANotSupported
}
if caps.hasMFASecretIssuedAt {
builder.WriteString(fmt.Sprintf(", mfa_secret_issued_at = $%d", idx))
args = append(args, issuedAt)
idx++
} else if !user.MFASecretIssuedAt.IsZero() {
return ErrMFANotSupported
}
if caps.hasMFAConfirmedAt {
builder.WriteString(fmt.Sprintf(", mfa_confirmed_at = $%d", idx))
args = append(args, confirmedAt)
idx++
} else if !user.MFAConfirmedAt.IsZero() {
return ErrMFANotSupported
}
if caps.hasUpdatedAt {
builder.WriteString(", updated_at = now()")
}
if caps.hasLevel {
builder.WriteString(fmt.Sprintf(", level = $%d", idx))
args = append(args, user.Level)
idx++
}
if caps.hasRole {
builder.WriteString(fmt.Sprintf(", role = $%d", idx))
args = append(args, user.Role)
idx++
}
if caps.hasGroups {
encoded, err := encodeStringSlice(user.Groups)
if err != nil {
return err
}
builder.WriteString(fmt.Sprintf(", groups = $%d", idx))
args = append(args, encoded)
idx++
}
if caps.hasPermissions {
encoded, err := encodeStringSlice(user.Permissions)
if err != nil {
return err
}
builder.WriteString(fmt.Sprintf(", permissions = $%d", idx))
args = append(args, encoded)
idx++
}
if caps.hasActive {
builder.WriteString(fmt.Sprintf(", active = $%d", idx))
args = append(args, user.Active)
idx++
}
if caps.hasProxyUUID {
builder.WriteString(fmt.Sprintf(", proxy_uuid = $%d", idx))
args = append(args, nullForEmpty(user.ProxyUUID))
idx++
}
if caps.hasProxyUUIDExpiresAt {
builder.WriteString(fmt.Sprintf(", proxy_uuid_expires_at = $%d", idx))
args = append(args, user.ProxyUUIDExpiresAt)
idx++
}
builder.WriteString(fmt.Sprintf(" WHERE uuid = $%d RETURNING ", idx))
args = append(args, user.ID)
idx++
if caps.hasCreatedAt {
builder.WriteString("coalesce(created_at, now())")
} else {
builder.WriteString("now()")
}
if caps.hasUpdatedAt {
builder.WriteString(", coalesce(updated_at, now())")
} else {
builder.WriteString(", now()")
}
query := builder.String()
var createdAt time.Time
var updatedAt time.Time
err = s.db.QueryRowContext(ctx, query, args...).Scan(&createdAt, &updatedAt)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return ErrUserNotFound
}
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) {
if pgErr.Code == "23505" {
switch {
case strings.Contains(pgErr.ConstraintName, "email"):
return ErrEmailExists
case strings.Contains(pgErr.ConstraintName, "name") || strings.Contains(pgErr.ConstraintName, "username"):
return ErrNameExists
}
}
}
return err
}
user.Name = normalizedName
user.Email = normalizedEmail
user.CreatedAt = createdAt.UTC()
user.UpdatedAt = updatedAt.UTC()
return nil
}
func (s *postgresStore) CountSuperAdmins(ctx context.Context) (int, error) {
if !s.allowSuperAdminCounting {
return 0, ErrSuperAdminCountingDisabled
}
caps, err := s.capabilities(ctx)
if err != nil {
return 0, err
}
roleClauses := make([]string, 0, 2)
if caps.hasRole {
roleClauses = append(roleClauses, "lower(role) IN ('root','admin')")
}
if caps.hasLevel {
roleClauses = append(roleClauses, fmt.Sprintf("level = %d", LevelAdmin))
}
if len(roleClauses) == 0 {
return 0, errors.New("postgres store schema does not expose role or level columns")
}
conditions := []string{fmt.Sprintf("(%s)", strings.Join(roleClauses, " OR "))}
if caps.hasGroups {
conditions = append(conditions, "groups @> '[\"Admin\"]'::jsonb")
}
if caps.hasPermissions {
conditions = append(conditions, "permissions @> '[\"*\"]'::jsonb")
}
query := fmt.Sprintf("SELECT COUNT(*) FROM users WHERE %s", strings.Join(conditions, " AND "))
var count int
if err := s.db.QueryRowContext(ctx, query).Scan(&count); err != nil {
return 0, err
}
return count, nil
}
// UpsertSubscription creates or updates a subscription row.
func (s *postgresStore) UpsertSubscription(ctx context.Context, subscription *Subscription) error {
if subscription == nil {
return errors.New("subscription is required")
}
normalizedUserID := strings.TrimSpace(subscription.UserID)
if normalizedUserID == "" {
return ErrUserNotFound
}
externalID := strings.TrimSpace(subscription.ExternalID)
if externalID == "" {
return errors.New("external id is required")
}
if strings.TrimSpace(subscription.PaymentMethod) == "" {
subscription.PaymentMethod = strings.TrimSpace(subscription.Provider)
}
subscription.PaymentQRCode = strings.TrimSpace(subscription.PaymentQRCode)
encodedMeta, err := json.Marshal(subscription.Meta)
if err != nil {
return err
}
var cancelledAt any
if subscription.CancelledAt != nil {
cancelledAt = subscription.CancelledAt.UTC()
}
const query = `INSERT INTO subscriptions (user_uuid, provider, payment_method, kind, plan_id, external_id, status, payment_qr, meta, cancelled_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, COALESCE($9, '{}'::jsonb), $10)
ON CONFLICT (user_uuid, external_id) DO UPDATE SET
provider = EXCLUDED.provider,
payment_method = EXCLUDED.payment_method,
kind = EXCLUDED.kind,
plan_id = EXCLUDED.plan_id,
status = EXCLUDED.status,
payment_qr = EXCLUDED.payment_qr,
meta = EXCLUDED.meta,
cancelled_at = EXCLUDED.cancelled_at,
updated_at = now()
RETURNING uuid, created_at, updated_at, cancelled_at`
var (
idValue any
createdAt time.Time
updatedAt time.Time
cancelled sql.NullTime
)
err = s.db.QueryRowContext(
ctx,
query,
normalizedUserID,
strings.TrimSpace(subscription.Provider),
strings.TrimSpace(subscription.PaymentMethod),
strings.TrimSpace(subscription.Kind),
strings.TrimSpace(subscription.PlanID),
externalID,
strings.TrimSpace(subscription.Status),
strings.TrimSpace(subscription.PaymentQRCode),
encodedMeta,
cancelledAt,
).Scan(&idValue, &createdAt, &updatedAt, &cancelled)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return ErrSubscriptionNotFound
}
return err
}
identifier, err := formatIdentifier(idValue)
if err != nil {
return err
}
subscription.ID = identifier
subscription.UserID = normalizedUserID
subscription.ExternalID = externalID
subscription.CreatedAt = createdAt.UTC()
subscription.UpdatedAt = updatedAt.UTC()
subscription.Meta, _ = decodeSubscriptionMeta(encodedMeta)
if cancelled.Valid {
subscription.CancelledAt = &cancelled.Time
}
return nil
}
// ListSubscriptionsByUser returns all subscriptions for a user ordered by recency.
func (s *postgresStore) ListSubscriptionsByUser(ctx context.Context, userID string) ([]Subscription, error) {
normalizedUserID := strings.TrimSpace(userID)
if normalizedUserID == "" {
return nil, ErrUserNotFound
}
const query = `SELECT uuid, user_uuid, provider, payment_method, kind, plan_id, external_id, status, payment_qr, meta, created_at, updated_at, cancelled_at
FROM subscriptions WHERE user_uuid = $1 ORDER BY created_at DESC`
rows, err := s.db.QueryContext(ctx, query, normalizedUserID)
if err != nil {
return nil, err
}
defer rows.Close()
var subs []Subscription
for rows.Next() {
var (
idValue any
provider string
paymentMethod string
kind string
planID sql.NullString
externalID string
status string
paymentQR sql.NullString
metaBytes []byte
createdAt time.Time
updatedAt time.Time
cancelled sql.NullTime
)
if err := rows.Scan(&idValue, &normalizedUserID, &provider, &paymentMethod, &kind, &planID, &externalID, &status, &paymentQR, &metaBytes, &createdAt, &updatedAt, &cancelled); err != nil {
return nil, err
}
identifier, err := formatIdentifier(idValue)
if err != nil {
return nil, err
}
meta, err := decodeSubscriptionMeta(metaBytes)
if err != nil {
return nil, err
}
sub := Subscription{
ID: identifier,
UserID: userID,
Provider: provider,
PaymentMethod: paymentMethod,
PaymentQRCode: paymentQR.String,
Kind: kind,
PlanID: planID.String,
ExternalID: externalID,
Status: status,
Meta: meta,
CreatedAt: createdAt.UTC(),
UpdatedAt: updatedAt.UTC(),
}
if cancelled.Valid {
sub.CancelledAt = &cancelled.Time
}
subs = append(subs, sub)
}
if err := rows.Err(); err != nil {
return nil, err
}
return subs, nil
}
// CancelSubscription marks the subscription as cancelled.
func (s *postgresStore) CancelSubscription(ctx context.Context, userID, externalID string, cancelledAt time.Time) (*Subscription, error) {
normalizedUserID := strings.TrimSpace(userID)
if normalizedUserID == "" {
return nil, ErrUserNotFound
}
key := strings.TrimSpace(externalID)
if key == "" {
return nil, ErrSubscriptionNotFound
}
const query = `UPDATE subscriptions
SET status = 'cancelled', cancelled_at = $3, updated_at = now()
WHERE user_uuid = $1 AND external_id = $2
RETURNING uuid, provider, payment_method, kind, plan_id, status, payment_qr, meta, created_at, updated_at, cancelled_at`
var (
idValue any
provider string
paymentMethod string
kind string
planID sql.NullString
status string
paymentQR sql.NullString
metaBytes []byte
createdAt time.Time
updatedAt time.Time
cancelled sql.NullTime
)
err := s.db.QueryRowContext(ctx, query, normalizedUserID, key, cancelledAt.UTC()).Scan(
&idValue,
&provider,
&paymentMethod,
&kind,
&planID,
&status,
&paymentQR,
&metaBytes,
&createdAt,
&updatedAt,
&cancelled,
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrSubscriptionNotFound
}
return nil, err
}
identifier, err := formatIdentifier(idValue)
if err != nil {
return nil, err
}
meta, err := decodeSubscriptionMeta(metaBytes)
if err != nil {
return nil, err
}
sub := &Subscription{
ID: identifier,
UserID: normalizedUserID,
Provider: provider,
PaymentMethod: paymentMethod,
PaymentQRCode: paymentQR.String,
Kind: kind,
PlanID: planID.String,
ExternalID: key,
Status: status,
Meta: meta,
CreatedAt: createdAt.UTC(),
UpdatedAt: updatedAt.UTC(),
}
if cancelled.Valid {
sub.CancelledAt = &cancelled.Time
}
return sub, nil
}
func decodeSubscriptionMeta(raw []byte) (map[string]any, error) {
if len(raw) == 0 {
return map[string]any{}, nil
}
var meta map[string]any
if err := json.Unmarshal(raw, &meta); err != nil {
return nil, err
}
if meta == nil {
meta = map[string]any{}
}
return meta, nil
}
func nullForEmpty(value string) any {
if strings.TrimSpace(value) == "" {
return nil
}
return value
}
func toUTCTime(value sql.NullTime) time.Time {
if !value.Valid {
return time.Time{}
}
return value.Time.UTC()
}
func formatIdentifier(value any) (string, error) {
switch v := value.(type) {
case nil:
return "", errors.New("user id is nil")
case string:
return v, nil
case []byte:
return string(v), nil
case [16]byte:
id := uuid.UUID(v)
return id.String(), nil
case *[16]byte:
if v == nil {
return "", errors.New("user id is nil")
}
id := uuid.UUID(*v)
return id.String(), nil
case int64:
return strconv.FormatInt(v, 10), nil
case int32:
return strconv.FormatInt(int64(v), 10), nil
case int:
return strconv.FormatInt(int64(v), 10), nil
case uint64:
return strconv.FormatUint(v, 10), nil
case uint32:
return strconv.FormatUint(uint64(v), 10), nil
case pgtype.UUID:
if !v.Valid {
return "", errors.New("user id is nil")
}
return v.String(), nil
case *pgtype.UUID:
if v == nil || !v.Valid {
return "", errors.New("user id is nil")
}
return v.String(), nil
case fmt.Stringer:
return v.String(), nil
default:
return "", fmt.Errorf("unsupported identifier type %T", value)
}
}
func (s *postgresStore) capabilities(ctx context.Context) (schemaCapabilities, error) {
s.capsMu.RLock()
if s.capsLoaded {
caps := s.caps
s.capsMu.RUnlock()
return caps, nil
}
s.capsMu.RUnlock()
s.capsMu.Lock()
defer s.capsMu.Unlock()
if s.capsLoaded {
return s.caps, nil
}
query := `SELECT
EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'users'
AND table_schema = ANY (current_schemas(false))
AND column_name = 'mfa_totp_secret'
) AS has_mfa_totp_secret,
EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'users'
AND table_schema = ANY (current_schemas(false))
AND column_name = 'mfa_enabled'
) AS has_mfa_enabled,
EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'users'
AND table_schema = ANY (current_schemas(false))
AND column_name = 'mfa_secret_issued_at'
) AS has_mfa_secret_issued_at,
EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'users'
AND table_schema = ANY (current_schemas(false))
AND column_name = 'mfa_confirmed_at'
) AS has_mfa_confirmed_at,
EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'users'
AND table_schema = ANY (current_schemas(false))
AND column_name = 'created_at'
) AS has_created_at,
EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'users'
AND table_schema = ANY (current_schemas(false))
AND column_name = 'updated_at'
) AS has_updated_at,
EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'users'
AND table_schema = ANY (current_schemas(false))
AND column_name = 'level'
) AS has_level,
EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'users'
AND table_schema = ANY (current_schemas(false))
AND column_name = 'role'
) AS has_role,
EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'users'
AND table_schema = ANY (current_schemas(false))
AND column_name = 'groups'
) AS has_groups,
EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'users'
AND table_schema = ANY (current_schemas(false))
AND column_name = 'permissions'
) AS has_permissions,
EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'users'
AND table_schema = ANY (current_schemas(false))
AND column_name = 'active'
) AS has_active,
EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'users'
AND table_schema = ANY (current_schemas(false))
AND column_name = 'proxy_uuid'
) AS has_proxy_uuid,
EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'users'
AND table_schema = ANY (current_schemas(false))
AND column_name = 'proxy_uuid_expires_at'
) AS has_proxy_uuid_expires_at`
row := s.db.QueryRowContext(ctx, query)
var caps schemaCapabilities
if err := row.Scan(
&caps.hasMFATOTPSecret,
&caps.hasMFAEnabled,
&caps.hasMFASecretIssuedAt,
&caps.hasMFAConfirmedAt,
&caps.hasCreatedAt,
&caps.hasUpdatedAt,
&caps.hasLevel,
&caps.hasRole,
&caps.hasGroups,
&caps.hasPermissions,
&caps.hasActive,
&caps.hasProxyUUID,
&caps.hasProxyUUIDExpiresAt,
); err != nil {
return schemaCapabilities{}, err
}
s.caps = caps
s.capsLoaded = true
return caps, nil
}
func (s *postgresStore) selectUserQuery(caps schemaCapabilities, whereClause string) string {
secretExpr := "NULL::text"
if caps.hasMFATOTPSecret {
secretExpr = "mfa_totp_secret"
}
enabledExpr := "false"
if caps.hasMFAEnabled {
enabledExpr = "coalesce(mfa_enabled, false)"
}
issuedExpr := "NULL::timestamptz"
if caps.hasMFASecretIssuedAt {
issuedExpr = "mfa_secret_issued_at"
}
confirmedExpr := "NULL::timestamptz"
if caps.hasMFAConfirmedAt {
confirmedExpr = "mfa_confirmed_at"
}
createdExpr := "now()"
if caps.hasCreatedAt {
createdExpr = "coalesce(created_at, now())"
}
updatedExpr := "now()"
if caps.hasUpdatedAt {
updatedExpr = "coalesce(updated_at, now())"
}
levelExpr := fmt.Sprintf("%d", LevelUser)
if caps.hasLevel {
levelExpr = fmt.Sprintf("coalesce(level, %d)", LevelUser)
}
roleExpr := fmt.Sprintf("'%s'", RoleUser)
if caps.hasRole {
roleExpr = fmt.Sprintf("coalesce(role, '%s')", RoleUser)
}
groupsExpr := "'[]'::jsonb"
if caps.hasGroups {
groupsExpr = "coalesce(groups, '[]'::jsonb)"
}
permissionsExpr := "'[]'::jsonb"
if caps.hasPermissions {
permissionsExpr = "coalesce(permissions, '[]'::jsonb)"
}
activeExpr := "true"
if caps.hasActive {
activeExpr = "coalesce(active, true)"
}
proxyUUIDExpr := "NULL::uuid"
if caps.hasProxyUUID {
proxyUUIDExpr = "proxy_uuid"
}
proxyExpiresAtExpr := "NULL::timestamptz"
if caps.hasProxyUUIDExpiresAt {
proxyExpiresAtExpr = "proxy_uuid_expires_at"
}
return fmt.Sprintf(`SELECT uuid, username, email, email_verified, password, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s FROM users %s`,
secretExpr, enabledExpr, issuedExpr, confirmedExpr, createdExpr, updatedExpr, levelExpr, roleExpr, groupsExpr, permissionsExpr, activeExpr, proxyUUIDExpr, proxyExpiresAtExpr, whereClause)
}
func encodeStringSlice(values []string) ([]byte, error) {
normalized := normalizeStringSlice(values)
if len(normalized) == 0 {
return []byte("[]"), nil
}
return json.Marshal(normalized)
}
func decodeStringSlice(raw []byte) []string {
if len(raw) == 0 {
return nil
}
var values []string
if err := json.Unmarshal(raw, &values); err != nil {
return nil
}
return normalizeStringSlice(values)
}
func (s *postgresStore) CreateIdentity(ctx context.Context, identity *Identity) error {
if identity == nil {
return errors.New("identity is required")
}
normalizedUserID := strings.TrimSpace(identity.UserID)
if normalizedUserID == "" {
return ErrUserNotFound
}
provider := strings.TrimSpace(identity.Provider)
externalID := strings.TrimSpace(identity.ExternalID)
if provider == "" || externalID == "" {
return errors.New("provider and external_id are required")
}
const query = `INSERT INTO identities (user_uuid, provider, external_id)
VALUES ($1, $2, $3)
RETURNING uuid, created_at, updated_at`
var (
idValue any
createdAt time.Time
updatedAt time.Time
)
err := s.db.QueryRowContext(ctx, query, normalizedUserID, provider, externalID).Scan(&idValue, &createdAt, &updatedAt)
if err != nil {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) {
if pgErr.Code == "23505" { // unique_violation
return errors.New("identity already exists")
}
}
return err
}
identifier, err := formatIdentifier(idValue)
if err != nil {
return err
}
identity.ID = identifier
identity.CreatedAt = createdAt.UTC()
identity.UpdatedAt = updatedAt.UTC()
return nil
}
// ListUsers returns all users from the postgres store.
func (s *postgresStore) ListUsers(ctx context.Context) ([]User, error) {
caps, err := s.capabilities(ctx)
if err != nil {
return nil, err
}
query := s.selectUserQuery(caps, "ORDER BY created_at ASC")
rows, err := s.db.QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
var users []User
for rows.Next() {
user, err := scanUser(rows)
if err != nil {
return nil, err
}
users = append(users, *user)
}
if err := rows.Err(); err != nil {
return nil, err
}
return users, nil
}
func (s *postgresStore) DeleteUser(ctx context.Context, id string) error {
const query = "DELETE FROM users WHERE uuid = $1"
_, err := s.db.ExecContext(ctx, query, id)
return err
}
func (s *postgresStore) AddToBlacklist(ctx context.Context, email string) error {
const query = "INSERT INTO email_blacklist (email) VALUES ($1) ON CONFLICT (email) DO NOTHING"
_, err := s.db.ExecContext(ctx, query, strings.ToLower(email))
return err
}
func (s *postgresStore) RemoveFromBlacklist(ctx context.Context, email string) error {
const query = "DELETE FROM email_blacklist WHERE email = $1"
_, err := s.db.ExecContext(ctx, query, strings.ToLower(email))
return err
}
func (s *postgresStore) IsBlacklisted(ctx context.Context, email string) (bool, error) {
const query = "SELECT EXISTS(SELECT 1 FROM email_blacklist WHERE email = $1)"
var exists bool
err := s.db.QueryRowContext(ctx, query, strings.ToLower(email)).Scan(&exists)
return exists, err
}
func (s *postgresStore) ListBlacklist(ctx context.Context) ([]string, error) {
const query = "SELECT email FROM email_blacklist ORDER BY created_at DESC"
rows, err := s.db.QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
var emails []string
for rows.Next() {
var email string
if err := rows.Scan(&email); err != nil {
return nil, err
}
emails = append(emails, email)
}
return emails, nil
}