diff --git a/api/api.go b/api/api.go index 1402468..98df9b8 100644 --- a/api/api.go +++ b/api/api.go @@ -346,6 +346,8 @@ func RegisterRoutes(r *gin.Engine, opts ...Option) { authProtected.GET("/session", h.session) authProtected.DELETE("/session", h.deleteSession) + authProtected.GET("/xworkmate/profile", h.getXWorkmateProfile) + authProtected.PUT("/xworkmate/profile", h.updateXWorkmateProfile) authProtected.POST("/mfa/totp/provision", h.provisionTOTP) authProtected.POST("/mfa/totp/verify", h.verifyTOTP) @@ -374,6 +376,7 @@ func RegisterRoutes(r *gin.Engine, opts ...Option) { authProtected.POST("/admin/users/:userId/resume", h.resumeUser) authProtected.DELETE("/admin/users/:userId", h.deleteUser) authProtected.POST("/admin/users/:userId/renew-uuid", h.renewProxyUUID) + authProtected.POST("/admin/tenants/bootstrap", h.bootstrapTenant) authProtected.GET("/admin/blacklist", h.listBlacklist) authProtected.POST("/admin/blacklist", h.addToBlacklist) authProtected.DELETE("/admin/blacklist/:email", h.removeFromBlacklist) @@ -1416,7 +1419,17 @@ func (h *handler) session(c *gin.Context) { slog.Warn("failed to rotate sandbox proxy uuid", "err", err, "userID", user.ID) } - c.JSON(http.StatusOK, gin.H{"user": sanitizeUser(user, nil)}) + sanitized, err := h.buildSessionUser(c.Request.Context(), h.resolveTenantHost(c), user) + if err != nil { + if errors.Is(err, store.ErrTenantNotFound) { + c.JSON(http.StatusOK, gin.H{"user": sanitizeUser(user, nil)}) + return + } + respondError(c, http.StatusInternalServerError, "session_tenant_resolution_failed", "failed to resolve tenant session context") + return + } + + c.JSON(http.StatusOK, gin.H{"user": sanitized}) } func (h *handler) deleteSession(c *gin.Context) { @@ -2699,7 +2712,7 @@ func (h *handler) oauthLogin(c *gin.Context) { return } - state := h.generateState() + state := buildOAuthState(h.resolveFrontendURL(c)) // In a real app, we should store state in a secure cookie or session. // For now, we'll just redirect. c.Redirect(http.StatusTemporaryRedirect, provider.AuthCodeURL(state)) @@ -2814,7 +2827,13 @@ func (h *handler) oauthCallback(c *gin.Context) { return } - frontendURL := h.oauthFrontendURL + frontendURL := h.validateFrontendURL(parseOAuthStateFrontendURL(c.Query("state"))) + if frontendURL == "" { + frontendURL = h.resolveFrontendURL(c) + } + if frontendURL == "" { + frontendURL = h.oauthFrontendURL + } if frontendURL == "" { frontendURL = "http://localhost:3000" } diff --git a/api/xworkmate.go b/api/xworkmate.go new file mode 100644 index 0000000..285ddf3 --- /dev/null +++ b/api/xworkmate.go @@ -0,0 +1,512 @@ +package api + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/gin-gonic/gin" + + "account/internal/auth" + "account/internal/store" +) + +type xworkmateAccessContext struct { + Tenant *store.Tenant + Domain string + MembershipRole string + ProfileScope string + CanEditIntegrations bool + CanManageTenant bool +} + +type xworkmateProfilePayload struct { + OpenclawURL string `json:"openclawUrl"` + OpenclawOrigin string `json:"openclawOrigin"` + VaultURL string `json:"vaultUrl"` + VaultNamespace string `json:"vaultNamespace"` + VaultSecretPath string `json:"vaultSecretPath"` + VaultSecretKey string `json:"vaultSecretKey"` + ApisixURL string `json:"apisixUrl"` +} + +func (h *handler) ensureSharedXWorkmateTenant(ctx context.Context) error { + tenant := &store.Tenant{ + ID: store.SharedXWorkmateTenantID, + Name: store.SharedXWorkmateTenantName, + Edition: store.SharedPublicTenantEdition, + } + if err := h.store.EnsureTenant(ctx, tenant); err != nil { + return err + } + + return h.store.EnsureTenantDomain(ctx, &store.TenantDomain{ + TenantID: tenant.ID, + Domain: store.SharedXWorkmateDomain, + Kind: store.TenantDomainKindGenerated, + IsPrimary: true, + Status: store.TenantDomainStatusVerified, + }) +} + +func (h *handler) resolveTenantHost(c *gin.Context) string { + for _, headerName := range []string{"X-Forwarded-Host", "X-Original-Host", "X-Host"} { + if candidate := store.NormalizeHostname(c.GetHeader(headerName)); candidate != "" { + return candidate + } + } + if candidate := store.NormalizeHostname(c.Request.Host); candidate != "" { + return candidate + } + return store.SharedXWorkmateDomain +} + +func (h *handler) resolveFrontendURL(c *gin.Context) string { + candidates := []string{ + strings.TrimSpace(c.Query("frontend_url")), + strings.TrimSpace(c.GetHeader("X-Frontend-Url")), + strings.TrimSpace(c.GetHeader("Origin")), + } + if referer := strings.TrimSpace(c.GetHeader("Referer")); referer != "" { + if parsed, err := url.Parse(referer); err == nil && parsed.Host != "" { + candidates = append(candidates, fmt.Sprintf("%s://%s", parsed.Scheme, parsed.Host)) + } + } + candidates = append(candidates, strings.TrimSpace(h.oauthFrontendURL)) + + for _, candidate := range candidates { + if validated := h.validateFrontendURL(candidate); validated != "" { + return validated + } + } + return "" +} + +func (h *handler) validateFrontendURL(raw string) string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return "" + } + + parsed, err := url.Parse(trimmed) + if err != nil || parsed.Host == "" { + return "" + } + if parsed.Scheme != "http" && parsed.Scheme != "https" { + return "" + } + + host := store.NormalizeHostname(parsed.Host) + if host == "" { + return "" + } + if store.IsSharedTenantHost(host) { + return fmt.Sprintf("%s://%s", parsed.Scheme, parsed.Host) + } + if _, _, err := h.store.ResolveTenantByHost(context.Background(), host); err == nil { + return fmt.Sprintf("%s://%s", parsed.Scheme, parsed.Host) + } + return "" +} + +func buildOAuthState(frontendURL string) string { + nonce := generateRandomState() + trimmed := strings.TrimSpace(frontendURL) + if trimmed == "" { + return nonce + } + encoded := base64.RawURLEncoding.EncodeToString([]byte(trimmed)) + return nonce + "." + encoded +} + +func parseOAuthStateFrontendURL(state string) string { + parts := strings.SplitN(strings.TrimSpace(state), ".", 2) + if len(parts) != 2 { + return "" + } + decoded, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return "" + } + return strings.TrimSpace(string(decoded)) +} + +func generateRandomState() string { + return (&handler{}).generateState() +} + +func (h *handler) currentAuthenticatedUser(c *gin.Context) (*store.User, bool) { + userID := strings.TrimSpace(auth.GetUserID(c)) + if userID == "" || userID == "system" { + respondError(c, http.StatusUnauthorized, "session_token_required", "session token is required") + return nil, false + } + + user, err := h.store.GetUserByID(c.Request.Context(), userID) + if err != nil { + respondError(c, http.StatusUnauthorized, "session_user_lookup_failed", "failed to load session user") + return nil, false + } + if !user.Active { + respondError(c, http.StatusForbidden, "account_suspended", "your account has been suspended") + return nil, false + } + return user, true +} + +func (h *handler) ensureSharedTenantMembership(ctx context.Context, user *store.User) (string, error) { + role := store.TenantMembershipRoleUser + if h.isRootAccount(user) || strings.EqualFold(strings.TrimSpace(user.Role), store.RoleAdmin) { + role = store.TenantMembershipRoleAdmin + } + return role, h.store.UpsertTenantMembership(ctx, &store.TenantMembership{ + TenantID: store.SharedXWorkmateTenantID, + UserID: user.ID, + Role: role, + }) +} + +func (h *handler) resolveXWorkmateAccess(ctx context.Context, host string, user *store.User) (*xworkmateAccessContext, error) { + normalizedHost := store.NormalizeHostname(host) + if store.IsSharedTenantHost(normalizedHost) { + if err := h.ensureSharedXWorkmateTenant(ctx); err != nil { + return nil, err + } + } + + tenant, domain, err := h.store.ResolveTenantByHost(ctx, normalizedHost) + if err != nil { + return nil, err + } + + access := &xworkmateAccessContext{ + Tenant: tenant, + Domain: store.SharedXWorkmateDomain, + } + if domain != nil && strings.TrimSpace(domain.Domain) != "" { + access.Domain = domain.Domain + } + + if tenant.Edition == store.SharedPublicTenantEdition { + role, err := h.ensureSharedTenantMembership(ctx, user) + if err != nil { + return nil, err + } + access.MembershipRole = role + access.ProfileScope = store.XWorkmateProfileScopeTenantShared + access.CanEditIntegrations = role == store.TenantMembershipRoleAdmin + access.CanManageTenant = access.CanEditIntegrations + return access, nil + } + + membership, err := h.store.GetTenantMembership(ctx, tenant.ID, user.ID) + if err != nil { + return nil, err + } + + access.MembershipRole = membership.Role + access.ProfileScope = store.XWorkmateProfileScopeUserPrivate + access.CanEditIntegrations = true + access.CanManageTenant = membership.Role == store.TenantMembershipRoleAdmin + return access, nil +} + +func buildSessionTenantEntries(memberships []store.TenantMembership) []gin.H { + if len(memberships) == 0 { + return []gin.H{} + } + + result := make([]gin.H, 0, len(memberships)) + for _, membership := range memberships { + entry := gin.H{ + "id": membership.TenantID, + "role": membership.Role, + } + if strings.TrimSpace(membership.TenantName) != "" { + entry["name"] = membership.TenantName + } + result = append(result, entry) + } + return result +} + +func (h *handler) buildSessionUser(ctx context.Context, host string, user *store.User) (gin.H, error) { + access, err := h.resolveXWorkmateAccess(ctx, host, user) + if err != nil { + return nil, err + } + + memberships, err := h.store.ListTenantMembershipsByUser(ctx, user.ID) + if err != nil { + return nil, err + } + + payload := sanitizeUser(user, nil) + payload["tenantId"] = access.Tenant.ID + payload["tenants"] = buildSessionTenantEntries(memberships) + return payload, nil +} + +func buildXWorkmateProfileResponse(access *xworkmateAccessContext, profile *store.XWorkmateProfile) gin.H { + resolvedProfile := gin.H{ + "openclawUrl": "", + "openclawOrigin": "", + "vaultUrl": "", + "vaultNamespace": "", + "vaultSecretPath": "", + "vaultSecretKey": "", + "apisixUrl": "", + } + openclawConfigured := false + + if profile != nil { + resolvedProfile["openclawUrl"] = profile.OpenclawURL + resolvedProfile["openclawOrigin"] = profile.OpenclawOrigin + resolvedProfile["vaultUrl"] = profile.VaultURL + resolvedProfile["vaultNamespace"] = profile.VaultNamespace + resolvedProfile["vaultSecretPath"] = profile.VaultSecretPath + resolvedProfile["vaultSecretKey"] = profile.VaultSecretKey + resolvedProfile["apisixUrl"] = profile.ApisixURL + openclawConfigured = strings.TrimSpace(profile.VaultSecretPath) != "" + } + + return gin.H{ + "edition": access.Tenant.Edition, + "tenant": gin.H{ + "id": access.Tenant.ID, + "name": access.Tenant.Name, + "domain": access.Domain, + }, + "membershipRole": access.MembershipRole, + "profileScope": access.ProfileScope, + "canEditIntegrations": access.CanEditIntegrations, + "canManageTenant": access.CanManageTenant, + "profile": resolvedProfile, + "tokenConfigured": gin.H{ + "openclaw": openclawConfigured, + "vault": false, + "apisix": false, + }, + } +} + +func (h *handler) getXWorkmateProfile(c *gin.Context) { + user, ok := h.currentAuthenticatedUser(c) + if !ok { + return + } + + access, err := h.resolveXWorkmateAccess(c.Request.Context(), h.resolveTenantHost(c), user) + if err != nil { + if errors.Is(err, store.ErrTenantMembershipNotFound) { + respondError(c, http.StatusForbidden, "tenant_membership_required", "tenant membership is required") + return + } + if errors.Is(err, store.ErrTenantNotFound) { + respondError(c, http.StatusNotFound, "tenant_not_found", "tenant was not found") + return + } + respondError(c, http.StatusInternalServerError, "xworkmate_context_failed", "failed to resolve xworkmate context") + return + } + + profile, err := h.store.GetXWorkmateProfile(c.Request.Context(), access.Tenant.ID, user.ID, access.ProfileScope) + if err != nil && !errors.Is(err, store.ErrXWorkmateProfileNotFound) { + respondError(c, http.StatusInternalServerError, "xworkmate_profile_read_failed", "failed to load xworkmate profile") + return + } + if errors.Is(err, store.ErrXWorkmateProfileNotFound) { + profile = nil + } + if access.ProfileScope == store.XWorkmateProfileScopeTenantShared && profile == nil { + profile, err = h.store.GetXWorkmateProfile(c.Request.Context(), access.Tenant.ID, "", access.ProfileScope) + if err != nil && !errors.Is(err, store.ErrXWorkmateProfileNotFound) { + respondError(c, http.StatusInternalServerError, "xworkmate_profile_read_failed", "failed to load xworkmate profile") + return + } + if errors.Is(err, store.ErrXWorkmateProfileNotFound) { + profile = nil + } + } + + c.JSON(http.StatusOK, buildXWorkmateProfileResponse(access, profile)) +} + +func (h *handler) updateXWorkmateProfile(c *gin.Context) { + user, ok := h.currentAuthenticatedUser(c) + if !ok { + return + } + + access, err := h.resolveXWorkmateAccess(c.Request.Context(), h.resolveTenantHost(c), user) + if err != nil { + if errors.Is(err, store.ErrTenantMembershipNotFound) { + respondError(c, http.StatusForbidden, "tenant_membership_required", "tenant membership is required") + return + } + if errors.Is(err, store.ErrTenantNotFound) { + respondError(c, http.StatusNotFound, "tenant_not_found", "tenant was not found") + return + } + respondError(c, http.StatusInternalServerError, "xworkmate_context_failed", "failed to resolve xworkmate context") + return + } + + if !access.CanEditIntegrations { + respondError(c, http.StatusForbidden, "xworkmate_profile_forbidden", "you are not allowed to update integrations for this tenant") + return + } + if h.isReadOnlyAccount(user) { + respondError(c, http.StatusForbidden, "read_only_account", "demo account is read-only") + return + } + + var raw map[string]any + if err := c.ShouldBindJSON(&raw); err != nil { + respondError(c, http.StatusBadRequest, "invalid_request", "invalid request payload") + return + } + + for _, forbiddenField := range []string{"openclawToken", "gatewayToken", "vaultToken", "apisixToken"} { + if _, ok := raw[forbiddenField]; ok { + respondError(c, http.StatusBadRequest, "token_persistence_forbidden", "raw token fields cannot be persisted") + return + } + } + + profileValue, ok := raw["profile"] + if !ok { + profileValue = raw + } + + encodedProfile, err := json.Marshal(profileValue) + if err != nil { + respondError(c, http.StatusBadRequest, "invalid_request", "invalid profile payload") + return + } + + var payload xworkmateProfilePayload + if err := json.Unmarshal(encodedProfile, &payload); err != nil { + respondError(c, http.StatusBadRequest, "invalid_request", "invalid profile payload") + return + } + + profileUserID := user.ID + if access.ProfileScope == store.XWorkmateProfileScopeTenantShared { + profileUserID = "" + } + + profile := &store.XWorkmateProfile{ + TenantID: access.Tenant.ID, + UserID: profileUserID, + Scope: access.ProfileScope, + OpenclawURL: payload.OpenclawURL, + OpenclawOrigin: payload.OpenclawOrigin, + VaultURL: payload.VaultURL, + VaultNamespace: payload.VaultNamespace, + VaultSecretPath: payload.VaultSecretPath, + VaultSecretKey: payload.VaultSecretKey, + ApisixURL: payload.ApisixURL, + } + if err := h.store.UpsertXWorkmateProfile(c.Request.Context(), profile); err != nil { + respondError(c, http.StatusInternalServerError, "xworkmate_profile_write_failed", "failed to save xworkmate profile") + return + } + + c.JSON(http.StatusOK, buildXWorkmateProfileResponse(access, profile)) +} + +func (h *handler) bootstrapTenant(c *gin.Context) { + adminUser, ok := h.requireAdminPermission(c, permissionAdminSettingsWrite) + if !ok { + return + } + if !h.isRootAccount(adminUser) { + respondError(c, http.StatusForbidden, "root_only", "root only") + return + } + + var payload struct { + Name string `json:"name"` + AdminUserID string `json:"adminUserId"` + AdminEmail string `json:"adminEmail"` + } + if err := c.ShouldBindJSON(&payload); err != nil { + respondError(c, http.StatusBadRequest, "invalid_request", "invalid request payload") + return + } + + var member *store.User + var err error + switch { + case strings.TrimSpace(payload.AdminUserID) != "": + member, err = h.store.GetUserByID(c.Request.Context(), strings.TrimSpace(payload.AdminUserID)) + case strings.TrimSpace(payload.AdminEmail) != "": + member, err = h.store.GetUserByEmail(c.Request.Context(), strings.TrimSpace(payload.AdminEmail)) + default: + respondError(c, http.StatusBadRequest, "admin_user_required", "adminUserId or adminEmail is required") + return + } + if err != nil { + respondError(c, http.StatusNotFound, "admin_user_not_found", "admin user not found") + return + } + + domain, err := store.GenerateRandomTenantDomain() + if err != nil { + respondError(c, http.StatusInternalServerError, "tenant_domain_generation_failed", "failed to generate tenant domain") + return + } + + tenant := &store.Tenant{ + Name: strings.TrimSpace(payload.Name), + Edition: store.TenantPrivateEdition, + } + if tenant.Name == "" { + tenant.Name = member.Name + if tenant.Name == "" { + tenant.Name = member.Email + } + } + if err := h.store.EnsureTenant(c.Request.Context(), tenant); err != nil { + respondError(c, http.StatusInternalServerError, "tenant_create_failed", "failed to create tenant") + return + } + if err := h.store.EnsureTenantDomain(c.Request.Context(), &store.TenantDomain{ + TenantID: tenant.ID, + Domain: domain, + Kind: store.TenantDomainKindGenerated, + IsPrimary: true, + Status: store.TenantDomainStatusVerified, + }); err != nil { + respondError(c, http.StatusInternalServerError, "tenant_domain_create_failed", "failed to create tenant domain") + return + } + if err := h.store.UpsertTenantMembership(c.Request.Context(), &store.TenantMembership{ + TenantID: tenant.ID, + UserID: member.ID, + Role: store.TenantMembershipRoleAdmin, + }); err != nil { + respondError(c, http.StatusInternalServerError, "tenant_membership_create_failed", "failed to create tenant membership") + return + } + + c.JSON(http.StatusCreated, gin.H{ + "tenant": gin.H{ + "id": tenant.ID, + "name": tenant.Name, + "edition": tenant.Edition, + "domain": domain, + }, + "member": gin.H{ + "id": member.ID, + "email": member.Email, + "role": store.TenantMembershipRoleAdmin, + }, + }) +} diff --git a/cmd/accountsvc/main.go b/cmd/accountsvc/main.go index 15ac690..37ed226 100644 --- a/cmd/accountsvc/main.go +++ b/cmd/accountsvc/main.go @@ -638,21 +638,6 @@ func runServer(ctx context.Context, cfg *config.Config, logger *slog.Logger) err logger = slog.Default() } - r := gin.New() - corsConfig := buildCORSConfig(logger, cfg.Server) - if corsConfig.AllowAllOrigins { - logger.Info("configured cors", "allowAllOrigins", true) - } else { - logger.Info("configured cors", "allowedOrigins", corsConfig.AllowOrigins) - } - r.Use(cors.New(corsConfig)) - r.Use(gin.Recovery()) - r.Use(func(c *gin.Context) { - start := time.Now() - c.Next() - logger.Info("request", "method", c.Request.Method, "path", c.FullPath(), "status", c.Writer.Status(), "latency", time.Since(start)) - }) - storeCfg := store.Config{ Driver: cfg.Store.Driver, DSN: cfg.Store.DSN, @@ -698,6 +683,37 @@ func runServer(ctx context.Context, cfg *config.Config, logger *slog.Logger) err if err := ensureReviewUser(ctx, st, cfg.ReviewAccount, logger); err != nil { logger.Warn("failed to ensure review user", "err", err) } + if err := st.EnsureTenant(ctx, &store.Tenant{ + ID: store.SharedXWorkmateTenantID, + Name: store.SharedXWorkmateTenantName, + Edition: store.SharedPublicTenantEdition, + }); err != nil { + return fmt.Errorf("ensure shared xworkmate tenant: %w", err) + } + if err := st.EnsureTenantDomain(ctx, &store.TenantDomain{ + TenantID: store.SharedXWorkmateTenantID, + Domain: store.SharedXWorkmateDomain, + Kind: store.TenantDomainKindGenerated, + IsPrimary: true, + Status: store.TenantDomainStatusVerified, + }); err != nil { + return fmt.Errorf("ensure shared xworkmate tenant domain: %w", err) + } + + r := gin.New() + corsConfig := buildCORSConfig(logger, cfg.Server, st) + if corsConfig.AllowAllOrigins { + logger.Info("configured cors", "allowAllOrigins", true) + } else { + logger.Info("configured cors", "allowedOrigins", cfg.Server.AllowedOrigins, "dynamicTenantDomains", true) + } + r.Use(cors.New(corsConfig)) + r.Use(gin.Recovery()) + r.Use(func(c *gin.Context) { + start := time.Now() + c.Next() + logger.Info("request", "method", c.Request.Method, "path", c.FullPath(), "status", c.Writer.Status(), "latency", time.Since(start)) + }) var emailSender api.EmailSender emailVerificationEnabled := true @@ -1275,7 +1291,14 @@ func openAdminSettingsDB(cfg config.Store) (*gorm.DB, func(context.Context) erro return nil, nil, fmt.Errorf("admin settings db connection failed after sidecar wait: %w", err) } - if err := db.AutoMigrate(&model.AdminSetting{}, &model.SandboxBinding{}); err != nil { + if err := db.AutoMigrate( + &model.AdminSetting{}, + &model.SandboxBinding{}, + &model.Tenant{}, + &model.TenantDomain{}, + &model.TenantMembership{}, + &model.XWorkmateProfile{}, + ); err != nil { return nil, nil, err } @@ -1321,7 +1344,7 @@ func isExampleDomain(host string) bool { return strings.HasSuffix(normalized, ".example.com") } -func buildCORSConfig(logger *slog.Logger, serverCfg config.Server) cors.Config { +func buildCORSConfig(logger *slog.Logger, serverCfg config.Server, st store.Store) cors.Config { allowOrigins, allowAll := resolveAllowedOrigins(logger, serverCfg) cfg := cors.Config{ @@ -1352,8 +1375,34 @@ func buildCORSConfig(logger *slog.Logger, serverCfg config.Server) cors.Config { cfg.AllowAllOrigins = true cfg.AllowCredentials = false } else { - cfg.AllowOrigins = allowOrigins cfg.AllowCredentials = true + allowedOriginSet := make(map[string]struct{}, len(allowOrigins)) + for _, origin := range allowOrigins { + allowedOriginSet[origin] = struct{}{} + } + cfg.AllowOriginFunc = func(origin string) bool { + normalized, err := parseOrigin(origin) + if err != nil { + return false + } + if _, ok := allowedOriginSet[normalized]; ok { + return true + } + + parsed, err := url.Parse(normalized) + if err != nil { + return false + } + host := store.NormalizeHostname(parsed.Host) + if store.IsSharedTenantHost(host) { + return true + } + if st == nil { + return false + } + _, _, err = st.ResolveTenantByHost(context.Background(), host) + return err == nil + } } return cfg diff --git a/internal/model/xworkmate_tenant.go b/internal/model/xworkmate_tenant.go new file mode 100644 index 0000000..0d7a0a9 --- /dev/null +++ b/internal/model/xworkmate_tenant.go @@ -0,0 +1,81 @@ +package model + +import ( + "strings" + "time" + + "github.com/google/uuid" + "gorm.io/gorm" +) + +type Tenant struct { + ID string `gorm:"column:id;type:text;primaryKey"` + Name string `gorm:"column:name;type:text;not null"` + Edition string `gorm:"column:edition;type:text;not null;index"` + CreatedAt time.Time `gorm:"column:created_at;not null;autoCreateTime"` + UpdatedAt time.Time `gorm:"column:updated_at;not null;autoUpdateTime"` +} + +func (Tenant) TableName() string { return "tenants" } + +func (tenant *Tenant) BeforeCreate(tx *gorm.DB) error { + if strings.TrimSpace(tenant.ID) == "" { + tenant.ID = uuid.NewString() + } + return nil +} + +type TenantDomain struct { + ID string `gorm:"column:id;type:text;primaryKey"` + TenantID string `gorm:"column:tenant_id;type:text;not null;index"` + Domain string `gorm:"column:domain;type:text;not null;uniqueIndex"` + Kind string `gorm:"column:kind;type:text;not null"` + IsPrimary bool `gorm:"column:is_primary;not null;default:false"` + Status string `gorm:"column:status;type:text;not null;index"` + CreatedAt time.Time `gorm:"column:created_at;not null;autoCreateTime"` + UpdatedAt time.Time `gorm:"column:updated_at;not null;autoUpdateTime"` +} + +func (TenantDomain) TableName() string { return "tenant_domains" } + +func (domain *TenantDomain) BeforeCreate(tx *gorm.DB) error { + if strings.TrimSpace(domain.ID) == "" { + domain.ID = uuid.NewString() + } + return nil +} + +type TenantMembership struct { + TenantID string `gorm:"column:tenant_id;type:text;primaryKey"` + UserID string `gorm:"column:user_id;type:text;primaryKey"` + Role string `gorm:"column:role;type:text;not null;index"` + CreatedAt time.Time `gorm:"column:created_at;not null;autoCreateTime"` + UpdatedAt time.Time `gorm:"column:updated_at;not null;autoUpdateTime"` +} + +func (TenantMembership) TableName() string { return "tenant_memberships" } + +type XWorkmateProfile struct { + ID string `gorm:"column:id;type:text;primaryKey"` + TenantID string `gorm:"column:tenant_id;type:text;not null;uniqueIndex:idx_xworkmate_profiles_scope"` + UserID string `gorm:"column:user_id;type:text;not null;default:'';uniqueIndex:idx_xworkmate_profiles_scope"` + Scope string `gorm:"column:scope;type:text;not null;uniqueIndex:idx_xworkmate_profiles_scope"` + OpenclawURL string `gorm:"column:openclaw_url;type:text;not null;default:''"` + OpenclawOrigin string `gorm:"column:openclaw_origin;type:text;not null;default:''"` + VaultURL string `gorm:"column:vault_url;type:text;not null;default:''"` + VaultNamespace string `gorm:"column:vault_namespace;type:text;not null;default:''"` + VaultSecretPath string `gorm:"column:vault_secret_path;type:text;not null;default:''"` + VaultSecretKey string `gorm:"column:vault_secret_key;type:text;not null;default:''"` + ApisixURL string `gorm:"column:apisix_url;type:text;not null;default:''"` + CreatedAt time.Time `gorm:"column:created_at;not null;autoCreateTime"` + UpdatedAt time.Time `gorm:"column:updated_at;not null;autoUpdateTime"` +} + +func (XWorkmateProfile) TableName() string { return "xworkmate_profiles" } + +func (profile *XWorkmateProfile) BeforeCreate(tx *gorm.DB) error { + if strings.TrimSpace(profile.ID) == "" { + profile.ID = uuid.NewString() + } + return nil +} diff --git a/internal/store/store.go b/internal/store/store.go index 28556d3..9b3d44b 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -105,6 +105,15 @@ type Store interface { ListAgents(ctx context.Context) ([]*Agent, error) DeleteAgent(ctx context.Context, id string) error DeleteStaleAgents(ctx context.Context, staleThreshold time.Duration) (int, error) + + EnsureTenant(ctx context.Context, tenant *Tenant) error + EnsureTenantDomain(ctx context.Context, domain *TenantDomain) error + UpsertTenantMembership(ctx context.Context, membership *TenantMembership) error + ResolveTenantByHost(ctx context.Context, host string) (*Tenant, *TenantDomain, error) + ListTenantMembershipsByUser(ctx context.Context, userID string) ([]TenantMembership, error) + GetTenantMembership(ctx context.Context, tenantID, userID string) (*TenantMembership, error) + GetXWorkmateProfile(ctx context.Context, tenantID, userID, scope string) (*XWorkmateProfile, error) + UpsertXWorkmateProfile(ctx context.Context, profile *XWorkmateProfile) error } // Domain level errors returned by the store implementation. @@ -131,6 +140,10 @@ type memoryStore struct { identities map[string]*Identity agents map[string]*Agent sessions map[string]*sessionRecord + tenants map[string]*Tenant + tenantDomains map[string]*TenantDomain + tenantMemberships map[string]map[string]*TenantMembership + xworkmateProfiles map[string]*XWorkmateProfile } type sessionRecord struct { @@ -165,6 +178,10 @@ func newMemoryStore(allowSuperAdminCounting bool) Store { identities: make(map[string]*Identity), agents: make(map[string]*Agent), sessions: make(map[string]*sessionRecord), + tenants: make(map[string]*Tenant), + tenantDomains: make(map[string]*TenantDomain), + tenantMemberships: make(map[string]map[string]*TenantMembership), + xworkmateProfiles: make(map[string]*XWorkmateProfile), } } diff --git a/internal/store/xworkmate.go b/internal/store/xworkmate.go new file mode 100644 index 0000000..48557f1 --- /dev/null +++ b/internal/store/xworkmate.go @@ -0,0 +1,232 @@ +package store + +import ( + "context" + "crypto/rand" + "encoding/hex" + "errors" + "fmt" + "net" + "net/url" + "strings" + "time" +) + +const ( + SharedPublicTenantEdition = "shared_public" + TenantPrivateEdition = "tenant_private" + + TenantMembershipRoleAdmin = "admin" + TenantMembershipRoleUser = "user" + + TenantDomainKindGenerated = "generated" + TenantDomainKindCustom = "custom" + + TenantDomainStatusPending = "pending" + TenantDomainStatusVerified = "verified" + + XWorkmateProfileScopeTenantShared = "tenant-shared" + XWorkmateProfileScopeUserPrivate = "user-private" + + SharedXWorkmateTenantID = "svc-plus-xworkmate" + SharedXWorkmateTenantName = "svc.plus XWorkmate" + SharedXWorkmateDomain = "svc.plus" +) + +var ( + ErrTenantNotFound = errors.New("tenant not found") + ErrTenantMembershipNotFound = errors.New("tenant membership not found") + ErrXWorkmateProfileNotFound = errors.New("xworkmate profile not found") +) + +type Tenant struct { + ID string + Name string + Edition string + CreatedAt time.Time + UpdatedAt time.Time +} + +type TenantDomain struct { + ID string + TenantID string + Domain string + Kind string + IsPrimary bool + Status string + CreatedAt time.Time + UpdatedAt time.Time +} + +type TenantMembership struct { + TenantID string + UserID string + Role string + TenantName string + TenantEdition string + Domain string + CreatedAt time.Time + UpdatedAt time.Time +} + +type XWorkmateProfile struct { + ID string + TenantID string + UserID string + Scope string + OpenclawURL string + OpenclawOrigin string + VaultURL string + VaultNamespace string + VaultSecretPath string + VaultSecretKey string + ApisixURL string + CreatedAt time.Time + UpdatedAt time.Time +} + +func NormalizeTenantEdition(value string) string { + switch strings.ToLower(strings.TrimSpace(value)) { + case SharedPublicTenantEdition: + return SharedPublicTenantEdition + default: + return TenantPrivateEdition + } +} + +func NormalizeTenantMembershipRole(value string) string { + switch strings.ToLower(strings.TrimSpace(value)) { + case TenantMembershipRoleAdmin: + return TenantMembershipRoleAdmin + default: + return TenantMembershipRoleUser + } +} + +func NormalizeTenantDomainKind(value string) string { + switch strings.ToLower(strings.TrimSpace(value)) { + case TenantDomainKindCustom: + return TenantDomainKindCustom + default: + return TenantDomainKindGenerated + } +} + +func NormalizeTenantDomainStatus(value string) string { + switch strings.ToLower(strings.TrimSpace(value)) { + case TenantDomainStatusPending: + return TenantDomainStatusPending + default: + return TenantDomainStatusVerified + } +} + +func NormalizeXWorkmateProfileScope(value string) string { + switch strings.ToLower(strings.TrimSpace(value)) { + case XWorkmateProfileScopeTenantShared: + return XWorkmateProfileScopeTenantShared + default: + return XWorkmateProfileScopeUserPrivate + } +} + +func NormalizeHostname(value string) string { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return "" + } + if comma := strings.Index(trimmed, ","); comma >= 0 { + trimmed = strings.TrimSpace(trimmed[:comma]) + } + + if parsed, err := url.Parse(trimmed); err == nil && parsed.Host != "" { + trimmed = parsed.Host + } + + if host, _, err := net.SplitHostPort(trimmed); err == nil { + trimmed = host + } + + return strings.Trim(strings.ToLower(trimmed), ".") +} + +func IsSharedTenantHost(host string) bool { + normalized := NormalizeHostname(host) + if normalized == "" { + return true + } + switch normalized { + case "svc.plus", "www.svc.plus", "console.svc.plus", "localhost", "127.0.0.1", "[::1]": + return true + default: + return false + } +} + +func GenerateRandomTenantDomain() (string, error) { + buffer := make([]byte, 4) + if _, err := rand.Read(buffer); err != nil { + return "", fmt.Errorf("generate tenant domain: %w", err) + } + return fmt.Sprintf("xw-%s.svc.plus", hex.EncodeToString(buffer)), nil +} + +func NormalizeTenant(tenant *Tenant) { + if tenant == nil { + return + } + tenant.ID = strings.TrimSpace(tenant.ID) + tenant.Name = strings.TrimSpace(tenant.Name) + tenant.Edition = NormalizeTenantEdition(tenant.Edition) +} + +func NormalizeTenantDomain(domain *TenantDomain) { + if domain == nil { + return + } + domain.ID = strings.TrimSpace(domain.ID) + domain.TenantID = strings.TrimSpace(domain.TenantID) + domain.Domain = NormalizeHostname(domain.Domain) + domain.Kind = NormalizeTenantDomainKind(domain.Kind) + domain.Status = NormalizeTenantDomainStatus(domain.Status) +} + +func NormalizeTenantMembership(membership *TenantMembership) { + if membership == nil { + return + } + membership.TenantID = strings.TrimSpace(membership.TenantID) + membership.UserID = strings.TrimSpace(membership.UserID) + membership.Role = NormalizeTenantMembershipRole(membership.Role) + membership.TenantName = strings.TrimSpace(membership.TenantName) + membership.TenantEdition = NormalizeTenantEdition(membership.TenantEdition) + membership.Domain = NormalizeHostname(membership.Domain) +} + +func NormalizeXWorkmateProfile(profile *XWorkmateProfile) { + if profile == nil { + return + } + profile.ID = strings.TrimSpace(profile.ID) + profile.TenantID = strings.TrimSpace(profile.TenantID) + profile.UserID = strings.TrimSpace(profile.UserID) + profile.Scope = NormalizeXWorkmateProfileScope(profile.Scope) + profile.OpenclawURL = strings.TrimSpace(profile.OpenclawURL) + profile.OpenclawOrigin = strings.TrimSpace(profile.OpenclawOrigin) + profile.VaultURL = strings.TrimSpace(profile.VaultURL) + profile.VaultNamespace = strings.TrimSpace(profile.VaultNamespace) + profile.VaultSecretPath = strings.Trim(strings.TrimSpace(profile.VaultSecretPath), "/") + profile.VaultSecretKey = strings.TrimSpace(profile.VaultSecretKey) + profile.ApisixURL = strings.TrimSpace(profile.ApisixURL) +} + +type TenantResolver interface { + EnsureTenant(ctx context.Context, tenant *Tenant) error + EnsureTenantDomain(ctx context.Context, domain *TenantDomain) error + UpsertTenantMembership(ctx context.Context, membership *TenantMembership) error + ResolveTenantByHost(ctx context.Context, host string) (*Tenant, *TenantDomain, error) + ListTenantMembershipsByUser(ctx context.Context, userID string) ([]TenantMembership, error) + GetTenantMembership(ctx context.Context, tenantID, userID string) (*TenantMembership, error) + GetXWorkmateProfile(ctx context.Context, tenantID, userID, scope string) (*XWorkmateProfile, error) + UpsertXWorkmateProfile(ctx context.Context, profile *XWorkmateProfile) error +} diff --git a/internal/store/xworkmate_memory.go b/internal/store/xworkmate_memory.go new file mode 100644 index 0000000..cca4ec3 --- /dev/null +++ b/internal/store/xworkmate_memory.go @@ -0,0 +1,323 @@ +package store + +import ( + "context" + "sort" + "strings" + "time" + + "github.com/google/uuid" +) + +func tenantProfileKey(tenantID, userID, scope string) string { + return strings.Join([]string{ + strings.TrimSpace(tenantID), + strings.TrimSpace(userID), + NormalizeXWorkmateProfileScope(scope), + }, "|") +} + +func (s *memoryStore) EnsureTenant(ctx context.Context, tenant *Tenant) error { + _ = ctx + if tenant == nil { + return ErrTenantNotFound + } + + NormalizeTenant(tenant) + if tenant.ID == "" { + tenant.ID = uuid.NewString() + } + + now := time.Now().UTC() + s.mu.Lock() + defer s.mu.Unlock() + + existing, ok := s.tenants[tenant.ID] + if ok { + existing.Name = tenant.Name + existing.Edition = tenant.Edition + existing.UpdatedAt = now + tenant.CreatedAt = existing.CreatedAt + tenant.UpdatedAt = existing.UpdatedAt + return nil + } + + stored := &Tenant{ + ID: tenant.ID, + Name: tenant.Name, + Edition: tenant.Edition, + CreatedAt: now, + UpdatedAt: now, + } + s.tenants[stored.ID] = stored + tenant.CreatedAt = stored.CreatedAt + tenant.UpdatedAt = stored.UpdatedAt + return nil +} + +func (s *memoryStore) EnsureTenantDomain(ctx context.Context, domain *TenantDomain) error { + _ = ctx + if domain == nil { + return ErrTenantNotFound + } + + NormalizeTenantDomain(domain) + if domain.Domain == "" || domain.TenantID == "" { + return ErrTenantNotFound + } + if domain.ID == "" { + domain.ID = uuid.NewString() + } + + now := time.Now().UTC() + s.mu.Lock() + defer s.mu.Unlock() + + existing, ok := s.tenantDomains[domain.Domain] + if ok { + existing.TenantID = domain.TenantID + existing.Kind = domain.Kind + existing.IsPrimary = domain.IsPrimary + existing.Status = domain.Status + existing.UpdatedAt = now + domain.CreatedAt = existing.CreatedAt + domain.UpdatedAt = existing.UpdatedAt + return nil + } + + stored := &TenantDomain{ + ID: domain.ID, + TenantID: domain.TenantID, + Domain: domain.Domain, + Kind: domain.Kind, + IsPrimary: domain.IsPrimary, + Status: domain.Status, + CreatedAt: now, + UpdatedAt: now, + } + s.tenantDomains[stored.Domain] = stored + domain.CreatedAt = stored.CreatedAt + domain.UpdatedAt = stored.UpdatedAt + return nil +} + +func (s *memoryStore) UpsertTenantMembership(ctx context.Context, membership *TenantMembership) error { + _ = ctx + if membership == nil { + return ErrTenantMembershipNotFound + } + + NormalizeTenantMembership(membership) + if membership.TenantID == "" || membership.UserID == "" { + return ErrTenantMembershipNotFound + } + + now := time.Now().UTC() + s.mu.Lock() + defer s.mu.Unlock() + + if s.tenantMemberships[membership.TenantID] == nil { + s.tenantMemberships[membership.TenantID] = make(map[string]*TenantMembership) + } + + if existing, ok := s.tenantMemberships[membership.TenantID][membership.UserID]; ok { + existing.Role = membership.Role + existing.UpdatedAt = now + membership.CreatedAt = existing.CreatedAt + membership.UpdatedAt = existing.UpdatedAt + return nil + } + + stored := &TenantMembership{ + TenantID: membership.TenantID, + UserID: membership.UserID, + Role: membership.Role, + CreatedAt: now, + UpdatedAt: now, + } + s.tenantMemberships[membership.TenantID][membership.UserID] = stored + membership.CreatedAt = stored.CreatedAt + membership.UpdatedAt = stored.UpdatedAt + return nil +} + +func (s *memoryStore) ResolveTenantByHost(ctx context.Context, host string) (*Tenant, *TenantDomain, error) { + _ = ctx + normalizedHost := NormalizeHostname(host) + + s.mu.RLock() + defer s.mu.RUnlock() + + if IsSharedTenantHost(normalizedHost) { + tenant, ok := s.tenants[SharedXWorkmateTenantID] + if !ok { + return nil, nil, ErrTenantNotFound + } + var domain *TenantDomain + if storedDomain, ok := s.tenantDomains[SharedXWorkmateDomain]; ok { + domainCopy := *storedDomain + domain = &domainCopy + } + tenantCopy := *tenant + return &tenantCopy, domain, nil + } + + domain, ok := s.tenantDomains[normalizedHost] + if !ok { + return nil, nil, ErrTenantNotFound + } + tenant, ok := s.tenants[domain.TenantID] + if !ok { + return nil, nil, ErrTenantNotFound + } + + tenantCopy := *tenant + domainCopy := *domain + return &tenantCopy, &domainCopy, nil +} + +func (s *memoryStore) ListTenantMembershipsByUser(ctx context.Context, userID string) ([]TenantMembership, error) { + _ = ctx + normalizedUserID := strings.TrimSpace(userID) + + s.mu.RLock() + defer s.mu.RUnlock() + + result := make([]TenantMembership, 0) + for tenantID, members := range s.tenantMemberships { + member, ok := members[normalizedUserID] + if !ok { + continue + } + + entry := *member + if tenant, ok := s.tenants[tenantID]; ok { + entry.TenantName = tenant.Name + entry.TenantEdition = tenant.Edition + } + for _, domain := range s.tenantDomains { + if domain.TenantID == tenantID && domain.IsPrimary { + entry.Domain = domain.Domain + break + } + } + result = append(result, entry) + } + + sort.Slice(result, func(i, j int) bool { + if result[i].TenantName == result[j].TenantName { + return result[i].TenantID < result[j].TenantID + } + return result[i].TenantName < result[j].TenantName + }) + + return result, nil +} + +func (s *memoryStore) GetTenantMembership(ctx context.Context, tenantID, userID string) (*TenantMembership, error) { + _ = ctx + normalizedTenantID := strings.TrimSpace(tenantID) + normalizedUserID := strings.TrimSpace(userID) + + s.mu.RLock() + defer s.mu.RUnlock() + + members := s.tenantMemberships[normalizedTenantID] + if members == nil { + return nil, ErrTenantMembershipNotFound + } + member, ok := members[normalizedUserID] + if !ok { + return nil, ErrTenantMembershipNotFound + } + + entry := *member + if tenant, ok := s.tenants[normalizedTenantID]; ok { + entry.TenantName = tenant.Name + entry.TenantEdition = tenant.Edition + } + for _, domain := range s.tenantDomains { + if domain.TenantID == normalizedTenantID && domain.IsPrimary { + entry.Domain = domain.Domain + break + } + } + return &entry, nil +} + +func (s *memoryStore) GetXWorkmateProfile(ctx context.Context, tenantID, userID, scope string) (*XWorkmateProfile, error) { + _ = ctx + key := tenantProfileKey(tenantID, userID, scope) + + s.mu.RLock() + defer s.mu.RUnlock() + + profile, ok := s.xworkmateProfiles[key] + if !ok { + return nil, ErrXWorkmateProfileNotFound + } + + entry := *profile + return &entry, nil +} + +func (s *memoryStore) UpsertXWorkmateProfile(ctx context.Context, profile *XWorkmateProfile) error { + _ = ctx + if profile == nil { + return ErrXWorkmateProfileNotFound + } + + NormalizeXWorkmateProfile(profile) + if profile.TenantID == "" { + return ErrXWorkmateProfileNotFound + } + if profile.Scope == XWorkmateProfileScopeUserPrivate && profile.UserID == "" { + return ErrXWorkmateProfileNotFound + } + if profile.Scope == XWorkmateProfileScopeTenantShared { + profile.UserID = "" + } + if profile.ID == "" { + profile.ID = uuid.NewString() + } + + now := time.Now().UTC() + key := tenantProfileKey(profile.TenantID, profile.UserID, profile.Scope) + + s.mu.Lock() + defer s.mu.Unlock() + + if existing, ok := s.xworkmateProfiles[key]; ok { + existing.OpenclawURL = profile.OpenclawURL + existing.OpenclawOrigin = profile.OpenclawOrigin + existing.VaultURL = profile.VaultURL + existing.VaultNamespace = profile.VaultNamespace + existing.VaultSecretPath = profile.VaultSecretPath + existing.VaultSecretKey = profile.VaultSecretKey + existing.ApisixURL = profile.ApisixURL + existing.UpdatedAt = now + profile.CreatedAt = existing.CreatedAt + profile.UpdatedAt = existing.UpdatedAt + return nil + } + + stored := &XWorkmateProfile{ + ID: profile.ID, + TenantID: profile.TenantID, + UserID: profile.UserID, + Scope: profile.Scope, + OpenclawURL: profile.OpenclawURL, + OpenclawOrigin: profile.OpenclawOrigin, + VaultURL: profile.VaultURL, + VaultNamespace: profile.VaultNamespace, + VaultSecretPath: profile.VaultSecretPath, + VaultSecretKey: profile.VaultSecretKey, + ApisixURL: profile.ApisixURL, + CreatedAt: now, + UpdatedAt: now, + } + s.xworkmateProfiles[key] = stored + profile.CreatedAt = stored.CreatedAt + profile.UpdatedAt = stored.UpdatedAt + return nil +} diff --git a/internal/store/xworkmate_postgres.go b/internal/store/xworkmate_postgres.go new file mode 100644 index 0000000..b977a12 --- /dev/null +++ b/internal/store/xworkmate_postgres.go @@ -0,0 +1,302 @@ +package store + +import ( + "context" + "database/sql" + "errors" + "strings" + + "github.com/google/uuid" +) + +func (s *postgresStore) EnsureTenant(ctx context.Context, tenant *Tenant) error { + if tenant == nil { + return ErrTenantNotFound + } + + NormalizeTenant(tenant) + if tenant.ID == "" { + tenant.ID = uuid.NewString() + } + + query := `INSERT INTO tenants (id, name, edition, created_at, updated_at) +VALUES ($1, $2, $3, now(), now()) +ON CONFLICT (id) DO UPDATE +SET name = EXCLUDED.name, + edition = EXCLUDED.edition, + updated_at = now() +RETURNING created_at, updated_at` + + return s.db.QueryRowContext(ctx, query, tenant.ID, tenant.Name, tenant.Edition).Scan(&tenant.CreatedAt, &tenant.UpdatedAt) +} + +func (s *postgresStore) EnsureTenantDomain(ctx context.Context, domain *TenantDomain) error { + if domain == nil { + return ErrTenantNotFound + } + + NormalizeTenantDomain(domain) + if domain.ID == "" { + domain.ID = uuid.NewString() + } + + query := `INSERT INTO tenant_domains (id, tenant_id, domain, kind, is_primary, status, created_at, updated_at) +VALUES ($1, $2, $3, $4, $5, $6, now(), now()) +ON CONFLICT (domain) DO UPDATE +SET tenant_id = EXCLUDED.tenant_id, + kind = EXCLUDED.kind, + is_primary = EXCLUDED.is_primary, + status = EXCLUDED.status, + updated_at = now() +RETURNING created_at, updated_at` + + return s.db.QueryRowContext( + ctx, + query, + domain.ID, + domain.TenantID, + domain.Domain, + domain.Kind, + domain.IsPrimary, + domain.Status, + ).Scan(&domain.CreatedAt, &domain.UpdatedAt) +} + +func (s *postgresStore) UpsertTenantMembership(ctx context.Context, membership *TenantMembership) error { + if membership == nil { + return ErrTenantMembershipNotFound + } + + NormalizeTenantMembership(membership) + query := `INSERT INTO tenant_memberships (tenant_id, user_id, role, created_at, updated_at) +VALUES ($1, $2, $3, now(), now()) +ON CONFLICT (tenant_id, user_id) DO UPDATE +SET role = EXCLUDED.role, + updated_at = now() +RETURNING created_at, updated_at` + + return s.db.QueryRowContext(ctx, query, membership.TenantID, membership.UserID, membership.Role).Scan(&membership.CreatedAt, &membership.UpdatedAt) +} + +func (s *postgresStore) ResolveTenantByHost(ctx context.Context, host string) (*Tenant, *TenantDomain, error) { + normalizedHost := NormalizeHostname(host) + + if IsSharedTenantHost(normalizedHost) { + query := `SELECT t.id, t.name, t.edition, t.created_at, t.updated_at, + COALESCE(td.id, ''), COALESCE(td.domain, ''), COALESCE(td.kind, ''), COALESCE(td.is_primary, false), COALESCE(td.status, ''), td.created_at, td.updated_at +FROM tenants t +LEFT JOIN tenant_domains td + ON td.tenant_id = t.id AND td.is_primary = TRUE +WHERE t.id = $1 +LIMIT 1` + return scanTenantResolutionRow(s.db.QueryRowContext(ctx, query, SharedXWorkmateTenantID)) + } + + query := `SELECT t.id, t.name, t.edition, t.created_at, t.updated_at, + td.id, td.domain, td.kind, td.is_primary, td.status, td.created_at, td.updated_at +FROM tenant_domains td +JOIN tenants t ON t.id = td.tenant_id +WHERE td.domain = $1 AND td.status = $2 +LIMIT 1` + return scanTenantResolutionRow(s.db.QueryRowContext(ctx, query, normalizedHost, TenantDomainStatusVerified)) +} + +func scanTenantResolutionRow(row *sql.Row) (*Tenant, *TenantDomain, error) { + tenant := &Tenant{} + var ( + domainID string + domainName string + domainKind string + domainIsPrimary bool + domainStatus string + domainCreatedAt sql.NullTime + domainUpdatedAt sql.NullTime + ) + if err := row.Scan( + &tenant.ID, + &tenant.Name, + &tenant.Edition, + &tenant.CreatedAt, + &tenant.UpdatedAt, + &domainID, + &domainName, + &domainKind, + &domainIsPrimary, + &domainStatus, + &domainCreatedAt, + &domainUpdatedAt, + ); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil, ErrTenantNotFound + } + return nil, nil, err + } + + var domain *TenantDomain + if strings.TrimSpace(domainName) != "" { + domain = &TenantDomain{ + ID: domainID, + TenantID: tenant.ID, + Domain: domainName, + Kind: domainKind, + IsPrimary: domainIsPrimary, + Status: domainStatus, + } + if domainCreatedAt.Valid { + domain.CreatedAt = domainCreatedAt.Time + } + if domainUpdatedAt.Valid { + domain.UpdatedAt = domainUpdatedAt.Time + } + } + + return tenant, domain, nil +} + +func (s *postgresStore) ListTenantMembershipsByUser(ctx context.Context, userID string) ([]TenantMembership, error) { + query := `SELECT tm.tenant_id, tm.user_id, tm.role, tm.created_at, tm.updated_at, + COALESCE(t.name, ''), COALESCE(t.edition, ''), COALESCE(td.domain, '') +FROM tenant_memberships tm +JOIN tenants t ON t.id = tm.tenant_id +LEFT JOIN tenant_domains td ON td.tenant_id = tm.tenant_id AND td.is_primary = TRUE +WHERE tm.user_id = $1 +ORDER BY t.name ASC, tm.tenant_id ASC` + + rows, err := s.db.QueryContext(ctx, query, strings.TrimSpace(userID)) + if err != nil { + return nil, err + } + defer rows.Close() + + result := make([]TenantMembership, 0) + for rows.Next() { + var membership TenantMembership + if err := rows.Scan( + &membership.TenantID, + &membership.UserID, + &membership.Role, + &membership.CreatedAt, + &membership.UpdatedAt, + &membership.TenantName, + &membership.TenantEdition, + &membership.Domain, + ); err != nil { + return nil, err + } + result = append(result, membership) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return result, nil +} + +func (s *postgresStore) GetTenantMembership(ctx context.Context, tenantID, userID string) (*TenantMembership, error) { + query := `SELECT tm.tenant_id, tm.user_id, tm.role, tm.created_at, tm.updated_at, + COALESCE(t.name, ''), COALESCE(t.edition, ''), COALESCE(td.domain, '') +FROM tenant_memberships tm +JOIN tenants t ON t.id = tm.tenant_id +LEFT JOIN tenant_domains td ON td.tenant_id = tm.tenant_id AND td.is_primary = TRUE +WHERE tm.tenant_id = $1 AND tm.user_id = $2 +LIMIT 1` + + membership := &TenantMembership{} + if err := s.db.QueryRowContext(ctx, query, strings.TrimSpace(tenantID), strings.TrimSpace(userID)).Scan( + &membership.TenantID, + &membership.UserID, + &membership.Role, + &membership.CreatedAt, + &membership.UpdatedAt, + &membership.TenantName, + &membership.TenantEdition, + &membership.Domain, + ); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrTenantMembershipNotFound + } + return nil, err + } + + return membership, nil +} + +func (s *postgresStore) GetXWorkmateProfile(ctx context.Context, tenantID, userID, scope string) (*XWorkmateProfile, error) { + profile := &XWorkmateProfile{} + query := `SELECT id, tenant_id, user_id, scope, openclaw_url, openclaw_origin, vault_url, vault_namespace, vault_secret_path, vault_secret_key, apisix_url, created_at, updated_at +FROM xworkmate_profiles +WHERE tenant_id = $1 AND user_id = $2 AND scope = $3 +LIMIT 1` + + if err := s.db.QueryRowContext( + ctx, + query, + strings.TrimSpace(tenantID), + strings.TrimSpace(userID), + NormalizeXWorkmateProfileScope(scope), + ).Scan( + &profile.ID, + &profile.TenantID, + &profile.UserID, + &profile.Scope, + &profile.OpenclawURL, + &profile.OpenclawOrigin, + &profile.VaultURL, + &profile.VaultNamespace, + &profile.VaultSecretPath, + &profile.VaultSecretKey, + &profile.ApisixURL, + &profile.CreatedAt, + &profile.UpdatedAt, + ); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrXWorkmateProfileNotFound + } + return nil, err + } + + return profile, nil +} + +func (s *postgresStore) UpsertXWorkmateProfile(ctx context.Context, profile *XWorkmateProfile) error { + if profile == nil { + return ErrXWorkmateProfileNotFound + } + + NormalizeXWorkmateProfile(profile) + if profile.ID == "" { + profile.ID = uuid.NewString() + } + + query := `INSERT INTO xworkmate_profiles ( + id, tenant_id, user_id, scope, openclaw_url, openclaw_origin, vault_url, vault_namespace, vault_secret_path, vault_secret_key, apisix_url, created_at, updated_at +) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, now(), now()) +ON CONFLICT (tenant_id, user_id, scope) DO UPDATE +SET openclaw_url = EXCLUDED.openclaw_url, + openclaw_origin = EXCLUDED.openclaw_origin, + vault_url = EXCLUDED.vault_url, + vault_namespace = EXCLUDED.vault_namespace, + vault_secret_path = EXCLUDED.vault_secret_path, + vault_secret_key = EXCLUDED.vault_secret_key, + apisix_url = EXCLUDED.apisix_url, + updated_at = now() +RETURNING created_at, updated_at` + + return s.db.QueryRowContext( + ctx, + query, + profile.ID, + profile.TenantID, + profile.UserID, + profile.Scope, + profile.OpenclawURL, + profile.OpenclawOrigin, + profile.VaultURL, + profile.VaultNamespace, + profile.VaultSecretPath, + profile.VaultSecretKey, + profile.ApisixURL, + ).Scan(&profile.CreatedAt, &profile.UpdatedAt) +} diff --git a/internal/store/xworkmate_test.go b/internal/store/xworkmate_test.go new file mode 100644 index 0000000..d5e753d --- /dev/null +++ b/internal/store/xworkmate_test.go @@ -0,0 +1,125 @@ +package store + +import ( + "context" + "strings" + "testing" +) + +func TestNormalizeHostname(t *testing.T) { + t.Parallel() + + got := NormalizeHostname("https://XW-ABCD.svc.plus:443/path?q=1") + if got != "xw-abcd.svc.plus" { + t.Fatalf("expected normalized host, got %q", got) + } +} + +func TestGenerateRandomTenantDomain(t *testing.T) { + t.Parallel() + + got, err := GenerateRandomTenantDomain() + if err != nil { + t.Fatalf("expected domain generation to succeed: %v", err) + } + if !strings.HasPrefix(got, "xw-") || !strings.HasSuffix(got, ".svc.plus") { + t.Fatalf("expected generated svc.plus tenant domain, got %q", got) + } +} + +func TestMemoryStoreResolveTenantAndProfile(t *testing.T) { + ctx := context.Background() + st := NewMemoryStore() + + if err := st.EnsureTenant(ctx, &Tenant{ + ID: SharedXWorkmateTenantID, + Name: SharedXWorkmateTenantName, + Edition: SharedPublicTenantEdition, + }); err != nil { + t.Fatalf("ensure shared tenant: %v", err) + } + if err := st.EnsureTenantDomain(ctx, &TenantDomain{ + TenantID: SharedXWorkmateTenantID, + Domain: SharedXWorkmateDomain, + Kind: TenantDomainKindGenerated, + IsPrimary: true, + Status: TenantDomainStatusVerified, + }); err != nil { + t.Fatalf("ensure shared domain: %v", err) + } + + tenant, domain, err := st.ResolveTenantByHost(ctx, "console.svc.plus") + if err != nil { + t.Fatalf("resolve shared tenant: %v", err) + } + if tenant.ID != SharedXWorkmateTenantID { + t.Fatalf("expected shared tenant id, got %q", tenant.ID) + } + if domain == nil || domain.Domain != SharedXWorkmateDomain { + t.Fatalf("expected shared primary domain, got %#v", domain) + } + + privateTenant := &Tenant{ + ID: "tenant-private-1", + Name: "Tenant One", + Edition: TenantPrivateEdition, + } + if err := st.EnsureTenant(ctx, privateTenant); err != nil { + t.Fatalf("ensure private tenant: %v", err) + } + if err := st.EnsureTenantDomain(ctx, &TenantDomain{ + TenantID: privateTenant.ID, + Domain: "xw-tenant-one.svc.plus", + Kind: TenantDomainKindGenerated, + IsPrimary: true, + Status: TenantDomainStatusVerified, + }); err != nil { + t.Fatalf("ensure private domain: %v", err) + } + if err := st.UpsertTenantMembership(ctx, &TenantMembership{ + TenantID: privateTenant.ID, + UserID: "user-1", + Role: TenantMembershipRoleAdmin, + }); err != nil { + t.Fatalf("ensure private membership: %v", err) + } + if err := st.UpsertXWorkmateProfile(ctx, &XWorkmateProfile{ + TenantID: privateTenant.ID, + UserID: "user-1", + Scope: XWorkmateProfileScopeUserPrivate, + OpenclawURL: "wss://openclaw.tenant-one.svc.plus", + VaultSecretPath: "kv/openclaw", + }); err != nil { + t.Fatalf("upsert private profile: %v", err) + } + + tenant, domain, err = st.ResolveTenantByHost(ctx, "https://xw-tenant-one.svc.plus") + if err != nil { + t.Fatalf("resolve private tenant: %v", err) + } + if tenant.ID != privateTenant.ID { + t.Fatalf("expected tenant %q, got %q", privateTenant.ID, tenant.ID) + } + if domain == nil || domain.Domain != "xw-tenant-one.svc.plus" { + t.Fatalf("expected tenant domain, got %#v", domain) + } + + profile, err := st.GetXWorkmateProfile(ctx, privateTenant.ID, "user-1", XWorkmateProfileScopeUserPrivate) + if err != nil { + t.Fatalf("get private profile: %v", err) + } + if profile.OpenclawURL != "wss://openclaw.tenant-one.svc.plus" { + t.Fatalf("expected persisted openclaw url, got %q", profile.OpenclawURL) + } + + memberships, err := st.ListTenantMembershipsByUser(ctx, "user-1") + if err != nil { + t.Fatalf("list memberships: %v", err) + } + if len(memberships) != 1 { + t.Fatalf("expected 1 tenant membership, got %d", len(memberships)) + } + if memberships[0].TenantName != "Tenant One" { + t.Fatalf("expected tenant name to be populated, got %q", memberships[0].TenantName) + } +}