1238 lines
31 KiB
Go
1238 lines
31 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/jackc/pgx/v5/pgconn"
|
|
"github.com/jackc/pgx/v5/pgtype"
|
|
_ "github.com/jackc/pgx/v5/stdlib"
|
|
)
|
|
|
|
// Config describes how to construct a Store implementation.
|
|
type Config struct {
|
|
Driver string
|
|
DSN string
|
|
MaxOpenConns int
|
|
MaxIdleConns int
|
|
ConnMaxLifetime time.Duration
|
|
ConnMaxIdleTime time.Duration
|
|
AllowSuperAdminCounting bool
|
|
}
|
|
|
|
// New creates a Store implementation based on the provided configuration.
|
|
func New(ctx context.Context, cfg Config) (Store, func(context.Context) error, error) {
|
|
driver := strings.ToLower(strings.TrimSpace(cfg.Driver))
|
|
if driver == "" || driver == "memory" {
|
|
ms := newMemoryStore(cfg.AllowSuperAdminCounting)
|
|
return ms, func(context.Context) error { return nil }, nil
|
|
}
|
|
|
|
switch driver {
|
|
case "postgres", "postgresql", "pgx":
|
|
if strings.TrimSpace(cfg.DSN) == "" {
|
|
return nil, nil, errors.New("store dsn is required for postgres driver")
|
|
}
|
|
|
|
db, err := sql.Open("pgx", cfg.DSN)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
if cfg.MaxOpenConns > 0 {
|
|
db.SetMaxOpenConns(cfg.MaxOpenConns)
|
|
}
|
|
if cfg.MaxIdleConns > 0 {
|
|
db.SetMaxIdleConns(cfg.MaxIdleConns)
|
|
}
|
|
if cfg.ConnMaxLifetime > 0 {
|
|
db.SetConnMaxLifetime(cfg.ConnMaxLifetime)
|
|
}
|
|
if cfg.ConnMaxIdleTime > 0 {
|
|
db.SetConnMaxIdleTime(cfg.ConnMaxIdleTime)
|
|
}
|
|
|
|
if err := db.PingContext(ctx); err != nil {
|
|
db.Close()
|
|
return nil, nil, err
|
|
}
|
|
|
|
cleanup := func(context.Context) error {
|
|
return db.Close()
|
|
}
|
|
|
|
return &postgresStore{db: db, allowSuperAdminCounting: cfg.AllowSuperAdminCounting}, cleanup, nil
|
|
default:
|
|
return nil, nil, fmt.Errorf("unsupported store driver %q", cfg.Driver)
|
|
}
|
|
}
|
|
|
|
type schemaCapabilities struct {
|
|
hasMFATOTPSecret bool
|
|
hasMFAEnabled bool
|
|
hasMFASecretIssuedAt bool
|
|
hasMFAConfirmedAt bool
|
|
hasCreatedAt bool
|
|
hasUpdatedAt bool
|
|
hasLevel bool
|
|
hasRole bool
|
|
hasGroups bool
|
|
hasPermissions bool
|
|
hasActive bool
|
|
hasProxyUUID bool
|
|
hasProxyUUIDExpiresAt bool
|
|
}
|
|
|
|
func (c schemaCapabilities) supportsMFA() bool {
|
|
return c.hasMFATOTPSecret && c.hasMFAEnabled && c.hasMFASecretIssuedAt && c.hasMFAConfirmedAt
|
|
}
|
|
|
|
type postgresStore struct {
|
|
db *sql.DB
|
|
allowSuperAdminCounting bool
|
|
|
|
capsMu sync.RWMutex
|
|
caps schemaCapabilities
|
|
capsLoaded bool
|
|
}
|
|
|
|
func (s *postgresStore) CreateUser(ctx context.Context, user *User) error {
|
|
normalizedEmail := strings.ToLower(strings.TrimSpace(user.Email))
|
|
normalizedName := strings.TrimSpace(user.Name)
|
|
if normalizedName == "" {
|
|
return ErrInvalidName
|
|
}
|
|
|
|
caps, err := s.capabilities(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
normalizeUserRoleFields(user)
|
|
|
|
var (
|
|
verifiedAt any
|
|
)
|
|
if user.EmailVerified {
|
|
verifiedAt = time.Now().UTC()
|
|
}
|
|
|
|
if normalizedEmail != "" {
|
|
const emailExistsQuery = "SELECT EXISTS(SELECT 1 FROM users WHERE lower(email) = $1)"
|
|
emailExists, err := s.userExists(ctx, emailExistsQuery, normalizedEmail)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if emailExists {
|
|
return ErrEmailExists
|
|
}
|
|
}
|
|
|
|
const nameExistsQuery = "SELECT EXISTS(SELECT 1 FROM users WHERE lower(username) = lower($1))"
|
|
nameExists, err := s.userExists(ctx, nameExistsQuery, normalizedName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if nameExists {
|
|
return ErrNameExists
|
|
}
|
|
|
|
columns := []string{"username", "password", "email", "email_verified_at"}
|
|
placeholders := []string{"$1", "$2", "$3", "$4"}
|
|
args := []any{normalizedName, user.PasswordHash, normalizedEmail, verifiedAt}
|
|
|
|
idx := len(args) + 1
|
|
|
|
if caps.hasLevel {
|
|
columns = append(columns, "level")
|
|
placeholders = append(placeholders, fmt.Sprintf("$%d", idx))
|
|
args = append(args, user.Level)
|
|
idx++
|
|
}
|
|
if caps.hasRole {
|
|
columns = append(columns, "role")
|
|
placeholders = append(placeholders, fmt.Sprintf("$%d", idx))
|
|
args = append(args, user.Role)
|
|
idx++
|
|
}
|
|
if caps.hasGroups {
|
|
encoded, err := encodeStringSlice(user.Groups)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
columns = append(columns, "groups")
|
|
placeholders = append(placeholders, fmt.Sprintf("$%d", idx))
|
|
args = append(args, encoded)
|
|
idx++
|
|
}
|
|
if caps.hasPermissions {
|
|
encoded, err := encodeStringSlice(user.Permissions)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
columns = append(columns, "permissions")
|
|
placeholders = append(placeholders, fmt.Sprintf("$%d", idx))
|
|
args = append(args, encoded)
|
|
idx++
|
|
}
|
|
|
|
if caps.hasActive {
|
|
columns = append(columns, "active")
|
|
placeholders = append(placeholders, fmt.Sprintf("$%d", idx))
|
|
args = append(args, user.Active)
|
|
idx++
|
|
}
|
|
if caps.hasProxyUUID {
|
|
columns = append(columns, "proxy_uuid")
|
|
placeholders = append(placeholders, fmt.Sprintf("$%d", idx))
|
|
args = append(args, nullForEmpty(user.ProxyUUID))
|
|
idx++
|
|
}
|
|
if caps.hasProxyUUIDExpiresAt {
|
|
columns = append(columns, "proxy_uuid_expires_at")
|
|
placeholders = append(placeholders, fmt.Sprintf("$%d", idx))
|
|
args = append(args, user.ProxyUUIDExpiresAt)
|
|
idx++
|
|
}
|
|
|
|
query := fmt.Sprintf(`INSERT INTO users (%s)
|
|
VALUES (%s)
|
|
RETURNING uuid, coalesce(created_at, now()), coalesce(updated_at, now()), email_verified`, strings.Join(columns, ", "), strings.Join(placeholders, ", "))
|
|
|
|
var idValue any
|
|
var createdAt time.Time
|
|
var updatedAt time.Time
|
|
var emailVerified sql.NullBool
|
|
err = s.db.QueryRowContext(ctx, query, args...).Scan(&idValue, &createdAt, &updatedAt, &emailVerified)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return ErrUserNotFound
|
|
}
|
|
var pgErr *pgconn.PgError
|
|
if errors.As(err, &pgErr) {
|
|
if pgErr.Code == "23505" { // unique_violation
|
|
switch {
|
|
case strings.Contains(pgErr.ConstraintName, "email"):
|
|
return ErrEmailExists
|
|
case strings.Contains(pgErr.ConstraintName, "name") || strings.Contains(pgErr.ConstraintName, "username"):
|
|
return ErrNameExists
|
|
}
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
|
|
identifier, err := formatIdentifier(idValue)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
user.ID = identifier
|
|
user.Name = normalizedName
|
|
user.Email = normalizedEmail
|
|
user.CreatedAt = createdAt.UTC()
|
|
user.UpdatedAt = updatedAt.UTC()
|
|
user.EmailVerified = emailVerified.Bool
|
|
return nil
|
|
}
|
|
|
|
func (s *postgresStore) userExists(ctx context.Context, query string, arg any) (bool, error) {
|
|
var exists bool
|
|
err := s.db.QueryRowContext(ctx, query, arg).Scan(&exists)
|
|
if err != nil {
|
|
if isDatabaseEmptyError(err) {
|
|
return false, nil
|
|
}
|
|
return false, err
|
|
}
|
|
return exists, nil
|
|
}
|
|
|
|
func isDatabaseEmptyError(err error) bool {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return true
|
|
}
|
|
var pgErr *pgconn.PgError
|
|
if errors.As(err, &pgErr) {
|
|
if pgErr.Code == "42P01" { // undefined_table
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (s *postgresStore) GetUserByEmail(ctx context.Context, email string) (*User, error) {
|
|
normalized := strings.ToLower(strings.TrimSpace(email))
|
|
if normalized == "" {
|
|
return nil, ErrUserNotFound
|
|
}
|
|
|
|
caps, err := s.capabilities(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
query := s.selectUserQuery(caps, "WHERE lower(email) = $1 LIMIT 1")
|
|
|
|
row := s.db.QueryRowContext(ctx, query, normalized)
|
|
return scanUser(row)
|
|
}
|
|
|
|
func (s *postgresStore) GetUserByName(ctx context.Context, name string) (*User, error) {
|
|
normalized := strings.TrimSpace(name)
|
|
if normalized == "" {
|
|
return nil, ErrUserNotFound
|
|
}
|
|
|
|
caps, err := s.capabilities(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
query := s.selectUserQuery(caps, "WHERE lower(username) = lower($1) LIMIT 1")
|
|
|
|
row := s.db.QueryRowContext(ctx, query, normalized)
|
|
return scanUser(row)
|
|
}
|
|
|
|
func (s *postgresStore) GetUserByID(ctx context.Context, id string) (*User, error) {
|
|
caps, err := s.capabilities(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
query := s.selectUserQuery(caps, "WHERE uuid = $1")
|
|
|
|
row := s.db.QueryRowContext(ctx, query, id)
|
|
return scanUser(row)
|
|
}
|
|
|
|
type rowScanner interface {
|
|
Scan(dest ...any) error
|
|
}
|
|
|
|
func scanUser(row rowScanner) (*User, error) {
|
|
var (
|
|
idValue any
|
|
username sql.NullString
|
|
email sql.NullString
|
|
emailVerified sql.NullBool
|
|
password sql.NullString
|
|
mfaSecret sql.NullString
|
|
mfaEnabled sql.NullBool
|
|
mfaSecretIssued sql.NullTime
|
|
mfaConfirmed sql.NullTime
|
|
createdAt time.Time
|
|
updatedAt time.Time
|
|
levelValue sql.NullInt64
|
|
roleValue sql.NullString
|
|
groupsRaw []byte
|
|
permissionsRaw []byte
|
|
activeValue sql.NullBool
|
|
proxyUUID sql.NullString
|
|
proxyExpiresAt sql.NullTime
|
|
)
|
|
|
|
if err := row.Scan(&idValue, &username, &email, &emailVerified, &password, &mfaSecret, &mfaEnabled, &mfaSecretIssued, &mfaConfirmed, &createdAt, &updatedAt, &levelValue, &roleValue, &groupsRaw, &permissionsRaw, &activeValue, &proxyUUID, &proxyExpiresAt); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, ErrUserNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
identifier, err := formatIdentifier(idValue)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
user := &User{
|
|
ID: identifier,
|
|
Name: strings.TrimSpace(username.String),
|
|
Email: strings.ToLower(strings.TrimSpace(email.String)),
|
|
EmailVerified: emailVerified.Bool,
|
|
PasswordHash: password.String,
|
|
MFATOTPSecret: strings.TrimSpace(mfaSecret.String),
|
|
MFAEnabled: mfaEnabled.Bool,
|
|
MFASecretIssuedAt: toUTCTime(mfaSecretIssued),
|
|
MFAConfirmedAt: toUTCTime(mfaConfirmed),
|
|
CreatedAt: createdAt.UTC(),
|
|
UpdatedAt: updatedAt.UTC(),
|
|
}
|
|
if levelValue.Valid {
|
|
user.Level = int(levelValue.Int64)
|
|
}
|
|
user.Role = strings.TrimSpace(roleValue.String)
|
|
user.Groups = decodeStringSlice(groupsRaw)
|
|
user.Permissions = decodeStringSlice(permissionsRaw)
|
|
user.Active = activeValue.Valid && activeValue.Bool
|
|
user.ProxyUUID = strings.TrimSpace(proxyUUID.String)
|
|
if proxyExpiresAt.Valid {
|
|
t := proxyExpiresAt.Time.UTC()
|
|
user.ProxyUUIDExpiresAt = &t
|
|
}
|
|
normalizeUserRoleFields(user)
|
|
return user, nil
|
|
}
|
|
|
|
func (s *postgresStore) UpdateUser(ctx context.Context, user *User) error {
|
|
normalizedName := strings.TrimSpace(user.Name)
|
|
if normalizedName == "" {
|
|
return ErrInvalidName
|
|
}
|
|
|
|
normalizedEmail := strings.ToLower(strings.TrimSpace(user.Email))
|
|
|
|
caps, err := s.capabilities(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var issuedAt any
|
|
if !user.MFASecretIssuedAt.IsZero() {
|
|
issuedAt = user.MFASecretIssuedAt.UTC()
|
|
}
|
|
var confirmedAt any
|
|
if !user.MFAConfirmedAt.IsZero() {
|
|
confirmedAt = user.MFAConfirmedAt.UTC()
|
|
}
|
|
|
|
builder := strings.Builder{}
|
|
builder.WriteString("UPDATE users SET username = $1, email = $2, password = $3")
|
|
|
|
if user.EmailVerified {
|
|
builder.WriteString(", email_verified_at = COALESCE(email_verified_at, now())")
|
|
} else {
|
|
builder.WriteString(", email_verified_at = NULL")
|
|
}
|
|
|
|
normalizeUserRoleFields(user)
|
|
|
|
args := []any{normalizedName, normalizedEmail, user.PasswordHash}
|
|
idx := 4
|
|
|
|
if caps.hasMFATOTPSecret {
|
|
builder.WriteString(fmt.Sprintf(", mfa_totp_secret = $%d", idx))
|
|
args = append(args, nullForEmpty(user.MFATOTPSecret))
|
|
idx++
|
|
} else if strings.TrimSpace(user.MFATOTPSecret) != "" {
|
|
return ErrMFANotSupported
|
|
}
|
|
|
|
if caps.hasMFAEnabled {
|
|
builder.WriteString(fmt.Sprintf(", mfa_enabled = $%d", idx))
|
|
args = append(args, user.MFAEnabled)
|
|
idx++
|
|
} else if user.MFAEnabled {
|
|
return ErrMFANotSupported
|
|
}
|
|
|
|
if caps.hasMFASecretIssuedAt {
|
|
builder.WriteString(fmt.Sprintf(", mfa_secret_issued_at = $%d", idx))
|
|
args = append(args, issuedAt)
|
|
idx++
|
|
} else if !user.MFASecretIssuedAt.IsZero() {
|
|
return ErrMFANotSupported
|
|
}
|
|
|
|
if caps.hasMFAConfirmedAt {
|
|
builder.WriteString(fmt.Sprintf(", mfa_confirmed_at = $%d", idx))
|
|
args = append(args, confirmedAt)
|
|
idx++
|
|
} else if !user.MFAConfirmedAt.IsZero() {
|
|
return ErrMFANotSupported
|
|
}
|
|
|
|
if caps.hasUpdatedAt {
|
|
builder.WriteString(", updated_at = now()")
|
|
}
|
|
|
|
if caps.hasLevel {
|
|
builder.WriteString(fmt.Sprintf(", level = $%d", idx))
|
|
args = append(args, user.Level)
|
|
idx++
|
|
}
|
|
|
|
if caps.hasRole {
|
|
builder.WriteString(fmt.Sprintf(", role = $%d", idx))
|
|
args = append(args, user.Role)
|
|
idx++
|
|
}
|
|
|
|
if caps.hasGroups {
|
|
encoded, err := encodeStringSlice(user.Groups)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
builder.WriteString(fmt.Sprintf(", groups = $%d", idx))
|
|
args = append(args, encoded)
|
|
idx++
|
|
}
|
|
|
|
if caps.hasPermissions {
|
|
encoded, err := encodeStringSlice(user.Permissions)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
builder.WriteString(fmt.Sprintf(", permissions = $%d", idx))
|
|
args = append(args, encoded)
|
|
idx++
|
|
}
|
|
|
|
if caps.hasActive {
|
|
builder.WriteString(fmt.Sprintf(", active = $%d", idx))
|
|
args = append(args, user.Active)
|
|
idx++
|
|
}
|
|
if caps.hasProxyUUID {
|
|
builder.WriteString(fmt.Sprintf(", proxy_uuid = $%d", idx))
|
|
args = append(args, nullForEmpty(user.ProxyUUID))
|
|
idx++
|
|
}
|
|
if caps.hasProxyUUIDExpiresAt {
|
|
builder.WriteString(fmt.Sprintf(", proxy_uuid_expires_at = $%d", idx))
|
|
args = append(args, user.ProxyUUIDExpiresAt)
|
|
idx++
|
|
}
|
|
|
|
builder.WriteString(fmt.Sprintf(" WHERE uuid = $%d RETURNING ", idx))
|
|
args = append(args, user.ID)
|
|
idx++
|
|
|
|
if caps.hasCreatedAt {
|
|
builder.WriteString("coalesce(created_at, now())")
|
|
} else {
|
|
builder.WriteString("now()")
|
|
}
|
|
|
|
if caps.hasUpdatedAt {
|
|
builder.WriteString(", coalesce(updated_at, now())")
|
|
} else {
|
|
builder.WriteString(", now()")
|
|
}
|
|
|
|
query := builder.String()
|
|
|
|
var createdAt time.Time
|
|
var updatedAt time.Time
|
|
err = s.db.QueryRowContext(ctx, query, args...).Scan(&createdAt, &updatedAt)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return ErrUserNotFound
|
|
}
|
|
var pgErr *pgconn.PgError
|
|
if errors.As(err, &pgErr) {
|
|
if pgErr.Code == "23505" {
|
|
switch {
|
|
case strings.Contains(pgErr.ConstraintName, "email"):
|
|
return ErrEmailExists
|
|
case strings.Contains(pgErr.ConstraintName, "name") || strings.Contains(pgErr.ConstraintName, "username"):
|
|
return ErrNameExists
|
|
}
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
|
|
user.Name = normalizedName
|
|
user.Email = normalizedEmail
|
|
user.CreatedAt = createdAt.UTC()
|
|
user.UpdatedAt = updatedAt.UTC()
|
|
return nil
|
|
}
|
|
|
|
func (s *postgresStore) CountSuperAdmins(ctx context.Context) (int, error) {
|
|
if !s.allowSuperAdminCounting {
|
|
return 0, ErrSuperAdminCountingDisabled
|
|
}
|
|
caps, err := s.capabilities(ctx)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
roleClauses := make([]string, 0, 2)
|
|
if caps.hasRole {
|
|
roleClauses = append(roleClauses, "lower(role) IN ('root','admin')")
|
|
}
|
|
if caps.hasLevel {
|
|
roleClauses = append(roleClauses, fmt.Sprintf("level = %d", LevelAdmin))
|
|
}
|
|
if len(roleClauses) == 0 {
|
|
return 0, errors.New("postgres store schema does not expose role or level columns")
|
|
}
|
|
|
|
conditions := []string{fmt.Sprintf("(%s)", strings.Join(roleClauses, " OR "))}
|
|
|
|
if caps.hasGroups {
|
|
conditions = append(conditions, "groups @> '[\"Admin\"]'::jsonb")
|
|
}
|
|
if caps.hasPermissions {
|
|
conditions = append(conditions, "permissions @> '[\"*\"]'::jsonb")
|
|
}
|
|
|
|
query := fmt.Sprintf("SELECT COUNT(*) FROM users WHERE %s", strings.Join(conditions, " AND "))
|
|
|
|
var count int
|
|
if err := s.db.QueryRowContext(ctx, query).Scan(&count); err != nil {
|
|
return 0, err
|
|
}
|
|
return count, nil
|
|
}
|
|
|
|
// UpsertSubscription creates or updates a subscription row.
|
|
func (s *postgresStore) UpsertSubscription(ctx context.Context, subscription *Subscription) error {
|
|
if subscription == nil {
|
|
return errors.New("subscription is required")
|
|
}
|
|
|
|
normalizedUserID := strings.TrimSpace(subscription.UserID)
|
|
if normalizedUserID == "" {
|
|
return ErrUserNotFound
|
|
}
|
|
|
|
externalID := strings.TrimSpace(subscription.ExternalID)
|
|
if externalID == "" {
|
|
return errors.New("external id is required")
|
|
}
|
|
if strings.TrimSpace(subscription.PaymentMethod) == "" {
|
|
subscription.PaymentMethod = strings.TrimSpace(subscription.Provider)
|
|
}
|
|
subscription.PaymentQRCode = strings.TrimSpace(subscription.PaymentQRCode)
|
|
|
|
encodedMeta, err := json.Marshal(subscription.Meta)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var cancelledAt any
|
|
if subscription.CancelledAt != nil {
|
|
cancelledAt = subscription.CancelledAt.UTC()
|
|
}
|
|
|
|
const query = `INSERT INTO subscriptions (user_uuid, provider, payment_method, kind, plan_id, external_id, status, payment_qr, meta, cancelled_at)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, COALESCE($9, '{}'::jsonb), $10)
|
|
ON CONFLICT (user_uuid, external_id) DO UPDATE SET
|
|
provider = EXCLUDED.provider,
|
|
payment_method = EXCLUDED.payment_method,
|
|
kind = EXCLUDED.kind,
|
|
plan_id = EXCLUDED.plan_id,
|
|
status = EXCLUDED.status,
|
|
payment_qr = EXCLUDED.payment_qr,
|
|
meta = EXCLUDED.meta,
|
|
cancelled_at = EXCLUDED.cancelled_at,
|
|
updated_at = now()
|
|
RETURNING uuid, created_at, updated_at, cancelled_at`
|
|
|
|
var (
|
|
idValue any
|
|
createdAt time.Time
|
|
updatedAt time.Time
|
|
cancelled sql.NullTime
|
|
)
|
|
|
|
err = s.db.QueryRowContext(
|
|
ctx,
|
|
query,
|
|
normalizedUserID,
|
|
strings.TrimSpace(subscription.Provider),
|
|
strings.TrimSpace(subscription.PaymentMethod),
|
|
strings.TrimSpace(subscription.Kind),
|
|
strings.TrimSpace(subscription.PlanID),
|
|
externalID,
|
|
strings.TrimSpace(subscription.Status),
|
|
strings.TrimSpace(subscription.PaymentQRCode),
|
|
encodedMeta,
|
|
cancelledAt,
|
|
).Scan(&idValue, &createdAt, &updatedAt, &cancelled)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return ErrSubscriptionNotFound
|
|
}
|
|
return err
|
|
}
|
|
|
|
identifier, err := formatIdentifier(idValue)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
subscription.ID = identifier
|
|
subscription.UserID = normalizedUserID
|
|
subscription.ExternalID = externalID
|
|
subscription.CreatedAt = createdAt.UTC()
|
|
subscription.UpdatedAt = updatedAt.UTC()
|
|
subscription.Meta, _ = decodeSubscriptionMeta(encodedMeta)
|
|
if cancelled.Valid {
|
|
subscription.CancelledAt = &cancelled.Time
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ListSubscriptionsByUser returns all subscriptions for a user ordered by recency.
|
|
func (s *postgresStore) ListSubscriptionsByUser(ctx context.Context, userID string) ([]Subscription, error) {
|
|
normalizedUserID := strings.TrimSpace(userID)
|
|
if normalizedUserID == "" {
|
|
return nil, ErrUserNotFound
|
|
}
|
|
|
|
const query = `SELECT uuid, user_uuid, provider, payment_method, kind, plan_id, external_id, status, payment_qr, meta, created_at, updated_at, cancelled_at
|
|
FROM subscriptions WHERE user_uuid = $1 ORDER BY created_at DESC`
|
|
|
|
rows, err := s.db.QueryContext(ctx, query, normalizedUserID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var subs []Subscription
|
|
for rows.Next() {
|
|
var (
|
|
idValue any
|
|
provider string
|
|
paymentMethod string
|
|
kind string
|
|
planID sql.NullString
|
|
externalID string
|
|
status string
|
|
paymentQR sql.NullString
|
|
metaBytes []byte
|
|
createdAt time.Time
|
|
updatedAt time.Time
|
|
cancelled sql.NullTime
|
|
)
|
|
if err := rows.Scan(&idValue, &normalizedUserID, &provider, &paymentMethod, &kind, &planID, &externalID, &status, &paymentQR, &metaBytes, &createdAt, &updatedAt, &cancelled); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
identifier, err := formatIdentifier(idValue)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
meta, err := decodeSubscriptionMeta(metaBytes)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sub := Subscription{
|
|
ID: identifier,
|
|
UserID: userID,
|
|
Provider: provider,
|
|
PaymentMethod: paymentMethod,
|
|
PaymentQRCode: paymentQR.String,
|
|
Kind: kind,
|
|
PlanID: planID.String,
|
|
ExternalID: externalID,
|
|
Status: status,
|
|
Meta: meta,
|
|
CreatedAt: createdAt.UTC(),
|
|
UpdatedAt: updatedAt.UTC(),
|
|
}
|
|
if cancelled.Valid {
|
|
sub.CancelledAt = &cancelled.Time
|
|
}
|
|
|
|
subs = append(subs, sub)
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return subs, nil
|
|
}
|
|
|
|
// CancelSubscription marks the subscription as cancelled.
|
|
func (s *postgresStore) CancelSubscription(ctx context.Context, userID, externalID string, cancelledAt time.Time) (*Subscription, error) {
|
|
normalizedUserID := strings.TrimSpace(userID)
|
|
if normalizedUserID == "" {
|
|
return nil, ErrUserNotFound
|
|
}
|
|
|
|
key := strings.TrimSpace(externalID)
|
|
if key == "" {
|
|
return nil, ErrSubscriptionNotFound
|
|
}
|
|
|
|
const query = `UPDATE subscriptions
|
|
SET status = 'cancelled', cancelled_at = $3, updated_at = now()
|
|
WHERE user_uuid = $1 AND external_id = $2
|
|
RETURNING uuid, provider, payment_method, kind, plan_id, status, payment_qr, meta, created_at, updated_at, cancelled_at`
|
|
|
|
var (
|
|
idValue any
|
|
provider string
|
|
paymentMethod string
|
|
kind string
|
|
planID sql.NullString
|
|
status string
|
|
paymentQR sql.NullString
|
|
metaBytes []byte
|
|
createdAt time.Time
|
|
updatedAt time.Time
|
|
cancelled sql.NullTime
|
|
)
|
|
|
|
err := s.db.QueryRowContext(ctx, query, normalizedUserID, key, cancelledAt.UTC()).Scan(
|
|
&idValue,
|
|
&provider,
|
|
&paymentMethod,
|
|
&kind,
|
|
&planID,
|
|
&status,
|
|
&paymentQR,
|
|
&metaBytes,
|
|
&createdAt,
|
|
&updatedAt,
|
|
&cancelled,
|
|
)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, ErrSubscriptionNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
identifier, err := formatIdentifier(idValue)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
meta, err := decodeSubscriptionMeta(metaBytes)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sub := &Subscription{
|
|
ID: identifier,
|
|
UserID: normalizedUserID,
|
|
Provider: provider,
|
|
PaymentMethod: paymentMethod,
|
|
PaymentQRCode: paymentQR.String,
|
|
Kind: kind,
|
|
PlanID: planID.String,
|
|
ExternalID: key,
|
|
Status: status,
|
|
Meta: meta,
|
|
CreatedAt: createdAt.UTC(),
|
|
UpdatedAt: updatedAt.UTC(),
|
|
}
|
|
if cancelled.Valid {
|
|
sub.CancelledAt = &cancelled.Time
|
|
}
|
|
|
|
return sub, nil
|
|
}
|
|
|
|
func decodeSubscriptionMeta(raw []byte) (map[string]any, error) {
|
|
if len(raw) == 0 {
|
|
return map[string]any{}, nil
|
|
}
|
|
var meta map[string]any
|
|
if err := json.Unmarshal(raw, &meta); err != nil {
|
|
return nil, err
|
|
}
|
|
if meta == nil {
|
|
meta = map[string]any{}
|
|
}
|
|
return meta, nil
|
|
}
|
|
|
|
func nullForEmpty(value string) any {
|
|
if strings.TrimSpace(value) == "" {
|
|
return nil
|
|
}
|
|
return value
|
|
}
|
|
|
|
func toUTCTime(value sql.NullTime) time.Time {
|
|
if !value.Valid {
|
|
return time.Time{}
|
|
}
|
|
return value.Time.UTC()
|
|
}
|
|
|
|
func formatIdentifier(value any) (string, error) {
|
|
switch v := value.(type) {
|
|
case nil:
|
|
return "", errors.New("user id is nil")
|
|
case string:
|
|
return v, nil
|
|
case []byte:
|
|
return string(v), nil
|
|
case [16]byte:
|
|
id := uuid.UUID(v)
|
|
return id.String(), nil
|
|
case *[16]byte:
|
|
if v == nil {
|
|
return "", errors.New("user id is nil")
|
|
}
|
|
id := uuid.UUID(*v)
|
|
return id.String(), nil
|
|
case int64:
|
|
return strconv.FormatInt(v, 10), nil
|
|
case int32:
|
|
return strconv.FormatInt(int64(v), 10), nil
|
|
case int:
|
|
return strconv.FormatInt(int64(v), 10), nil
|
|
case uint64:
|
|
return strconv.FormatUint(v, 10), nil
|
|
case uint32:
|
|
return strconv.FormatUint(uint64(v), 10), nil
|
|
case pgtype.UUID:
|
|
if !v.Valid {
|
|
return "", errors.New("user id is nil")
|
|
}
|
|
return v.String(), nil
|
|
case *pgtype.UUID:
|
|
if v == nil || !v.Valid {
|
|
return "", errors.New("user id is nil")
|
|
}
|
|
return v.String(), nil
|
|
case fmt.Stringer:
|
|
return v.String(), nil
|
|
default:
|
|
return "", fmt.Errorf("unsupported identifier type %T", value)
|
|
}
|
|
}
|
|
|
|
func (s *postgresStore) capabilities(ctx context.Context) (schemaCapabilities, error) {
|
|
s.capsMu.RLock()
|
|
if s.capsLoaded {
|
|
caps := s.caps
|
|
s.capsMu.RUnlock()
|
|
return caps, nil
|
|
}
|
|
s.capsMu.RUnlock()
|
|
|
|
s.capsMu.Lock()
|
|
defer s.capsMu.Unlock()
|
|
if s.capsLoaded {
|
|
return s.caps, nil
|
|
}
|
|
|
|
query := `SELECT
|
|
EXISTS (
|
|
SELECT 1 FROM information_schema.columns
|
|
WHERE table_name = 'users'
|
|
AND table_schema = ANY (current_schemas(false))
|
|
AND column_name = 'mfa_totp_secret'
|
|
) AS has_mfa_totp_secret,
|
|
EXISTS (
|
|
SELECT 1 FROM information_schema.columns
|
|
WHERE table_name = 'users'
|
|
AND table_schema = ANY (current_schemas(false))
|
|
AND column_name = 'mfa_enabled'
|
|
) AS has_mfa_enabled,
|
|
EXISTS (
|
|
SELECT 1 FROM information_schema.columns
|
|
WHERE table_name = 'users'
|
|
AND table_schema = ANY (current_schemas(false))
|
|
AND column_name = 'mfa_secret_issued_at'
|
|
) AS has_mfa_secret_issued_at,
|
|
EXISTS (
|
|
SELECT 1 FROM information_schema.columns
|
|
WHERE table_name = 'users'
|
|
AND table_schema = ANY (current_schemas(false))
|
|
AND column_name = 'mfa_confirmed_at'
|
|
) AS has_mfa_confirmed_at,
|
|
EXISTS (
|
|
SELECT 1 FROM information_schema.columns
|
|
WHERE table_name = 'users'
|
|
AND table_schema = ANY (current_schemas(false))
|
|
AND column_name = 'created_at'
|
|
) AS has_created_at,
|
|
EXISTS (
|
|
SELECT 1 FROM information_schema.columns
|
|
WHERE table_name = 'users'
|
|
AND table_schema = ANY (current_schemas(false))
|
|
AND column_name = 'updated_at'
|
|
) AS has_updated_at,
|
|
EXISTS (
|
|
SELECT 1 FROM information_schema.columns
|
|
WHERE table_name = 'users'
|
|
AND table_schema = ANY (current_schemas(false))
|
|
AND column_name = 'level'
|
|
) AS has_level,
|
|
EXISTS (
|
|
SELECT 1 FROM information_schema.columns
|
|
WHERE table_name = 'users'
|
|
AND table_schema = ANY (current_schemas(false))
|
|
AND column_name = 'role'
|
|
) AS has_role,
|
|
EXISTS (
|
|
SELECT 1 FROM information_schema.columns
|
|
WHERE table_name = 'users'
|
|
AND table_schema = ANY (current_schemas(false))
|
|
AND column_name = 'groups'
|
|
) AS has_groups,
|
|
EXISTS (
|
|
SELECT 1 FROM information_schema.columns
|
|
WHERE table_name = 'users'
|
|
AND table_schema = ANY (current_schemas(false))
|
|
AND column_name = 'permissions'
|
|
) AS has_permissions,
|
|
EXISTS (
|
|
SELECT 1 FROM information_schema.columns
|
|
WHERE table_name = 'users'
|
|
AND table_schema = ANY (current_schemas(false))
|
|
AND column_name = 'active'
|
|
) AS has_active,
|
|
EXISTS (
|
|
SELECT 1 FROM information_schema.columns
|
|
WHERE table_name = 'users'
|
|
AND table_schema = ANY (current_schemas(false))
|
|
AND column_name = 'proxy_uuid'
|
|
) AS has_proxy_uuid,
|
|
EXISTS (
|
|
SELECT 1 FROM information_schema.columns
|
|
WHERE table_name = 'users'
|
|
AND table_schema = ANY (current_schemas(false))
|
|
AND column_name = 'proxy_uuid_expires_at'
|
|
) AS has_proxy_uuid_expires_at`
|
|
|
|
row := s.db.QueryRowContext(ctx, query)
|
|
var caps schemaCapabilities
|
|
if err := row.Scan(
|
|
&caps.hasMFATOTPSecret,
|
|
&caps.hasMFAEnabled,
|
|
&caps.hasMFASecretIssuedAt,
|
|
&caps.hasMFAConfirmedAt,
|
|
&caps.hasCreatedAt,
|
|
&caps.hasUpdatedAt,
|
|
&caps.hasLevel,
|
|
&caps.hasRole,
|
|
&caps.hasGroups,
|
|
&caps.hasPermissions,
|
|
&caps.hasActive,
|
|
&caps.hasProxyUUID,
|
|
&caps.hasProxyUUIDExpiresAt,
|
|
); err != nil {
|
|
return schemaCapabilities{}, err
|
|
}
|
|
|
|
s.caps = caps
|
|
s.capsLoaded = true
|
|
return caps, nil
|
|
}
|
|
|
|
func (s *postgresStore) selectUserQuery(caps schemaCapabilities, whereClause string) string {
|
|
secretExpr := "NULL::text"
|
|
if caps.hasMFATOTPSecret {
|
|
secretExpr = "mfa_totp_secret"
|
|
}
|
|
|
|
enabledExpr := "false"
|
|
if caps.hasMFAEnabled {
|
|
enabledExpr = "coalesce(mfa_enabled, false)"
|
|
}
|
|
|
|
issuedExpr := "NULL::timestamptz"
|
|
if caps.hasMFASecretIssuedAt {
|
|
issuedExpr = "mfa_secret_issued_at"
|
|
}
|
|
|
|
confirmedExpr := "NULL::timestamptz"
|
|
if caps.hasMFAConfirmedAt {
|
|
confirmedExpr = "mfa_confirmed_at"
|
|
}
|
|
|
|
createdExpr := "now()"
|
|
if caps.hasCreatedAt {
|
|
createdExpr = "coalesce(created_at, now())"
|
|
}
|
|
|
|
updatedExpr := "now()"
|
|
if caps.hasUpdatedAt {
|
|
updatedExpr = "coalesce(updated_at, now())"
|
|
}
|
|
|
|
levelExpr := fmt.Sprintf("%d", LevelUser)
|
|
if caps.hasLevel {
|
|
levelExpr = fmt.Sprintf("coalesce(level, %d)", LevelUser)
|
|
}
|
|
|
|
roleExpr := fmt.Sprintf("'%s'", RoleUser)
|
|
if caps.hasRole {
|
|
roleExpr = fmt.Sprintf("coalesce(role, '%s')", RoleUser)
|
|
}
|
|
|
|
groupsExpr := "'[]'::jsonb"
|
|
if caps.hasGroups {
|
|
groupsExpr = "coalesce(groups, '[]'::jsonb)"
|
|
}
|
|
|
|
permissionsExpr := "'[]'::jsonb"
|
|
if caps.hasPermissions {
|
|
permissionsExpr = "coalesce(permissions, '[]'::jsonb)"
|
|
}
|
|
|
|
activeExpr := "true"
|
|
if caps.hasActive {
|
|
activeExpr = "coalesce(active, true)"
|
|
}
|
|
|
|
proxyUUIDExpr := "NULL::uuid"
|
|
if caps.hasProxyUUID {
|
|
proxyUUIDExpr = "proxy_uuid"
|
|
}
|
|
|
|
proxyExpiresAtExpr := "NULL::timestamptz"
|
|
if caps.hasProxyUUIDExpiresAt {
|
|
proxyExpiresAtExpr = "proxy_uuid_expires_at"
|
|
}
|
|
|
|
return fmt.Sprintf(`SELECT uuid, username, email, email_verified, password, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s FROM users %s`,
|
|
secretExpr, enabledExpr, issuedExpr, confirmedExpr, createdExpr, updatedExpr, levelExpr, roleExpr, groupsExpr, permissionsExpr, activeExpr, proxyUUIDExpr, proxyExpiresAtExpr, whereClause)
|
|
}
|
|
|
|
func encodeStringSlice(values []string) ([]byte, error) {
|
|
normalized := normalizeStringSlice(values)
|
|
if len(normalized) == 0 {
|
|
return []byte("[]"), nil
|
|
}
|
|
return json.Marshal(normalized)
|
|
}
|
|
|
|
func decodeStringSlice(raw []byte) []string {
|
|
if len(raw) == 0 {
|
|
return nil
|
|
}
|
|
var values []string
|
|
if err := json.Unmarshal(raw, &values); err != nil {
|
|
return nil
|
|
}
|
|
return normalizeStringSlice(values)
|
|
}
|
|
func (s *postgresStore) CreateIdentity(ctx context.Context, identity *Identity) error {
|
|
if identity == nil {
|
|
return errors.New("identity is required")
|
|
}
|
|
|
|
normalizedUserID := strings.TrimSpace(identity.UserID)
|
|
if normalizedUserID == "" {
|
|
return ErrUserNotFound
|
|
}
|
|
|
|
provider := strings.TrimSpace(identity.Provider)
|
|
externalID := strings.TrimSpace(identity.ExternalID)
|
|
if provider == "" || externalID == "" {
|
|
return errors.New("provider and external_id are required")
|
|
}
|
|
|
|
const query = `INSERT INTO identities (user_uuid, provider, external_id)
|
|
VALUES ($1, $2, $3)
|
|
RETURNING uuid, created_at, updated_at`
|
|
|
|
var (
|
|
idValue any
|
|
createdAt time.Time
|
|
updatedAt time.Time
|
|
)
|
|
|
|
err := s.db.QueryRowContext(ctx, query, normalizedUserID, provider, externalID).Scan(&idValue, &createdAt, &updatedAt)
|
|
if err != nil {
|
|
var pgErr *pgconn.PgError
|
|
if errors.As(err, &pgErr) {
|
|
if pgErr.Code == "23505" { // unique_violation
|
|
return errors.New("identity already exists")
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
|
|
identifier, err := formatIdentifier(idValue)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
identity.ID = identifier
|
|
identity.CreatedAt = createdAt.UTC()
|
|
identity.UpdatedAt = updatedAt.UTC()
|
|
|
|
return nil
|
|
}
|
|
|
|
// ListUsers returns all users from the postgres store.
|
|
func (s *postgresStore) ListUsers(ctx context.Context) ([]User, error) {
|
|
caps, err := s.capabilities(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
query := s.selectUserQuery(caps, "ORDER BY created_at ASC")
|
|
|
|
rows, err := s.db.QueryContext(ctx, query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var users []User
|
|
for rows.Next() {
|
|
user, err := scanUser(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
users = append(users, *user)
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return users, nil
|
|
}
|
|
|
|
func (s *postgresStore) DeleteUser(ctx context.Context, id string) error {
|
|
const query = "DELETE FROM users WHERE uuid = $1"
|
|
_, err := s.db.ExecContext(ctx, query, id)
|
|
return err
|
|
}
|
|
|
|
func (s *postgresStore) AddToBlacklist(ctx context.Context, email string) error {
|
|
const query = "INSERT INTO email_blacklist (email) VALUES ($1) ON CONFLICT (email) DO NOTHING"
|
|
_, err := s.db.ExecContext(ctx, query, strings.ToLower(email))
|
|
return err
|
|
}
|
|
|
|
func (s *postgresStore) RemoveFromBlacklist(ctx context.Context, email string) error {
|
|
const query = "DELETE FROM email_blacklist WHERE email = $1"
|
|
_, err := s.db.ExecContext(ctx, query, strings.ToLower(email))
|
|
return err
|
|
}
|
|
|
|
func (s *postgresStore) IsBlacklisted(ctx context.Context, email string) (bool, error) {
|
|
const query = "SELECT EXISTS(SELECT 1 FROM email_blacklist WHERE email = $1)"
|
|
var exists bool
|
|
err := s.db.QueryRowContext(ctx, query, strings.ToLower(email)).Scan(&exists)
|
|
return exists, err
|
|
}
|
|
|
|
func (s *postgresStore) ListBlacklist(ctx context.Context) ([]string, error) {
|
|
const query = "SELECT email FROM email_blacklist ORDER BY created_at DESC"
|
|
rows, err := s.db.QueryContext(ctx, query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var emails []string
|
|
for rows.Next() {
|
|
var email string
|
|
if err := rows.Scan(&email); err != nil {
|
|
return nil, err
|
|
}
|
|
emails = append(emails, email)
|
|
}
|
|
return emails, nil
|
|
}
|