diff --git a/account/internal/store/postgres.go b/account/internal/store/postgres.go index 8926443..3db875f 100644 --- a/account/internal/store/postgres.go +++ b/account/internal/store/postgres.go @@ -109,18 +109,24 @@ func (s *postgresStore) CreateUser(ctx context.Context, user *User) error { } if normalizedEmail != "" { - if _, err := s.GetUserByEmail(ctx, normalizedEmail); err == nil { - return ErrEmailExists - } else if !errors.Is(err, ErrUserNotFound) { + const emailExistsQuery = "SELECT EXISTS(SELECT 1 FROM users WHERE lower(email) = $1)" + emailExists, err := s.userExists(ctx, emailExistsQuery, normalizedEmail) + if err != nil { return err } + if emailExists { + return ErrEmailExists + } } - if _, err := s.GetUserByName(ctx, normalizedName); err == nil { - return ErrNameExists - } else if !errors.Is(err, ErrUserNotFound) { + const nameExistsQuery = "SELECT EXISTS(SELECT 1 FROM users WHERE lower(username) = lower($1))" + nameExists, err := s.userExists(ctx, nameExistsQuery, normalizedName) + if err != nil { return err } + if nameExists { + return ErrNameExists + } query := `INSERT INTO users (username, password, email, email_verified_at) VALUES ($1, $2, $3, $4) @@ -163,6 +169,31 @@ func (s *postgresStore) CreateUser(ctx context.Context, user *User) error { return nil } +func (s *postgresStore) userExists(ctx context.Context, query string, arg any) (bool, error) { + var exists bool + err := s.db.QueryRowContext(ctx, query, arg).Scan(&exists) + if err != nil { + if isDatabaseEmptyError(err) { + return false, nil + } + return false, err + } + return exists, nil +} + +func isDatabaseEmptyError(err error) bool { + if errors.Is(err, sql.ErrNoRows) { + return true + } + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) { + if pgErr.Code == "42P01" { // undefined_table + return true + } + } + return false +} + func (s *postgresStore) GetUserByEmail(ctx context.Context, email string) (*User, error) { normalized := strings.ToLower(strings.TrimSpace(email)) if normalized == "" {