feat(account): add merge-aware importer with dry-run support (#507)

This commit is contained in:
shenlan 2025-10-13 11:24:36 +08:00 committed by GitHub
parent 1c4adf2ce6
commit e2ebbe2b19
5 changed files with 736 additions and 122 deletions

View File

@ -188,8 +188,13 @@ account-export:
@go run ./cmd/migratectl/main.go export --dsn "$(DB_URL)" --output "$(ACCOUNT_EXPORT_FILE)" $(if $(ACCOUNT_EMAIL_KEYWORD),--email "$(ACCOUNT_EMAIL_KEYWORD)")
account-import:
@[ -f "$(ACCOUNT_IMPORT_FILE)" ] || (echo "❌ 未找到文件 $(ACCOUNT_IMPORT_FILE)"; exit 1)
@go run ./cmd/migratectl/main.go import --dsn "$(DB_URL)" --file "$(ACCOUNT_IMPORT_FILE)"
@[ -f "$(ACCOUNT_IMPORT_FILE)" ] || (echo "❌ 未找到文件 $(ACCOUNT_IMPORT_FILE)"; exit 1)
@go run ./cmd/migratectl/main.go import --dsn "$(DB_URL)" --file "$(ACCOUNT_IMPORT_FILE)" \
$(if $(ACCOUNT_IMPORT_MERGE),--merge) \
$(if $(ACCOUNT_IMPORT_MERGE_STRATEGY),--merge-strategy "$(ACCOUNT_IMPORT_MERGE_STRATEGY)") \
$(if $(ACCOUNT_IMPORT_DRY_RUN),--dry-run) \
$(foreach UUID,$(ACCOUNT_IMPORT_MERGE_ALLOWLIST),--merge-allowlist $(UUID)) \
$(ACCOUNT_IMPORT_EXTRA_FLAGS)
create-super-admin:
@[ -n "$(SUPERADMIN_USERNAME)" ] && [ -n "$(SUPERADMIN_PASSWORD)" ] || (echo "❌ 请指定用户名与密码"; exit 1)

View File

@ -7,6 +7,7 @@ import (
"fmt"
"io"
"os"
"strings"
"time"
"github.com/spf13/cobra"
@ -250,9 +251,13 @@ func newExportCmd() *cobra.Command {
func newImportCmd() *cobra.Command {
var (
dsn string
file string
timeout time.Duration
dsn string
file string
timeout time.Duration
merge bool
mergeStrategy string
dryRun bool
mergeAllowlist []string
)
timeout = 5 * time.Minute
@ -288,14 +293,49 @@ func newImportCmd() *cobra.Command {
}
importer := migrate.NewImporter()
allowlist := map[string]struct{}{}
for _, id := range mergeAllowlist {
id = strings.TrimSpace(id)
if id == "" {
continue
}
allowlist[id] = struct{}{}
}
if len(allowlist) == 0 {
allowlist = nil
}
if !merge {
if mergeStrategy != "" {
return errors.New("--merge-strategy requires --merge")
}
if len(mergeAllowlist) > 0 {
return errors.New("--merge-allowlist requires --merge")
}
}
ctx, cancel := context.WithTimeout(cmd.Context(), timeout)
defer cancel()
if err := importer.Import(ctx, dsn, &dump); err != nil {
report, err := importer.Import(ctx, dsn, &dump, migrate.ImportOptions{
Merge: merge,
MergeStrategy: migrate.MergeStrategy(mergeStrategy),
DryRun: dryRun,
Allowlist: allowlist,
LogWriter: cmd.ErrOrStderr(),
})
if err != nil {
return err
}
fmt.Fprintf(cmd.OutOrStdout(), "Imported %d users\n", len(dump.Users))
summaryTarget := "applied"
if dryRun {
summaryTarget = "preview"
}
fmt.Fprintf(cmd.OutOrStdout(), "Import %s: users inserted=%d updated=%d skipped=%d\n", summaryTarget, report.UsersInserted, report.UsersUpdated, report.UsersSkipped)
fmt.Fprintf(cmd.OutOrStdout(), "Identities inserted=%d updated=%d deleted=%d\n", report.IdentitiesInserted, report.IdentitiesUpdated, report.IdentitiesDeleted)
fmt.Fprintf(cmd.OutOrStdout(), "Sessions inserted=%d updated=%d deleted=%d\n", report.SessionsInserted, report.SessionsUpdated, report.SessionsDeleted)
if report.ConflictsResolved > 0 || report.ConflictsSkipped > 0 {
fmt.Fprintf(cmd.OutOrStdout(), "Conflicts resolved=%d skipped=%d\n", report.ConflictsResolved, report.ConflictsSkipped)
}
return nil
},
}
@ -303,6 +343,10 @@ func newImportCmd() *cobra.Command {
cmd.Flags().StringVar(&dsn, "dsn", "", "PostgreSQL connection string")
cmd.Flags().StringVar(&file, "file", "", "YAML file path or '-' for stdin")
cmd.Flags().DurationVar(&timeout, "timeout", timeout, "Import operation timeout")
cmd.Flags().BoolVar(&merge, "merge", false, "Enable additive merge behaviour")
cmd.Flags().StringVar(&mergeStrategy, "merge-strategy", "", "Merge strategy (replace, append, timestamp)")
cmd.Flags().BoolVar(&dryRun, "dry-run", false, "Preview the import without applying changes")
cmd.Flags().StringSliceVar(&mergeAllowlist, "merge-allowlist", nil, "User UUIDs allowed to merge (comma-separated or repeated)")
return cmd
}

View File

@ -0,0 +1,33 @@
package migrate
import (
"errors"
"time"
accountschema "xcontrol/account/sql"
)
// SnapshotVersion identifies the canonical format of exported account snapshots.
const SnapshotVersion = "v1"
// SnapshotMetadata captures provenance information for account snapshots.
type SnapshotMetadata struct {
Version string `yaml:"version"`
SchemaHash string `yaml:"schemaHash"`
ExportedAt time.Time `yaml:"exportedAt"`
}
// validateSnapshotMetadata ensures the provided metadata matches the expected
// snapshot format and schema hash.
func validateSnapshotMetadata(meta *SnapshotMetadata) error {
if meta == nil {
return errors.New("snapshot metadata missing (expected version and schema hash)")
}
if meta.Version != SnapshotVersion {
return errors.New("snapshot version mismatch")
}
if meta.SchemaHash != accountschema.Hash() {
return errors.New("snapshot schema hash mismatch")
}
return nil
}

View File

@ -6,16 +6,21 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"slices"
"sort"
"strings"
"time"
accountschema "xcontrol/account/sql"
)
// AccountDump represents the serialized snapshot of account-related tables.
type AccountDump struct {
Users []UserRecord `yaml:"users"`
Identities []IdentityRecord `yaml:"identities,omitempty"`
Sessions []SessionRecord `yaml:"sessions,omitempty"`
Metadata *SnapshotMetadata `yaml:"metadata,omitempty"`
Users []UserRecord `yaml:"users"`
Identities []IdentityRecord `yaml:"identities,omitempty"`
Sessions []SessionRecord `yaml:"sessions,omitempty"`
}
// UserRecord captures the exported representation of a user row.
@ -58,6 +63,48 @@ type SessionRecord struct {
UpdatedAt *time.Time `yaml:"updatedAt,omitempty"`
}
// MergeStrategy defines how snapshot data should be reconciled with the target database.
type MergeStrategy string
const (
// MergeStrategyReplace preserves the legacy behaviour where incoming records
// fully replace existing ones.
MergeStrategyReplace MergeStrategy = "replace"
// MergeStrategyAppend performs additive merges, keeping existing data that is
// absent from the snapshot.
MergeStrategyAppend MergeStrategy = "append"
// MergeStrategyTimestamp resolves conflicts by preferring rows with the newest
// updated_at timestamp.
MergeStrategyTimestamp MergeStrategy = "timestamp"
)
// ImportOptions configures how snapshot imports should be applied.
type ImportOptions struct {
Merge bool
MergeStrategy MergeStrategy
DryRun bool
Allowlist map[string]struct{}
LogWriter io.Writer
}
// ImportReport captures the outcome of an import (or dry-run) execution.
type ImportReport struct {
UsersInserted int
UsersUpdated int
UsersSkipped int
IdentitiesInserted int
IdentitiesUpdated int
IdentitiesDeleted int
SessionsInserted int
SessionsUpdated int
SessionsDeleted int
ConflictsResolved int
ConflictsSkipped int
}
// Exporter reads account data from a PostgreSQL database.
type Exporter struct{}
@ -75,7 +122,13 @@ func (e *Exporter) Export(ctx context.Context, dsn, emailKeyword string) (*Accou
}
defer db.Close()
dump := &AccountDump{}
dump := &AccountDump{
Metadata: &SnapshotMetadata{
Version: SnapshotVersion,
SchemaHash: accountschema.Hash(),
ExportedAt: time.Now().UTC(),
},
}
users, err := loadUsers(ctx, db, emailKeyword)
if err != nil {
@ -115,69 +168,288 @@ func NewImporter() *Importer {
return &Importer{}
}
// Import restores account data from a dump into the target database. Existing
// rows are replaced on conflict and related identities/sessions are refreshed.
func (i *Importer) Import(ctx context.Context, dsn string, dump *AccountDump) error {
// Import restores account data from a dump into the target database using the
// provided options. When merge mode is disabled the behaviour mirrors the
// legacy implementation.
func (i *Importer) Import(ctx context.Context, dsn string, dump *AccountDump, opts ImportOptions) (*ImportReport, error) {
if dump == nil {
return errors.New("dump is nil")
return nil, errors.New("dump is nil")
}
if err := validateSnapshotMetadata(dump.Metadata); err != nil {
return nil, err
}
logWriter := opts.LogWriter
if logWriter == nil {
logWriter = io.Discard
}
logf := func(format string, args ...any) {
fmt.Fprintf(logWriter, format, args...)
}
strategy := opts.MergeStrategy
if strategy == "" {
if opts.Merge {
strategy = MergeStrategyAppend
} else {
strategy = MergeStrategyReplace
}
}
switch strategy {
case MergeStrategyReplace, MergeStrategyAppend, MergeStrategyTimestamp:
default:
return nil, fmt.Errorf("unsupported merge strategy %q", strategy)
}
if !opts.Merge {
strategy = MergeStrategyReplace
}
db, err := openDB(ctx, dsn)
if err != nil {
return err
return nil, err
}
defer db.Close()
identityCaps, err := tableColumnCaps(ctx, db, "identities")
if err != nil {
return err
return nil, err
}
sessionCaps, err := tableColumnCaps(ctx, db, "sessions")
if err != nil {
return err
return nil, err
}
userUUIDs := make([]string, 0, len(dump.Users))
for _, user := range dump.Users {
userUUIDs = append(userUUIDs, user.UUID)
}
existingUsers, err := loadUsersByUUIDs(ctx, db, userUUIDs)
if err != nil {
return nil, err
}
existingIdentitiesSlice, err := loadIdentities(ctx, db, userUUIDs)
if err != nil {
return nil, err
}
existingSessionsSlice, err := loadSessions(ctx, db, userUUIDs)
if err != nil {
return nil, err
}
existingIdentitiesByUUID := make(map[string]IdentityRecord, len(existingIdentitiesSlice))
existingIdentitiesByUser := make(map[string][]IdentityRecord)
for _, identity := range existingIdentitiesSlice {
existingIdentitiesByUUID[identity.UUID] = identity
existingIdentitiesByUser[identity.UserUUID] = append(existingIdentitiesByUser[identity.UserUUID], identity)
}
existingSessionsByUUID := make(map[string]SessionRecord, len(existingSessionsSlice))
existingSessionsByUser := make(map[string][]SessionRecord)
for _, session := range existingSessionsSlice {
existingSessionsByUUID[session.UUID] = session
existingSessionsByUser[session.UserUUID] = append(existingSessionsByUser[session.UserUUID], session)
}
incomingIdentitiesByUser := make(map[string][]IdentityRecord)
for _, identity := range dump.Identities {
incomingIdentitiesByUser[identity.UserUUID] = append(incomingIdentitiesByUser[identity.UserUUID], identity)
}
incomingSessionsByUser := make(map[string][]SessionRecord)
for _, session := range dump.Sessions {
incomingSessionsByUser[session.UserUUID] = append(incomingSessionsByUser[session.UserUUID], session)
}
tx, err := db.BeginTx(ctx, &sql.TxOptions{})
if err != nil {
return err
return nil, err
}
committed := false
defer func() {
if err != nil {
if !committed {
tx.Rollback()
}
}()
cleared := make(map[string]struct{})
report := &ImportReport{}
allowlist := opts.Allowlist
allowlistEnabled := opts.Merge && len(allowlist) > 0
for _, user := range dump.Users {
if err = upsertUser(ctx, tx, &user); err != nil {
return err
if allowlistEnabled {
if _, ok := allowlist[user.UUID]; !ok {
report.UsersSkipped++
logf("skip user %s: not present in merge allowlist\n", user.UUID)
continue
}
}
cleared[user.UUID] = struct{}{}
}
for uuid := range cleared {
if _, err = tx.ExecContext(ctx, `DELETE FROM identities WHERE user_uuid = $1`, uuid); err != nil {
return err
existing, hasExisting := existingUsers[user.UUID]
if opts.Merge && hasExisting && strategy == MergeStrategyTimestamp && existing.UpdatedAt.After(user.UpdatedAt) {
report.UsersSkipped++
report.ConflictsSkipped++
logf("skip user %s: existing updated_at %s newer than snapshot %s\n", user.UUID, existing.UpdatedAt.Format(time.RFC3339), user.UpdatedAt.Format(time.RFC3339))
continue
}
if _, err = tx.ExecContext(ctx, `DELETE FROM sessions WHERE user_uuid = $1`, uuid); err != nil {
return err
mergedUser, changed := mergeUserRecord(user, existing, opts.Merge, hasExisting)
if !hasExisting {
report.UsersInserted++
} else if changed {
report.UsersUpdated++
if opts.Merge && strategy == MergeStrategyTimestamp {
report.ConflictsResolved++
}
} else {
report.UsersSkipped++
}
if changed && !opts.DryRun {
if err := upsertUser(ctx, tx, &mergedUser); err != nil {
return nil, err
}
}
existingUsers[user.UUID] = mergedUser
incomingIdentities := incomingIdentitiesByUser[user.UUID]
incomingSessions := incomingSessionsByUser[user.UUID]
if !opts.Merge || strategy == MergeStrategyReplace {
if existingCount := len(existingIdentitiesByUser[user.UUID]); existingCount > 0 {
report.IdentitiesDeleted += existingCount
if !opts.DryRun {
if _, err := tx.ExecContext(ctx, `DELETE FROM identities WHERE user_uuid = $1`, user.UUID); err != nil {
return nil, err
}
}
}
if existingCount := len(existingSessionsByUser[user.UUID]); existingCount > 0 {
report.SessionsDeleted += existingCount
if !opts.DryRun {
if _, err := tx.ExecContext(ctx, `DELETE FROM sessions WHERE user_uuid = $1`, user.UUID); err != nil {
return nil, err
}
}
}
for _, identity := range incomingIdentities {
if _, ok := existingIdentitiesByUUID[identity.UUID]; ok {
report.IdentitiesUpdated++
} else {
report.IdentitiesInserted++
}
if !opts.DryRun {
if err := upsertIdentity(ctx, tx, &identity, identityCaps); err != nil {
return nil, err
}
}
}
for _, session := range incomingSessions {
if _, ok := existingSessionsByUUID[session.UUID]; ok {
report.SessionsUpdated++
} else {
report.SessionsInserted++
}
if !opts.DryRun {
if err := upsertSession(ctx, tx, &session, sessionCaps); err != nil {
return nil, err
}
}
}
continue
}
// Merge mode (append/timestamp) for identities.
for _, identity := range incomingIdentities {
existingIdentity, ok := existingIdentitiesByUUID[identity.UUID]
if !ok {
report.IdentitiesInserted++
if !opts.DryRun {
if err := upsertIdentity(ctx, tx, &identity, identityCaps); err != nil {
return nil, err
}
}
continue
}
if strategy == MergeStrategyTimestamp && preferExistingIdentity(existingIdentity, identity) {
report.ConflictsSkipped++
logf("retain identity %s for user %s: existing updated_at preferred\n", identity.UUID, identity.UserUUID)
continue
}
if identityDiffers(identity, existingIdentity) {
report.IdentitiesUpdated++
if strategy == MergeStrategyTimestamp {
report.ConflictsResolved++
}
if !opts.DryRun {
if err := upsertIdentity(ctx, tx, &identity, identityCaps); err != nil {
return nil, err
}
}
}
}
for _, session := range incomingSessions {
existingSession, ok := existingSessionsByUUID[session.UUID]
if !ok {
report.SessionsInserted++
if !opts.DryRun {
if err := upsertSession(ctx, tx, &session, sessionCaps); err != nil {
return nil, err
}
}
continue
}
if strategy == MergeStrategyTimestamp && preferExistingSession(existingSession, session) {
report.ConflictsSkipped++
logf("retain session %s for user %s: existing updated_at preferred\n", session.UUID, session.UserUUID)
continue
}
if sessionDiffers(session, existingSession) {
report.SessionsUpdated++
if strategy == MergeStrategyTimestamp {
report.ConflictsResolved++
}
if !opts.DryRun {
if err := upsertSession(ctx, tx, &session, sessionCaps); err != nil {
return nil, err
}
}
}
}
}
for _, identity := range dump.Identities {
if err = upsertIdentity(ctx, tx, &identity, identityCaps); err != nil {
return err
if opts.DryRun {
if err := tx.Rollback(); err != nil && !errors.Is(err, sql.ErrTxDone) {
return nil, err
}
committed = true
logf("dry-run complete: no changes applied\n")
return report, nil
}
for _, session := range dump.Sessions {
if err = upsertSession(ctx, tx, &session, sessionCaps); err != nil {
return err
}
if err := tx.Commit(); err != nil {
return nil, err
}
committed = true
return report, nil
}
return tx.Commit()
const userSelectColumns = `uuid, username, password, email, email_verified, email_verified_at, level, role, groups, permissions, created_at, updated_at, mfa_totp_secret, mfa_enabled, mfa_secret_issued_at, mfa_confirmed_at`
type rowScanner interface {
Scan(dest ...any) error
}
func loadUsers(ctx context.Context, db *sql.DB, emailKeyword string) ([]UserRecord, error) {
@ -186,7 +458,9 @@ func loadUsers(ctx context.Context, db *sql.DB, emailKeyword string) ([]UserReco
args []any
)
query.WriteString(`SELECT uuid, username, password, email, email_verified, email_verified_at, level, role, groups, permissions, created_at, updated_at, mfa_totp_secret, mfa_enabled, mfa_secret_issued_at, mfa_confirmed_at FROM users`)
query.WriteString("SELECT ")
query.WriteString(userSelectColumns)
query.WriteString(" FROM users")
if keyword := strings.TrimSpace(emailKeyword); keyword != "" {
query.WriteString(` WHERE email ILIKE $1`)
args = append(args, "%"+keyword+"%")
@ -201,91 +475,10 @@ func loadUsers(ctx context.Context, db *sql.DB, emailKeyword string) ([]UserReco
var users []UserRecord
for rows.Next() {
var (
email sql.NullString
emailVerified bool
emailVerifiedAt sql.NullTime
level sql.NullInt64
role sql.NullString
groupsRaw []byte
permissionsRaw []byte
createdAt time.Time
updatedAt time.Time
mfaSecret sql.NullString
mfaEnabled sql.NullBool
mfaIssuedAt sql.NullTime
mfaConfirmedAt sql.NullTime
user UserRecord
)
if err := rows.Scan(
&user.UUID,
&user.Username,
&user.PasswordHash,
&email,
&emailVerified,
&emailVerifiedAt,
&level,
&role,
&groupsRaw,
&permissionsRaw,
&createdAt,
&updatedAt,
&mfaSecret,
&mfaEnabled,
&mfaIssuedAt,
&mfaConfirmedAt,
); err != nil {
user, err := scanUserRow(rows)
if err != nil {
return nil, err
}
if email.Valid {
user.Email = email.String
}
user.EmailVerified = emailVerified
if emailVerifiedAt.Valid {
ts := emailVerifiedAt.Time
user.EmailVerifiedAt = &ts
}
if level.Valid {
user.Level = int(level.Int64)
}
if role.Valid {
user.Role = role.String
}
if len(groupsRaw) > 0 {
if err := json.Unmarshal(groupsRaw, &user.Groups); err != nil {
return nil, fmt.Errorf("decode groups for user %s: %w", user.UUID, err)
}
}
if len(permissionsRaw) > 0 {
if err := json.Unmarshal(permissionsRaw, &user.Permissions); err != nil {
return nil, fmt.Errorf("decode permissions for user %s: %w", user.UUID, err)
}
}
user.CreatedAt = createdAt
user.UpdatedAt = updatedAt
if mfaSecret.Valid {
user.MFATOTPSecret = mfaSecret.String
}
user.MFAEnabled = mfaEnabled.Bool
if mfaIssuedAt.Valid {
ts := mfaIssuedAt.Time
user.MFASecretIssuedAt = &ts
}
if mfaConfirmedAt.Valid {
ts := mfaConfirmedAt.Time
user.MFAConfirmedAt = &ts
}
if user.Groups == nil {
user.Groups = []string{}
}
if user.Permissions == nil {
user.Permissions = []string{}
}
if user.Role == "" {
user.Role = "user"
}
users = append(users, user)
}
@ -296,6 +489,320 @@ func loadUsers(ctx context.Context, db *sql.DB, emailKeyword string) ([]UserReco
return users, nil
}
func loadUsersByUUIDs(ctx context.Context, db *sql.DB, uuids []string) (map[string]UserRecord, error) {
users := make(map[string]UserRecord, len(uuids))
if len(uuids) == 0 {
return users, nil
}
queryTemplate := fmt.Sprintf("SELECT %s FROM users WHERE uuid IN (%%s)", userSelectColumns)
query, args := buildInQuery(queryTemplate, uuids)
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
user, err := scanUserRow(rows)
if err != nil {
return nil, err
}
users[user.UUID] = user
}
if err := rows.Err(); err != nil {
return nil, err
}
return users, nil
}
func scanUserRow(scanner rowScanner) (UserRecord, error) {
var (
email sql.NullString
emailVerified bool
emailVerifiedAt sql.NullTime
level sql.NullInt64
role sql.NullString
groupsRaw []byte
permissionsRaw []byte
createdAt time.Time
updatedAt time.Time
mfaSecret sql.NullString
mfaEnabled sql.NullBool
mfaIssuedAt sql.NullTime
mfaConfirmedAt sql.NullTime
user UserRecord
)
if err := scanner.Scan(
&user.UUID,
&user.Username,
&user.PasswordHash,
&email,
&emailVerified,
&emailVerifiedAt,
&level,
&role,
&groupsRaw,
&permissionsRaw,
&createdAt,
&updatedAt,
&mfaSecret,
&mfaEnabled,
&mfaIssuedAt,
&mfaConfirmedAt,
); err != nil {
return UserRecord{}, err
}
if email.Valid {
user.Email = email.String
}
user.EmailVerified = emailVerified
if emailVerifiedAt.Valid {
ts := emailVerifiedAt.Time
user.EmailVerifiedAt = &ts
}
if level.Valid {
user.Level = int(level.Int64)
}
if role.Valid {
user.Role = role.String
}
if len(groupsRaw) > 0 {
if err := json.Unmarshal(groupsRaw, &user.Groups); err != nil {
return UserRecord{}, fmt.Errorf("decode groups for user %s: %w", user.UUID, err)
}
}
if len(permissionsRaw) > 0 {
if err := json.Unmarshal(permissionsRaw, &user.Permissions); err != nil {
return UserRecord{}, fmt.Errorf("decode permissions for user %s: %w", user.UUID, err)
}
}
user.CreatedAt = createdAt
user.UpdatedAt = updatedAt
if mfaSecret.Valid {
user.MFATOTPSecret = mfaSecret.String
}
user.MFAEnabled = mfaEnabled.Bool
if mfaIssuedAt.Valid {
ts := mfaIssuedAt.Time
user.MFASecretIssuedAt = &ts
}
if mfaConfirmedAt.Valid {
ts := mfaConfirmedAt.Time
user.MFAConfirmedAt = &ts
}
ensureUserDefaults(&user)
return user, nil
}
func ensureUserDefaults(user *UserRecord) {
if user.Groups == nil {
user.Groups = []string{}
}
if user.Permissions == nil {
user.Permissions = []string{}
}
if user.Role == "" {
user.Role = "user"
}
}
func mergeUserRecord(incoming UserRecord, existing UserRecord, merge bool, hasExisting bool) (UserRecord, bool) {
ensureUserDefaults(&incoming)
if !hasExisting {
return incoming, true
}
if merge {
if incoming.Email == "" {
incoming.Email = existing.Email
}
if incoming.EmailVerifiedAt == nil {
incoming.EmailVerifiedAt = cloneTimePtr(existing.EmailVerifiedAt)
}
if len(incoming.Groups) == 0 && len(existing.Groups) > 0 {
incoming.Groups = append([]string(nil), existing.Groups...)
}
if len(incoming.Permissions) == 0 && len(existing.Permissions) > 0 {
incoming.Permissions = append([]string(nil), existing.Permissions...)
}
if incoming.Role == "" {
incoming.Role = existing.Role
}
if incoming.MFATOTPSecret == "" {
incoming.MFATOTPSecret = existing.MFATOTPSecret
}
if incoming.MFASecretIssuedAt == nil {
incoming.MFASecretIssuedAt = cloneTimePtr(existing.MFASecretIssuedAt)
}
if incoming.MFAConfirmedAt == nil {
incoming.MFAConfirmedAt = cloneTimePtr(existing.MFAConfirmedAt)
}
}
changed := userDiffers(incoming, existing)
return incoming, changed
}
func userDiffers(a, b UserRecord) bool {
if a.Username != b.Username {
return true
}
if a.PasswordHash != b.PasswordHash {
return true
}
if a.Email != b.Email {
return true
}
if a.EmailVerified != b.EmailVerified {
return true
}
if !timePtrEqual(a.EmailVerifiedAt, b.EmailVerifiedAt) {
return true
}
if a.Level != b.Level {
return true
}
if a.Role != b.Role {
return true
}
if !slices.Equal(a.Groups, b.Groups) {
return true
}
if !slices.Equal(a.Permissions, b.Permissions) {
return true
}
if !a.CreatedAt.Equal(b.CreatedAt) {
return true
}
if !a.UpdatedAt.Equal(b.UpdatedAt) {
return true
}
if a.MFATOTPSecret != b.MFATOTPSecret {
return true
}
if a.MFAEnabled != b.MFAEnabled {
return true
}
if !timePtrEqual(a.MFASecretIssuedAt, b.MFASecretIssuedAt) {
return true
}
if !timePtrEqual(a.MFAConfirmedAt, b.MFAConfirmedAt) {
return true
}
return false
}
func identityDiffers(a, b IdentityRecord) bool {
if a.Provider != b.Provider {
return true
}
if a.ExternalID != b.ExternalID {
return true
}
if a.UserUUID != b.UserUUID {
return true
}
if !timePtrEqual(a.CreatedAt, b.CreatedAt) {
return true
}
if !timePtrEqual(a.UpdatedAt, b.UpdatedAt) {
return true
}
return false
}
func preferExistingIdentity(existing, incoming IdentityRecord) bool {
switch {
case existing.UpdatedAt != nil && incoming.UpdatedAt != nil:
if existing.UpdatedAt.Equal(*incoming.UpdatedAt) {
return false
}
return existing.UpdatedAt.After(*incoming.UpdatedAt)
case existing.UpdatedAt != nil:
return true
case incoming.UpdatedAt != nil:
return false
}
if existing.CreatedAt != nil && incoming.CreatedAt != nil {
if existing.CreatedAt.Equal(*incoming.CreatedAt) {
return false
}
return existing.CreatedAt.After(*incoming.CreatedAt)
}
return false
}
func sessionDiffers(a, b SessionRecord) bool {
if a.Token != b.Token {
return true
}
if !a.ExpiresAt.Equal(b.ExpiresAt) {
return true
}
if a.UserUUID != b.UserUUID {
return true
}
if !timePtrEqual(a.CreatedAt, b.CreatedAt) {
return true
}
if !timePtrEqual(a.UpdatedAt, b.UpdatedAt) {
return true
}
return false
}
func preferExistingSession(existing, incoming SessionRecord) bool {
switch {
case existing.UpdatedAt != nil && incoming.UpdatedAt != nil:
if existing.UpdatedAt.Equal(*incoming.UpdatedAt) {
return false
}
return existing.UpdatedAt.After(*incoming.UpdatedAt)
case existing.UpdatedAt != nil:
return true
case incoming.UpdatedAt != nil:
return false
}
if existing.CreatedAt != nil && incoming.CreatedAt != nil {
if existing.CreatedAt.Equal(*incoming.CreatedAt) {
return false
}
return existing.CreatedAt.After(*incoming.CreatedAt)
}
return false
}
func cloneTimePtr(ts *time.Time) *time.Time {
if ts == nil {
return nil
}
clone := *ts
return &clone
}
func timePtrEqual(a, b *time.Time) bool {
switch {
case a == nil && b == nil:
return true
case a == nil || b == nil:
return false
default:
return a.Equal(*b)
}
}
func loadIdentities(ctx context.Context, db *sql.DB, uuids []string) ([]IdentityRecord, error) {
if len(uuids) == 0 {
return nil, nil

25
account/sql/embed.go Normal file
View File

@ -0,0 +1,25 @@
package schema
import (
"crypto/sha256"
"embed"
"encoding/hex"
"sync"
)
//go:embed schema.sql
var schemaFile []byte
var (
hashOnce sync.Once
hash string
)
// Hash returns the SHA-256 hash of the canonical schema.sql file.
func Hash() string {
hashOnce.Do(func() {
sum := sha256.Sum256(schemaFile)
hash = hex.EncodeToString(sum[:])
})
return hash
}