feat(account): add merge-aware importer with dry-run support (#507)
This commit is contained in:
parent
1c4adf2ce6
commit
e2ebbe2b19
@ -188,8 +188,13 @@ account-export:
|
||||
@go run ./cmd/migratectl/main.go export --dsn "$(DB_URL)" --output "$(ACCOUNT_EXPORT_FILE)" $(if $(ACCOUNT_EMAIL_KEYWORD),--email "$(ACCOUNT_EMAIL_KEYWORD)")
|
||||
|
||||
account-import:
|
||||
@[ -f "$(ACCOUNT_IMPORT_FILE)" ] || (echo "❌ 未找到文件 $(ACCOUNT_IMPORT_FILE)"; exit 1)
|
||||
@go run ./cmd/migratectl/main.go import --dsn "$(DB_URL)" --file "$(ACCOUNT_IMPORT_FILE)"
|
||||
@[ -f "$(ACCOUNT_IMPORT_FILE)" ] || (echo "❌ 未找到文件 $(ACCOUNT_IMPORT_FILE)"; exit 1)
|
||||
@go run ./cmd/migratectl/main.go import --dsn "$(DB_URL)" --file "$(ACCOUNT_IMPORT_FILE)" \
|
||||
$(if $(ACCOUNT_IMPORT_MERGE),--merge) \
|
||||
$(if $(ACCOUNT_IMPORT_MERGE_STRATEGY),--merge-strategy "$(ACCOUNT_IMPORT_MERGE_STRATEGY)") \
|
||||
$(if $(ACCOUNT_IMPORT_DRY_RUN),--dry-run) \
|
||||
$(foreach UUID,$(ACCOUNT_IMPORT_MERGE_ALLOWLIST),--merge-allowlist $(UUID)) \
|
||||
$(ACCOUNT_IMPORT_EXTRA_FLAGS)
|
||||
|
||||
create-super-admin:
|
||||
@[ -n "$(SUPERADMIN_USERNAME)" ] && [ -n "$(SUPERADMIN_PASSWORD)" ] || (echo "❌ 请指定用户名与密码"; exit 1)
|
||||
|
||||
@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
@ -250,9 +251,13 @@ func newExportCmd() *cobra.Command {
|
||||
|
||||
func newImportCmd() *cobra.Command {
|
||||
var (
|
||||
dsn string
|
||||
file string
|
||||
timeout time.Duration
|
||||
dsn string
|
||||
file string
|
||||
timeout time.Duration
|
||||
merge bool
|
||||
mergeStrategy string
|
||||
dryRun bool
|
||||
mergeAllowlist []string
|
||||
)
|
||||
|
||||
timeout = 5 * time.Minute
|
||||
@ -288,14 +293,49 @@ func newImportCmd() *cobra.Command {
|
||||
}
|
||||
|
||||
importer := migrate.NewImporter()
|
||||
allowlist := map[string]struct{}{}
|
||||
for _, id := range mergeAllowlist {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
allowlist[id] = struct{}{}
|
||||
}
|
||||
if len(allowlist) == 0 {
|
||||
allowlist = nil
|
||||
}
|
||||
if !merge {
|
||||
if mergeStrategy != "" {
|
||||
return errors.New("--merge-strategy requires --merge")
|
||||
}
|
||||
if len(mergeAllowlist) > 0 {
|
||||
return errors.New("--merge-allowlist requires --merge")
|
||||
}
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(cmd.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
if err := importer.Import(ctx, dsn, &dump); err != nil {
|
||||
report, err := importer.Import(ctx, dsn, &dump, migrate.ImportOptions{
|
||||
Merge: merge,
|
||||
MergeStrategy: migrate.MergeStrategy(mergeStrategy),
|
||||
DryRun: dryRun,
|
||||
Allowlist: allowlist,
|
||||
LogWriter: cmd.ErrOrStderr(),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Fprintf(cmd.OutOrStdout(), "Imported %d users\n", len(dump.Users))
|
||||
summaryTarget := "applied"
|
||||
if dryRun {
|
||||
summaryTarget = "preview"
|
||||
}
|
||||
fmt.Fprintf(cmd.OutOrStdout(), "Import %s: users inserted=%d updated=%d skipped=%d\n", summaryTarget, report.UsersInserted, report.UsersUpdated, report.UsersSkipped)
|
||||
fmt.Fprintf(cmd.OutOrStdout(), "Identities inserted=%d updated=%d deleted=%d\n", report.IdentitiesInserted, report.IdentitiesUpdated, report.IdentitiesDeleted)
|
||||
fmt.Fprintf(cmd.OutOrStdout(), "Sessions inserted=%d updated=%d deleted=%d\n", report.SessionsInserted, report.SessionsUpdated, report.SessionsDeleted)
|
||||
if report.ConflictsResolved > 0 || report.ConflictsSkipped > 0 {
|
||||
fmt.Fprintf(cmd.OutOrStdout(), "Conflicts resolved=%d skipped=%d\n", report.ConflictsResolved, report.ConflictsSkipped)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
@ -303,6 +343,10 @@ func newImportCmd() *cobra.Command {
|
||||
cmd.Flags().StringVar(&dsn, "dsn", "", "PostgreSQL connection string")
|
||||
cmd.Flags().StringVar(&file, "file", "", "YAML file path or '-' for stdin")
|
||||
cmd.Flags().DurationVar(&timeout, "timeout", timeout, "Import operation timeout")
|
||||
cmd.Flags().BoolVar(&merge, "merge", false, "Enable additive merge behaviour")
|
||||
cmd.Flags().StringVar(&mergeStrategy, "merge-strategy", "", "Merge strategy (replace, append, timestamp)")
|
||||
cmd.Flags().BoolVar(&dryRun, "dry-run", false, "Preview the import without applying changes")
|
||||
cmd.Flags().StringSliceVar(&mergeAllowlist, "merge-allowlist", nil, "User UUIDs allowed to merge (comma-separated or repeated)")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
33
account/internal/migrate/snapshot.go
Normal file
33
account/internal/migrate/snapshot.go
Normal file
@ -0,0 +1,33 @@
|
||||
package migrate
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
accountschema "xcontrol/account/sql"
|
||||
)
|
||||
|
||||
// SnapshotVersion identifies the canonical format of exported account snapshots.
|
||||
const SnapshotVersion = "v1"
|
||||
|
||||
// SnapshotMetadata captures provenance information for account snapshots.
|
||||
type SnapshotMetadata struct {
|
||||
Version string `yaml:"version"`
|
||||
SchemaHash string `yaml:"schemaHash"`
|
||||
ExportedAt time.Time `yaml:"exportedAt"`
|
||||
}
|
||||
|
||||
// validateSnapshotMetadata ensures the provided metadata matches the expected
|
||||
// snapshot format and schema hash.
|
||||
func validateSnapshotMetadata(meta *SnapshotMetadata) error {
|
||||
if meta == nil {
|
||||
return errors.New("snapshot metadata missing (expected version and schema hash)")
|
||||
}
|
||||
if meta.Version != SnapshotVersion {
|
||||
return errors.New("snapshot version mismatch")
|
||||
}
|
||||
if meta.SchemaHash != accountschema.Hash() {
|
||||
return errors.New("snapshot schema hash mismatch")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -6,16 +6,21 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
accountschema "xcontrol/account/sql"
|
||||
)
|
||||
|
||||
// AccountDump represents the serialized snapshot of account-related tables.
|
||||
type AccountDump struct {
|
||||
Users []UserRecord `yaml:"users"`
|
||||
Identities []IdentityRecord `yaml:"identities,omitempty"`
|
||||
Sessions []SessionRecord `yaml:"sessions,omitempty"`
|
||||
Metadata *SnapshotMetadata `yaml:"metadata,omitempty"`
|
||||
Users []UserRecord `yaml:"users"`
|
||||
Identities []IdentityRecord `yaml:"identities,omitempty"`
|
||||
Sessions []SessionRecord `yaml:"sessions,omitempty"`
|
||||
}
|
||||
|
||||
// UserRecord captures the exported representation of a user row.
|
||||
@ -58,6 +63,48 @@ type SessionRecord struct {
|
||||
UpdatedAt *time.Time `yaml:"updatedAt,omitempty"`
|
||||
}
|
||||
|
||||
// MergeStrategy defines how snapshot data should be reconciled with the target database.
|
||||
type MergeStrategy string
|
||||
|
||||
const (
|
||||
// MergeStrategyReplace preserves the legacy behaviour where incoming records
|
||||
// fully replace existing ones.
|
||||
MergeStrategyReplace MergeStrategy = "replace"
|
||||
// MergeStrategyAppend performs additive merges, keeping existing data that is
|
||||
// absent from the snapshot.
|
||||
MergeStrategyAppend MergeStrategy = "append"
|
||||
// MergeStrategyTimestamp resolves conflicts by preferring rows with the newest
|
||||
// updated_at timestamp.
|
||||
MergeStrategyTimestamp MergeStrategy = "timestamp"
|
||||
)
|
||||
|
||||
// ImportOptions configures how snapshot imports should be applied.
|
||||
type ImportOptions struct {
|
||||
Merge bool
|
||||
MergeStrategy MergeStrategy
|
||||
DryRun bool
|
||||
Allowlist map[string]struct{}
|
||||
LogWriter io.Writer
|
||||
}
|
||||
|
||||
// ImportReport captures the outcome of an import (or dry-run) execution.
|
||||
type ImportReport struct {
|
||||
UsersInserted int
|
||||
UsersUpdated int
|
||||
UsersSkipped int
|
||||
|
||||
IdentitiesInserted int
|
||||
IdentitiesUpdated int
|
||||
IdentitiesDeleted int
|
||||
|
||||
SessionsInserted int
|
||||
SessionsUpdated int
|
||||
SessionsDeleted int
|
||||
|
||||
ConflictsResolved int
|
||||
ConflictsSkipped int
|
||||
}
|
||||
|
||||
// Exporter reads account data from a PostgreSQL database.
|
||||
type Exporter struct{}
|
||||
|
||||
@ -75,7 +122,13 @@ func (e *Exporter) Export(ctx context.Context, dsn, emailKeyword string) (*Accou
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
dump := &AccountDump{}
|
||||
dump := &AccountDump{
|
||||
Metadata: &SnapshotMetadata{
|
||||
Version: SnapshotVersion,
|
||||
SchemaHash: accountschema.Hash(),
|
||||
ExportedAt: time.Now().UTC(),
|
||||
},
|
||||
}
|
||||
|
||||
users, err := loadUsers(ctx, db, emailKeyword)
|
||||
if err != nil {
|
||||
@ -115,69 +168,288 @@ func NewImporter() *Importer {
|
||||
return &Importer{}
|
||||
}
|
||||
|
||||
// Import restores account data from a dump into the target database. Existing
|
||||
// rows are replaced on conflict and related identities/sessions are refreshed.
|
||||
func (i *Importer) Import(ctx context.Context, dsn string, dump *AccountDump) error {
|
||||
// Import restores account data from a dump into the target database using the
|
||||
// provided options. When merge mode is disabled the behaviour mirrors the
|
||||
// legacy implementation.
|
||||
func (i *Importer) Import(ctx context.Context, dsn string, dump *AccountDump, opts ImportOptions) (*ImportReport, error) {
|
||||
if dump == nil {
|
||||
return errors.New("dump is nil")
|
||||
return nil, errors.New("dump is nil")
|
||||
}
|
||||
if err := validateSnapshotMetadata(dump.Metadata); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logWriter := opts.LogWriter
|
||||
if logWriter == nil {
|
||||
logWriter = io.Discard
|
||||
}
|
||||
logf := func(format string, args ...any) {
|
||||
fmt.Fprintf(logWriter, format, args...)
|
||||
}
|
||||
|
||||
strategy := opts.MergeStrategy
|
||||
if strategy == "" {
|
||||
if opts.Merge {
|
||||
strategy = MergeStrategyAppend
|
||||
} else {
|
||||
strategy = MergeStrategyReplace
|
||||
}
|
||||
}
|
||||
switch strategy {
|
||||
case MergeStrategyReplace, MergeStrategyAppend, MergeStrategyTimestamp:
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported merge strategy %q", strategy)
|
||||
}
|
||||
if !opts.Merge {
|
||||
strategy = MergeStrategyReplace
|
||||
}
|
||||
db, err := openDB(ctx, dsn)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
identityCaps, err := tableColumnCaps(ctx, db, "identities")
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
sessionCaps, err := tableColumnCaps(ctx, db, "sessions")
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userUUIDs := make([]string, 0, len(dump.Users))
|
||||
for _, user := range dump.Users {
|
||||
userUUIDs = append(userUUIDs, user.UUID)
|
||||
}
|
||||
|
||||
existingUsers, err := loadUsersByUUIDs(ctx, db, userUUIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existingIdentitiesSlice, err := loadIdentities(ctx, db, userUUIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
existingSessionsSlice, err := loadSessions(ctx, db, userUUIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existingIdentitiesByUUID := make(map[string]IdentityRecord, len(existingIdentitiesSlice))
|
||||
existingIdentitiesByUser := make(map[string][]IdentityRecord)
|
||||
for _, identity := range existingIdentitiesSlice {
|
||||
existingIdentitiesByUUID[identity.UUID] = identity
|
||||
existingIdentitiesByUser[identity.UserUUID] = append(existingIdentitiesByUser[identity.UserUUID], identity)
|
||||
}
|
||||
|
||||
existingSessionsByUUID := make(map[string]SessionRecord, len(existingSessionsSlice))
|
||||
existingSessionsByUser := make(map[string][]SessionRecord)
|
||||
for _, session := range existingSessionsSlice {
|
||||
existingSessionsByUUID[session.UUID] = session
|
||||
existingSessionsByUser[session.UserUUID] = append(existingSessionsByUser[session.UserUUID], session)
|
||||
}
|
||||
|
||||
incomingIdentitiesByUser := make(map[string][]IdentityRecord)
|
||||
for _, identity := range dump.Identities {
|
||||
incomingIdentitiesByUser[identity.UserUUID] = append(incomingIdentitiesByUser[identity.UserUUID], identity)
|
||||
}
|
||||
|
||||
incomingSessionsByUser := make(map[string][]SessionRecord)
|
||||
for _, session := range dump.Sessions {
|
||||
incomingSessionsByUser[session.UserUUID] = append(incomingSessionsByUser[session.UserUUID], session)
|
||||
}
|
||||
|
||||
tx, err := db.BeginTx(ctx, &sql.TxOptions{})
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
committed := false
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if !committed {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
cleared := make(map[string]struct{})
|
||||
report := &ImportReport{}
|
||||
|
||||
allowlist := opts.Allowlist
|
||||
allowlistEnabled := opts.Merge && len(allowlist) > 0
|
||||
|
||||
for _, user := range dump.Users {
|
||||
if err = upsertUser(ctx, tx, &user); err != nil {
|
||||
return err
|
||||
if allowlistEnabled {
|
||||
if _, ok := allowlist[user.UUID]; !ok {
|
||||
report.UsersSkipped++
|
||||
logf("skip user %s: not present in merge allowlist\n", user.UUID)
|
||||
continue
|
||||
}
|
||||
}
|
||||
cleared[user.UUID] = struct{}{}
|
||||
}
|
||||
|
||||
for uuid := range cleared {
|
||||
if _, err = tx.ExecContext(ctx, `DELETE FROM identities WHERE user_uuid = $1`, uuid); err != nil {
|
||||
return err
|
||||
existing, hasExisting := existingUsers[user.UUID]
|
||||
|
||||
if opts.Merge && hasExisting && strategy == MergeStrategyTimestamp && existing.UpdatedAt.After(user.UpdatedAt) {
|
||||
report.UsersSkipped++
|
||||
report.ConflictsSkipped++
|
||||
logf("skip user %s: existing updated_at %s newer than snapshot %s\n", user.UUID, existing.UpdatedAt.Format(time.RFC3339), user.UpdatedAt.Format(time.RFC3339))
|
||||
continue
|
||||
}
|
||||
if _, err = tx.ExecContext(ctx, `DELETE FROM sessions WHERE user_uuid = $1`, uuid); err != nil {
|
||||
return err
|
||||
|
||||
mergedUser, changed := mergeUserRecord(user, existing, opts.Merge, hasExisting)
|
||||
|
||||
if !hasExisting {
|
||||
report.UsersInserted++
|
||||
} else if changed {
|
||||
report.UsersUpdated++
|
||||
if opts.Merge && strategy == MergeStrategyTimestamp {
|
||||
report.ConflictsResolved++
|
||||
}
|
||||
} else {
|
||||
report.UsersSkipped++
|
||||
}
|
||||
|
||||
if changed && !opts.DryRun {
|
||||
if err := upsertUser(ctx, tx, &mergedUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
existingUsers[user.UUID] = mergedUser
|
||||
|
||||
incomingIdentities := incomingIdentitiesByUser[user.UUID]
|
||||
incomingSessions := incomingSessionsByUser[user.UUID]
|
||||
|
||||
if !opts.Merge || strategy == MergeStrategyReplace {
|
||||
if existingCount := len(existingIdentitiesByUser[user.UUID]); existingCount > 0 {
|
||||
report.IdentitiesDeleted += existingCount
|
||||
if !opts.DryRun {
|
||||
if _, err := tx.ExecContext(ctx, `DELETE FROM identities WHERE user_uuid = $1`, user.UUID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
if existingCount := len(existingSessionsByUser[user.UUID]); existingCount > 0 {
|
||||
report.SessionsDeleted += existingCount
|
||||
if !opts.DryRun {
|
||||
if _, err := tx.ExecContext(ctx, `DELETE FROM sessions WHERE user_uuid = $1`, user.UUID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, identity := range incomingIdentities {
|
||||
if _, ok := existingIdentitiesByUUID[identity.UUID]; ok {
|
||||
report.IdentitiesUpdated++
|
||||
} else {
|
||||
report.IdentitiesInserted++
|
||||
}
|
||||
if !opts.DryRun {
|
||||
if err := upsertIdentity(ctx, tx, &identity, identityCaps); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, session := range incomingSessions {
|
||||
if _, ok := existingSessionsByUUID[session.UUID]; ok {
|
||||
report.SessionsUpdated++
|
||||
} else {
|
||||
report.SessionsInserted++
|
||||
}
|
||||
if !opts.DryRun {
|
||||
if err := upsertSession(ctx, tx, &session, sessionCaps); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// Merge mode (append/timestamp) for identities.
|
||||
for _, identity := range incomingIdentities {
|
||||
existingIdentity, ok := existingIdentitiesByUUID[identity.UUID]
|
||||
if !ok {
|
||||
report.IdentitiesInserted++
|
||||
if !opts.DryRun {
|
||||
if err := upsertIdentity(ctx, tx, &identity, identityCaps); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if strategy == MergeStrategyTimestamp && preferExistingIdentity(existingIdentity, identity) {
|
||||
report.ConflictsSkipped++
|
||||
logf("retain identity %s for user %s: existing updated_at preferred\n", identity.UUID, identity.UserUUID)
|
||||
continue
|
||||
}
|
||||
|
||||
if identityDiffers(identity, existingIdentity) {
|
||||
report.IdentitiesUpdated++
|
||||
if strategy == MergeStrategyTimestamp {
|
||||
report.ConflictsResolved++
|
||||
}
|
||||
if !opts.DryRun {
|
||||
if err := upsertIdentity(ctx, tx, &identity, identityCaps); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, session := range incomingSessions {
|
||||
existingSession, ok := existingSessionsByUUID[session.UUID]
|
||||
if !ok {
|
||||
report.SessionsInserted++
|
||||
if !opts.DryRun {
|
||||
if err := upsertSession(ctx, tx, &session, sessionCaps); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if strategy == MergeStrategyTimestamp && preferExistingSession(existingSession, session) {
|
||||
report.ConflictsSkipped++
|
||||
logf("retain session %s for user %s: existing updated_at preferred\n", session.UUID, session.UserUUID)
|
||||
continue
|
||||
}
|
||||
|
||||
if sessionDiffers(session, existingSession) {
|
||||
report.SessionsUpdated++
|
||||
if strategy == MergeStrategyTimestamp {
|
||||
report.ConflictsResolved++
|
||||
}
|
||||
if !opts.DryRun {
|
||||
if err := upsertSession(ctx, tx, &session, sessionCaps); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, identity := range dump.Identities {
|
||||
if err = upsertIdentity(ctx, tx, &identity, identityCaps); err != nil {
|
||||
return err
|
||||
if opts.DryRun {
|
||||
if err := tx.Rollback(); err != nil && !errors.Is(err, sql.ErrTxDone) {
|
||||
return nil, err
|
||||
}
|
||||
committed = true
|
||||
logf("dry-run complete: no changes applied\n")
|
||||
return report, nil
|
||||
}
|
||||
|
||||
for _, session := range dump.Sessions {
|
||||
if err = upsertSession(ctx, tx, &session, sessionCaps); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
committed = true
|
||||
return report, nil
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
const userSelectColumns = `uuid, username, password, email, email_verified, email_verified_at, level, role, groups, permissions, created_at, updated_at, mfa_totp_secret, mfa_enabled, mfa_secret_issued_at, mfa_confirmed_at`
|
||||
|
||||
type rowScanner interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
func loadUsers(ctx context.Context, db *sql.DB, emailKeyword string) ([]UserRecord, error) {
|
||||
@ -186,7 +458,9 @@ func loadUsers(ctx context.Context, db *sql.DB, emailKeyword string) ([]UserReco
|
||||
args []any
|
||||
)
|
||||
|
||||
query.WriteString(`SELECT uuid, username, password, email, email_verified, email_verified_at, level, role, groups, permissions, created_at, updated_at, mfa_totp_secret, mfa_enabled, mfa_secret_issued_at, mfa_confirmed_at FROM users`)
|
||||
query.WriteString("SELECT ")
|
||||
query.WriteString(userSelectColumns)
|
||||
query.WriteString(" FROM users")
|
||||
if keyword := strings.TrimSpace(emailKeyword); keyword != "" {
|
||||
query.WriteString(` WHERE email ILIKE $1`)
|
||||
args = append(args, "%"+keyword+"%")
|
||||
@ -201,91 +475,10 @@ func loadUsers(ctx context.Context, db *sql.DB, emailKeyword string) ([]UserReco
|
||||
|
||||
var users []UserRecord
|
||||
for rows.Next() {
|
||||
var (
|
||||
email sql.NullString
|
||||
emailVerified bool
|
||||
emailVerifiedAt sql.NullTime
|
||||
level sql.NullInt64
|
||||
role sql.NullString
|
||||
groupsRaw []byte
|
||||
permissionsRaw []byte
|
||||
createdAt time.Time
|
||||
updatedAt time.Time
|
||||
mfaSecret sql.NullString
|
||||
mfaEnabled sql.NullBool
|
||||
mfaIssuedAt sql.NullTime
|
||||
mfaConfirmedAt sql.NullTime
|
||||
user UserRecord
|
||||
)
|
||||
|
||||
if err := rows.Scan(
|
||||
&user.UUID,
|
||||
&user.Username,
|
||||
&user.PasswordHash,
|
||||
&email,
|
||||
&emailVerified,
|
||||
&emailVerifiedAt,
|
||||
&level,
|
||||
&role,
|
||||
&groupsRaw,
|
||||
&permissionsRaw,
|
||||
&createdAt,
|
||||
&updatedAt,
|
||||
&mfaSecret,
|
||||
&mfaEnabled,
|
||||
&mfaIssuedAt,
|
||||
&mfaConfirmedAt,
|
||||
); err != nil {
|
||||
user, err := scanUserRow(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if email.Valid {
|
||||
user.Email = email.String
|
||||
}
|
||||
user.EmailVerified = emailVerified
|
||||
if emailVerifiedAt.Valid {
|
||||
ts := emailVerifiedAt.Time
|
||||
user.EmailVerifiedAt = &ts
|
||||
}
|
||||
if level.Valid {
|
||||
user.Level = int(level.Int64)
|
||||
}
|
||||
if role.Valid {
|
||||
user.Role = role.String
|
||||
}
|
||||
if len(groupsRaw) > 0 {
|
||||
if err := json.Unmarshal(groupsRaw, &user.Groups); err != nil {
|
||||
return nil, fmt.Errorf("decode groups for user %s: %w", user.UUID, err)
|
||||
}
|
||||
}
|
||||
if len(permissionsRaw) > 0 {
|
||||
if err := json.Unmarshal(permissionsRaw, &user.Permissions); err != nil {
|
||||
return nil, fmt.Errorf("decode permissions for user %s: %w", user.UUID, err)
|
||||
}
|
||||
}
|
||||
user.CreatedAt = createdAt
|
||||
user.UpdatedAt = updatedAt
|
||||
if mfaSecret.Valid {
|
||||
user.MFATOTPSecret = mfaSecret.String
|
||||
}
|
||||
user.MFAEnabled = mfaEnabled.Bool
|
||||
if mfaIssuedAt.Valid {
|
||||
ts := mfaIssuedAt.Time
|
||||
user.MFASecretIssuedAt = &ts
|
||||
}
|
||||
if mfaConfirmedAt.Valid {
|
||||
ts := mfaConfirmedAt.Time
|
||||
user.MFAConfirmedAt = &ts
|
||||
}
|
||||
if user.Groups == nil {
|
||||
user.Groups = []string{}
|
||||
}
|
||||
if user.Permissions == nil {
|
||||
user.Permissions = []string{}
|
||||
}
|
||||
if user.Role == "" {
|
||||
user.Role = "user"
|
||||
}
|
||||
users = append(users, user)
|
||||
}
|
||||
|
||||
@ -296,6 +489,320 @@ func loadUsers(ctx context.Context, db *sql.DB, emailKeyword string) ([]UserReco
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func loadUsersByUUIDs(ctx context.Context, db *sql.DB, uuids []string) (map[string]UserRecord, error) {
|
||||
users := make(map[string]UserRecord, len(uuids))
|
||||
if len(uuids) == 0 {
|
||||
return users, nil
|
||||
}
|
||||
|
||||
queryTemplate := fmt.Sprintf("SELECT %s FROM users WHERE uuid IN (%%s)", userSelectColumns)
|
||||
query, args := buildInQuery(queryTemplate, uuids)
|
||||
rows, err := db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
user, err := scanUserRow(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users[user.UUID] = user
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func scanUserRow(scanner rowScanner) (UserRecord, error) {
|
||||
var (
|
||||
email sql.NullString
|
||||
emailVerified bool
|
||||
emailVerifiedAt sql.NullTime
|
||||
level sql.NullInt64
|
||||
role sql.NullString
|
||||
groupsRaw []byte
|
||||
permissionsRaw []byte
|
||||
createdAt time.Time
|
||||
updatedAt time.Time
|
||||
mfaSecret sql.NullString
|
||||
mfaEnabled sql.NullBool
|
||||
mfaIssuedAt sql.NullTime
|
||||
mfaConfirmedAt sql.NullTime
|
||||
user UserRecord
|
||||
)
|
||||
|
||||
if err := scanner.Scan(
|
||||
&user.UUID,
|
||||
&user.Username,
|
||||
&user.PasswordHash,
|
||||
&email,
|
||||
&emailVerified,
|
||||
&emailVerifiedAt,
|
||||
&level,
|
||||
&role,
|
||||
&groupsRaw,
|
||||
&permissionsRaw,
|
||||
&createdAt,
|
||||
&updatedAt,
|
||||
&mfaSecret,
|
||||
&mfaEnabled,
|
||||
&mfaIssuedAt,
|
||||
&mfaConfirmedAt,
|
||||
); err != nil {
|
||||
return UserRecord{}, err
|
||||
}
|
||||
|
||||
if email.Valid {
|
||||
user.Email = email.String
|
||||
}
|
||||
user.EmailVerified = emailVerified
|
||||
if emailVerifiedAt.Valid {
|
||||
ts := emailVerifiedAt.Time
|
||||
user.EmailVerifiedAt = &ts
|
||||
}
|
||||
if level.Valid {
|
||||
user.Level = int(level.Int64)
|
||||
}
|
||||
if role.Valid {
|
||||
user.Role = role.String
|
||||
}
|
||||
if len(groupsRaw) > 0 {
|
||||
if err := json.Unmarshal(groupsRaw, &user.Groups); err != nil {
|
||||
return UserRecord{}, fmt.Errorf("decode groups for user %s: %w", user.UUID, err)
|
||||
}
|
||||
}
|
||||
if len(permissionsRaw) > 0 {
|
||||
if err := json.Unmarshal(permissionsRaw, &user.Permissions); err != nil {
|
||||
return UserRecord{}, fmt.Errorf("decode permissions for user %s: %w", user.UUID, err)
|
||||
}
|
||||
}
|
||||
user.CreatedAt = createdAt
|
||||
user.UpdatedAt = updatedAt
|
||||
if mfaSecret.Valid {
|
||||
user.MFATOTPSecret = mfaSecret.String
|
||||
}
|
||||
user.MFAEnabled = mfaEnabled.Bool
|
||||
if mfaIssuedAt.Valid {
|
||||
ts := mfaIssuedAt.Time
|
||||
user.MFASecretIssuedAt = &ts
|
||||
}
|
||||
if mfaConfirmedAt.Valid {
|
||||
ts := mfaConfirmedAt.Time
|
||||
user.MFAConfirmedAt = &ts
|
||||
}
|
||||
|
||||
ensureUserDefaults(&user)
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func ensureUserDefaults(user *UserRecord) {
|
||||
if user.Groups == nil {
|
||||
user.Groups = []string{}
|
||||
}
|
||||
if user.Permissions == nil {
|
||||
user.Permissions = []string{}
|
||||
}
|
||||
if user.Role == "" {
|
||||
user.Role = "user"
|
||||
}
|
||||
}
|
||||
|
||||
func mergeUserRecord(incoming UserRecord, existing UserRecord, merge bool, hasExisting bool) (UserRecord, bool) {
|
||||
ensureUserDefaults(&incoming)
|
||||
|
||||
if !hasExisting {
|
||||
return incoming, true
|
||||
}
|
||||
|
||||
if merge {
|
||||
if incoming.Email == "" {
|
||||
incoming.Email = existing.Email
|
||||
}
|
||||
if incoming.EmailVerifiedAt == nil {
|
||||
incoming.EmailVerifiedAt = cloneTimePtr(existing.EmailVerifiedAt)
|
||||
}
|
||||
if len(incoming.Groups) == 0 && len(existing.Groups) > 0 {
|
||||
incoming.Groups = append([]string(nil), existing.Groups...)
|
||||
}
|
||||
if len(incoming.Permissions) == 0 && len(existing.Permissions) > 0 {
|
||||
incoming.Permissions = append([]string(nil), existing.Permissions...)
|
||||
}
|
||||
if incoming.Role == "" {
|
||||
incoming.Role = existing.Role
|
||||
}
|
||||
if incoming.MFATOTPSecret == "" {
|
||||
incoming.MFATOTPSecret = existing.MFATOTPSecret
|
||||
}
|
||||
if incoming.MFASecretIssuedAt == nil {
|
||||
incoming.MFASecretIssuedAt = cloneTimePtr(existing.MFASecretIssuedAt)
|
||||
}
|
||||
if incoming.MFAConfirmedAt == nil {
|
||||
incoming.MFAConfirmedAt = cloneTimePtr(existing.MFAConfirmedAt)
|
||||
}
|
||||
}
|
||||
|
||||
changed := userDiffers(incoming, existing)
|
||||
return incoming, changed
|
||||
}
|
||||
|
||||
func userDiffers(a, b UserRecord) bool {
|
||||
if a.Username != b.Username {
|
||||
return true
|
||||
}
|
||||
if a.PasswordHash != b.PasswordHash {
|
||||
return true
|
||||
}
|
||||
if a.Email != b.Email {
|
||||
return true
|
||||
}
|
||||
if a.EmailVerified != b.EmailVerified {
|
||||
return true
|
||||
}
|
||||
if !timePtrEqual(a.EmailVerifiedAt, b.EmailVerifiedAt) {
|
||||
return true
|
||||
}
|
||||
if a.Level != b.Level {
|
||||
return true
|
||||
}
|
||||
if a.Role != b.Role {
|
||||
return true
|
||||
}
|
||||
if !slices.Equal(a.Groups, b.Groups) {
|
||||
return true
|
||||
}
|
||||
if !slices.Equal(a.Permissions, b.Permissions) {
|
||||
return true
|
||||
}
|
||||
if !a.CreatedAt.Equal(b.CreatedAt) {
|
||||
return true
|
||||
}
|
||||
if !a.UpdatedAt.Equal(b.UpdatedAt) {
|
||||
return true
|
||||
}
|
||||
if a.MFATOTPSecret != b.MFATOTPSecret {
|
||||
return true
|
||||
}
|
||||
if a.MFAEnabled != b.MFAEnabled {
|
||||
return true
|
||||
}
|
||||
if !timePtrEqual(a.MFASecretIssuedAt, b.MFASecretIssuedAt) {
|
||||
return true
|
||||
}
|
||||
if !timePtrEqual(a.MFAConfirmedAt, b.MFAConfirmedAt) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func identityDiffers(a, b IdentityRecord) bool {
|
||||
if a.Provider != b.Provider {
|
||||
return true
|
||||
}
|
||||
if a.ExternalID != b.ExternalID {
|
||||
return true
|
||||
}
|
||||
if a.UserUUID != b.UserUUID {
|
||||
return true
|
||||
}
|
||||
if !timePtrEqual(a.CreatedAt, b.CreatedAt) {
|
||||
return true
|
||||
}
|
||||
if !timePtrEqual(a.UpdatedAt, b.UpdatedAt) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func preferExistingIdentity(existing, incoming IdentityRecord) bool {
|
||||
switch {
|
||||
case existing.UpdatedAt != nil && incoming.UpdatedAt != nil:
|
||||
if existing.UpdatedAt.Equal(*incoming.UpdatedAt) {
|
||||
return false
|
||||
}
|
||||
return existing.UpdatedAt.After(*incoming.UpdatedAt)
|
||||
case existing.UpdatedAt != nil:
|
||||
return true
|
||||
case incoming.UpdatedAt != nil:
|
||||
return false
|
||||
}
|
||||
|
||||
if existing.CreatedAt != nil && incoming.CreatedAt != nil {
|
||||
if existing.CreatedAt.Equal(*incoming.CreatedAt) {
|
||||
return false
|
||||
}
|
||||
return existing.CreatedAt.After(*incoming.CreatedAt)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func sessionDiffers(a, b SessionRecord) bool {
|
||||
if a.Token != b.Token {
|
||||
return true
|
||||
}
|
||||
if !a.ExpiresAt.Equal(b.ExpiresAt) {
|
||||
return true
|
||||
}
|
||||
if a.UserUUID != b.UserUUID {
|
||||
return true
|
||||
}
|
||||
if !timePtrEqual(a.CreatedAt, b.CreatedAt) {
|
||||
return true
|
||||
}
|
||||
if !timePtrEqual(a.UpdatedAt, b.UpdatedAt) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func preferExistingSession(existing, incoming SessionRecord) bool {
|
||||
switch {
|
||||
case existing.UpdatedAt != nil && incoming.UpdatedAt != nil:
|
||||
if existing.UpdatedAt.Equal(*incoming.UpdatedAt) {
|
||||
return false
|
||||
}
|
||||
return existing.UpdatedAt.After(*incoming.UpdatedAt)
|
||||
case existing.UpdatedAt != nil:
|
||||
return true
|
||||
case incoming.UpdatedAt != nil:
|
||||
return false
|
||||
}
|
||||
|
||||
if existing.CreatedAt != nil && incoming.CreatedAt != nil {
|
||||
if existing.CreatedAt.Equal(*incoming.CreatedAt) {
|
||||
return false
|
||||
}
|
||||
return existing.CreatedAt.After(*incoming.CreatedAt)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func cloneTimePtr(ts *time.Time) *time.Time {
|
||||
if ts == nil {
|
||||
return nil
|
||||
}
|
||||
clone := *ts
|
||||
return &clone
|
||||
}
|
||||
|
||||
func timePtrEqual(a, b *time.Time) bool {
|
||||
switch {
|
||||
case a == nil && b == nil:
|
||||
return true
|
||||
case a == nil || b == nil:
|
||||
return false
|
||||
default:
|
||||
return a.Equal(*b)
|
||||
}
|
||||
}
|
||||
|
||||
func loadIdentities(ctx context.Context, db *sql.DB, uuids []string) ([]IdentityRecord, error) {
|
||||
if len(uuids) == 0 {
|
||||
return nil, nil
|
||||
|
||||
25
account/sql/embed.go
Normal file
25
account/sql/embed.go
Normal file
@ -0,0 +1,25 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"embed"
|
||||
"encoding/hex"
|
||||
"sync"
|
||||
)
|
||||
|
||||
//go:embed schema.sql
|
||||
var schemaFile []byte
|
||||
|
||||
var (
|
||||
hashOnce sync.Once
|
||||
hash string
|
||||
)
|
||||
|
||||
// Hash returns the SHA-256 hash of the canonical schema.sql file.
|
||||
func Hash() string {
|
||||
hashOnce.Do(func() {
|
||||
sum := sha256.Sum256(schemaFile)
|
||||
hash = hex.EncodeToString(sum[:])
|
||||
})
|
||||
return hash
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user