307 lines
8.4 KiB
Go
307 lines
8.4 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"flag"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/pquerna/otp"
|
|
"github.com/pquerna/otp/totp"
|
|
"golang.org/x/crypto/bcrypt"
|
|
|
|
"account/internal/store"
|
|
)
|
|
|
|
func main() {
|
|
var (
|
|
driver = flag.String("driver", "postgres", "database driver (postgres, memory)")
|
|
dsn = flag.String("dsn", "", "database connection string")
|
|
username = flag.String("username", "", "root username")
|
|
password = flag.String("password", "", "root password")
|
|
email = flag.String("email", store.RootAdminEmail, "root email (must be admin@svc.plus)")
|
|
groups = flag.String("groups", "", "comma separated list of groups to assign (optional)")
|
|
permissions = flag.String("permissions", "", "comma separated list of permissions to assign (optional)")
|
|
currentPassword = flag.String("current-password", "", "current super administrator password (required when updating)")
|
|
mfaCode = flag.String("mfa", "", "MFA TOTP code for the current super administrator (required when MFA is enabled)")
|
|
)
|
|
flag.Parse()
|
|
|
|
if err := run(*driver, *dsn, *username, *password, *email, *groups, *permissions, *currentPassword, *mfaCode); err != nil {
|
|
log.Fatalf("failed to create super administrator: %v", err)
|
|
}
|
|
}
|
|
|
|
func run(driver, dsn, username, password, email, groups, permissions, currentPassword, mfaCode string) error {
|
|
driver = strings.TrimSpace(driver)
|
|
dsn = strings.TrimSpace(dsn)
|
|
username = strings.TrimSpace(username)
|
|
password = strings.TrimSpace(password)
|
|
email = strings.TrimSpace(email)
|
|
groups = strings.TrimSpace(groups)
|
|
permissions = strings.TrimSpace(permissions)
|
|
currentPassword = strings.TrimSpace(currentPassword)
|
|
mfaCode = strings.TrimSpace(mfaCode)
|
|
|
|
if username == "" {
|
|
return errors.New("username is required")
|
|
}
|
|
if !strings.EqualFold(email, store.RootAdminEmail) {
|
|
return fmt.Errorf("root email must be %q", store.RootAdminEmail)
|
|
}
|
|
if dsn == "" && !strings.EqualFold(driver, "memory") {
|
|
return errors.New("dsn is required")
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
|
|
storeConfig := store.Config{
|
|
Driver: driver,
|
|
DSN: dsn,
|
|
AllowSuperAdminCounting: true,
|
|
}
|
|
|
|
s, cleanup, err := store.New(ctx, storeConfig)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
_ = cleanup(context.Background())
|
|
}()
|
|
|
|
configuredGroups := parseCSV(groups)
|
|
configuredPermissions := parseCSV(permissions)
|
|
|
|
user, err := s.GetUserByEmail(ctx, store.RootAdminEmail)
|
|
if err != nil && !errors.Is(err, store.ErrUserNotFound) {
|
|
return err
|
|
}
|
|
if errors.Is(err, store.ErrUserNotFound) {
|
|
user, err = s.GetUserByName(ctx, username)
|
|
if err != nil && !errors.Is(err, store.ErrUserNotFound) {
|
|
return err
|
|
}
|
|
}
|
|
|
|
superAdminCount, err := countSuperAdmins(ctx, s)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if user == nil {
|
|
if superAdminCount > 0 {
|
|
return errors.New("root administrator already exists")
|
|
}
|
|
if password == "" {
|
|
return errors.New("password is required")
|
|
}
|
|
|
|
hashed, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return fmt.Errorf("hash password: %w", err)
|
|
}
|
|
|
|
newUser := &store.User{
|
|
Name: username,
|
|
Email: email,
|
|
PasswordHash: string(hashed),
|
|
Level: store.LevelAdmin,
|
|
Role: store.RoleRoot,
|
|
Groups: ensureSuperAdminGroups(configuredGroups, nil),
|
|
Permissions: ensureSuperAdminPermissions(configuredPermissions, nil),
|
|
EmailVerified: true,
|
|
}
|
|
|
|
if err := s.CreateUser(ctx, newUser); err != nil {
|
|
if errors.Is(err, store.ErrEmailExists) {
|
|
return fmt.Errorf("email already exists: %w", err)
|
|
}
|
|
if errors.Is(err, store.ErrNameExists) {
|
|
return fmt.Errorf("username already exists: %w", err)
|
|
}
|
|
return err
|
|
}
|
|
|
|
fmt.Fprintf(os.Stdout, "Created super administrator %s (id=%s)\n", newUser.Name, newUser.ID)
|
|
return nil
|
|
}
|
|
|
|
if superAdminCount > 1 {
|
|
return errors.New("multiple root administrators detected; resolve manually before continuing")
|
|
}
|
|
|
|
if user.PasswordHash != "" {
|
|
if currentPassword == "" {
|
|
return errors.New("current password is required to update the super administrator")
|
|
}
|
|
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(currentPassword)); err != nil {
|
|
return errors.New("current password verification failed")
|
|
}
|
|
}
|
|
|
|
if user.MFAEnabled {
|
|
if mfaCode == "" {
|
|
return errors.New("mfa code is required for this super administrator")
|
|
}
|
|
valid, err := totp.ValidateCustom(mfaCode, user.MFATOTPSecret, time.Now().UTC(), totp.ValidateOpts{
|
|
Period: 30,
|
|
Skew: 1,
|
|
Digits: otp.DigitsSix,
|
|
Algorithm: otp.AlgorithmSHA1,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("validate mfa code: %w", err)
|
|
}
|
|
if !valid {
|
|
return errors.New("invalid mfa code provided")
|
|
}
|
|
}
|
|
|
|
updated := *user
|
|
updated.Email = store.RootAdminEmail
|
|
if password != "" {
|
|
hashed, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return fmt.Errorf("hash password: %w", err)
|
|
}
|
|
updated.PasswordHash = string(hashed)
|
|
}
|
|
|
|
updated.Groups = ensureSuperAdminGroups(configuredGroups, user.Groups)
|
|
updated.Permissions = ensureSuperAdminPermissions(configuredPermissions, user.Permissions)
|
|
updated.EmailVerified = updated.Email != ""
|
|
updated.Role = store.RoleRoot
|
|
updated.Level = store.LevelAdmin
|
|
updated.UpdatedAt = time.Now().UTC()
|
|
|
|
if err := s.UpdateUser(ctx, &updated); err != nil {
|
|
if errors.Is(err, store.ErrEmailExists) {
|
|
return fmt.Errorf("email already exists: %w", err)
|
|
}
|
|
if errors.Is(err, store.ErrNameExists) {
|
|
return fmt.Errorf("username already exists: %w", err)
|
|
}
|
|
return err
|
|
}
|
|
|
|
fmt.Fprintf(os.Stdout, "Updated super administrator %s (id=%s)\n", updated.Name, updated.ID)
|
|
return nil
|
|
}
|
|
|
|
func countSuperAdmins(ctx context.Context, s store.Store) (int, error) {
|
|
type superAdminCounter interface {
|
|
CountSuperAdmins(ctx context.Context) (int, error)
|
|
}
|
|
|
|
if counter, ok := s.(superAdminCounter); ok {
|
|
count, err := counter.CountSuperAdmins(ctx)
|
|
if errors.Is(err, store.ErrSuperAdminCountingDisabled) {
|
|
return 0, errors.New("store does not permit super administrator counting; enable it explicitly to proceed")
|
|
}
|
|
return count, err
|
|
}
|
|
return 0, errors.New("store does not support super administrator discovery")
|
|
}
|
|
|
|
func parseCSV(input string) []string {
|
|
if input == "" {
|
|
return nil
|
|
}
|
|
parts := strings.Split(input, ",")
|
|
result := make([]string, 0, len(parts))
|
|
seen := make(map[string]struct{})
|
|
for _, part := range parts {
|
|
trimmed := strings.TrimSpace(part)
|
|
if trimmed == "" {
|
|
continue
|
|
}
|
|
lowered := strings.ToLower(trimmed)
|
|
if _, exists := seen[lowered]; exists {
|
|
continue
|
|
}
|
|
seen[lowered] = struct{}{}
|
|
result = append(result, trimmed)
|
|
}
|
|
if len(result) == 0 {
|
|
return nil
|
|
}
|
|
sort.Strings(result)
|
|
return result
|
|
}
|
|
|
|
func ensureSuperAdminGroups(configured, existing []string) []string {
|
|
base := mergeValues(existing, configured)
|
|
if !containsCaseInsensitive(base, "Admin") {
|
|
base = append(base, "Admin")
|
|
}
|
|
return normalizeResult(base)
|
|
}
|
|
|
|
func ensureSuperAdminPermissions(configured, existing []string) []string {
|
|
base := mergeValues(existing, configured)
|
|
if !containsExact(base, "*") {
|
|
base = append(base, "*")
|
|
}
|
|
return normalizeResult(base)
|
|
}
|
|
|
|
func mergeValues(existing, configured []string) []string {
|
|
values := make([]string, 0, len(existing)+len(configured))
|
|
values = append(values, existing...)
|
|
values = append(values, configured...)
|
|
return values
|
|
}
|
|
|
|
func containsCaseInsensitive(values []string, target string) bool {
|
|
if target == "" {
|
|
return false
|
|
}
|
|
targetLower := strings.ToLower(target)
|
|
for _, value := range values {
|
|
if strings.ToLower(strings.TrimSpace(value)) == targetLower {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func containsExact(values []string, target string) bool {
|
|
for _, value := range values {
|
|
if strings.TrimSpace(value) == target {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func normalizeResult(values []string) []string {
|
|
if len(values) == 0 {
|
|
return nil
|
|
}
|
|
normalized := make([]string, 0, len(values))
|
|
seen := make(map[string]struct{}, len(values))
|
|
for _, value := range values {
|
|
trimmed := strings.TrimSpace(value)
|
|
if trimmed == "" {
|
|
continue
|
|
}
|
|
key := strings.ToLower(trimmed)
|
|
if trimmed == "*" {
|
|
key = "*"
|
|
}
|
|
if _, ok := seen[key]; ok {
|
|
continue
|
|
}
|
|
seen[key] = struct{}{}
|
|
normalized = append(normalized, trimmed)
|
|
}
|
|
sort.Strings(normalized)
|
|
return normalized
|
|
}
|