commit 5580f0ae1a502a25c3bf644a5d1b7c530ca48ea0 Author: Haitao Pan Date: Thu Apr 9 06:20:30 2026 +0800 feat(billing): bootstrap billing service diff --git a/README.md b/README.md new file mode 100644 index 0000000..5e7032a --- /dev/null +++ b/README.md @@ -0,0 +1,15 @@ +# billing-service + +`billing-service` is the v1 minute-delta and replay-safe writer for the Cloud +Network Billing & Control Plane. + +It pulls the latest normalized snapshot from `xray-exporter`, computes deltas +from cumulative counters, and writes idempotent usage and billing facts into the +existing `accounts.svc.plus` PostgreSQL schema. + +## Endpoints + +- `POST /v1/jobs/collect-and-rate` +- `POST /v1/jobs/reconcile` +- `GET /healthz` +- `GET /v1/status` diff --git a/cmd/billing-service/main.go b/cmd/billing-service/main.go new file mode 100644 index 0000000..7b89737 --- /dev/null +++ b/cmd/billing-service/main.go @@ -0,0 +1,56 @@ +package main + +import ( + "context" + "database/sql" + "log" + "net/http" + "os/signal" + "syscall" + + "billing-service/internal/config" + "billing-service/internal/exporter" + "billing-service/internal/httpapi" + "billing-service/internal/repository" + "billing-service/internal/service" + + _ "github.com/jackc/pgx/v5/stdlib" +) + +func main() { + cfg, err := config.Load() + if err != nil { + log.Fatal(err) + } + + db, err := sql.Open("pgx", cfg.DatabaseURL) + if err != nil { + log.Fatal(err) + } + defer db.Close() + + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() + + svc := service.New( + cfg, + exporter.NewClient(cfg.ExporterBaseURL), + repository.NewPostgres(db), + ) + svc.Start(ctx) + + server := &http.Server{ + Addr: cfg.ListenAddr, + Handler: httpapi.New(svc).Routes(), + } + + go func() { + <-ctx.Done() + _ = server.Shutdown(context.Background()) + }() + + log.Printf("billing-service listening on %s", cfg.ListenAddr) + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Fatal(err) + } +} diff --git a/docker-compose.postgres.yml b/docker-compose.postgres.yml new file mode 100644 index 0000000..2f47dd7 --- /dev/null +++ b/docker-compose.postgres.yml @@ -0,0 +1,15 @@ +services: + postgres: + image: postgres:16 + container_name: billing-service-test-postgres + environment: + POSTGRES_DB: billing_service_test + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + ports: + - "55432:5432" + healthcheck: + test: ["CMD-SHELL", "pg_isready -U postgres -d billing_service_test"] + interval: 2s + timeout: 2s + retries: 20 diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..a023b01 --- /dev/null +++ b/go.mod @@ -0,0 +1,17 @@ +module billing-service + +go 1.25.1 + +require ( + github.com/google/uuid v1.6.0 + github.com/jackc/pgx/v5 v5.7.6 +) + +require ( + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect + golang.org/x/crypto v0.37.0 // indirect + golang.org/x/sync v0.13.0 // indirect + golang.org/x/text v0.24.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..084d1d3 --- /dev/null +++ b/go.sum @@ -0,0 +1,30 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk= +github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= +golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= +golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= +golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= +golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..6954fae --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,84 @@ +package config + +import ( + "fmt" + "os" + "strconv" + "strings" + "time" +) + +type Config struct { + ExporterBaseURL string + DatabaseURL string + ListenAddr string + CollectInterval time.Duration + DefaultRegion string + SourceRevision string + PricePerByte float64 + InitialIncludedQuotaBytes int64 + InitialBalance float64 +} + +func Load() (Config, error) { + cfg := Config{ + ExporterBaseURL: strings.TrimRight(strings.TrimSpace(os.Getenv("EXPORTER_BASE_URL")), "/"), + DatabaseURL: strings.TrimSpace(os.Getenv("DATABASE_URL")), + ListenAddr: strings.TrimSpace(os.Getenv("LISTEN_ADDR")), + DefaultRegion: strings.TrimSpace(os.Getenv("DEFAULT_REGION")), + SourceRevision: strings.TrimSpace(os.Getenv("SOURCE_REVISION")), + } + if cfg.ListenAddr == "" { + cfg.ListenAddr = ":8081" + } + if cfg.SourceRevision == "" { + cfg.SourceRevision = "billing-service-v1" + } + + if cfg.ExporterBaseURL == "" { + return Config{}, fmt.Errorf("EXPORTER_BASE_URL is required") + } + if cfg.DatabaseURL == "" { + return Config{}, fmt.Errorf("DATABASE_URL is required") + } + + interval := strings.TrimSpace(os.Getenv("COLLECT_INTERVAL")) + if interval == "" { + cfg.CollectInterval = time.Minute + } else { + parsed, err := time.ParseDuration(interval) + if err != nil { + return Config{}, fmt.Errorf("parse COLLECT_INTERVAL: %w", err) + } + cfg.CollectInterval = parsed + } + + cfg.PricePerByte = parseFloatEnv("PRICE_PER_BYTE", 0) + cfg.InitialBalance = parseFloatEnv("INITIAL_BALANCE", 0) + cfg.InitialIncludedQuotaBytes = parseIntEnv("INITIAL_INCLUDED_QUOTA_BYTES", 0) + return cfg, nil +} + +func parseFloatEnv(key string, fallback float64) float64 { + raw := strings.TrimSpace(os.Getenv(key)) + if raw == "" { + return fallback + } + parsed, err := strconv.ParseFloat(raw, 64) + if err != nil { + return fallback + } + return parsed +} + +func parseIntEnv(key string, fallback int64) int64 { + raw := strings.TrimSpace(os.Getenv(key)) + if raw == "" { + return fallback + } + parsed, err := strconv.ParseInt(raw, 10, 64) + if err != nil { + return fallback + } + return parsed +} diff --git a/internal/exporter/client.go b/internal/exporter/client.go new file mode 100644 index 0000000..3586db2 --- /dev/null +++ b/internal/exporter/client.go @@ -0,0 +1,54 @@ +package exporter + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "billing-service/internal/model" +) + +type Client struct { + baseURL string + httpClient *http.Client +} + +func NewClient(baseURL string) *Client { + return &Client{ + baseURL: strings.TrimRight(strings.TrimSpace(baseURL), "/"), + httpClient: &http.Client{Timeout: 15 * time.Second}, + } +} + +func (c *Client) FetchLatestSnapshot(ctx context.Context) (model.Snapshot, error) { + endpoint, err := url.JoinPath(c.baseURL, "/v1/snapshots/latest") + if err != nil { + return model.Snapshot{}, fmt.Errorf("build snapshot endpoint: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return model.Snapshot{}, fmt.Errorf("build snapshot request: %w", err) + } + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return model.Snapshot{}, fmt.Errorf("fetch snapshot: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return model.Snapshot{}, fmt.Errorf("fetch snapshot: unexpected status %s", resp.Status) + } + + var snapshot model.Snapshot + if err := json.NewDecoder(resp.Body).Decode(&snapshot); err != nil { + return model.Snapshot{}, fmt.Errorf("decode snapshot: %w", err) + } + return snapshot, nil +} diff --git a/internal/httpapi/handler.go b/internal/httpapi/handler.go new file mode 100644 index 0000000..4ebda10 --- /dev/null +++ b/internal/httpapi/handler.go @@ -0,0 +1,76 @@ +package httpapi + +import ( + "encoding/json" + "net/http" + + "billing-service/internal/model" + "billing-service/internal/service" +) + +type Handler struct { + service *service.Service +} + +func New(svc *service.Service) *Handler { + return &Handler{service: svc} +} + +func (h *Handler) Routes() http.Handler { + mux := http.NewServeMux() + mux.HandleFunc("/healthz", h.healthz) + mux.HandleFunc("/v1/status", h.status) + mux.HandleFunc("/v1/jobs/collect-and-rate", h.collectAndRate) + mux.HandleFunc("/v1/jobs/reconcile", h.reconcile) + return mux +} + +func (h *Handler) healthz(w http.ResponseWriter, r *http.Request) { + ok, message := h.service.Health() + status := http.StatusOK + if !ok { + status = http.StatusServiceUnavailable + } + writeJSON(w, status, map[string]any{ + "status": map[bool]string{true: "ok", false: "degraded"}[ok], + "message": message, + }) +} + +func (h *Handler) status(w http.ResponseWriter, r *http.Request) { + writeJSON(w, http.StatusOK, h.service.Status()) +} + +func (h *Handler) collectAndRate(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + result, err := h.service.RunCollectAndRate(r.Context(), "collect-and-rate") + if err != nil { + writeJSON(w, http.StatusServiceUnavailable, result) + return + } + writeJSON(w, http.StatusOK, result) +} + +func (h *Handler) reconcile(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + result, err := h.service.RunCollectAndRate(r.Context(), "reconcile") + if err != nil { + writeJSON(w, http.StatusServiceUnavailable, result) + return + } + writeJSON(w, http.StatusOK, result) +} + +func writeJSON(w http.ResponseWriter, status int, payload any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(payload) +} + +var _ = model.JobResult{} diff --git a/internal/model/types.go b/internal/model/types.go new file mode 100644 index 0000000..72dc977 --- /dev/null +++ b/internal/model/types.go @@ -0,0 +1,76 @@ +package model + +import "time" + +type Sample struct { + UUID string `json:"uuid"` + Email string `json:"email"` + InboundTag string `json:"inbound_tag"` + UplinkBytesTotal int64 `json:"uplink_bytes_total"` + DownlinkBytesTotal int64 `json:"downlink_bytes_total"` +} + +type Snapshot struct { + CollectedAt time.Time `json:"collected_at"` + NodeID string `json:"node_id"` + Env string `json:"env"` + Samples []Sample `json:"samples"` +} + +type Checkpoint struct { + NodeID string + AccountUUID string + LastUplinkTotal int64 + LastDownlinkTotal int64 + LastSeenAt time.Time + XrayRevision string + ResetEpoch int64 +} + +type MinuteBucket struct { + BucketStart time.Time + NodeID string + AccountUUID string + Region string + LineCode string + UplinkBytes int64 + DownlinkBytes int64 + TotalBytes int64 + Multiplier float64 + RatingStatus string + SourceRevision string +} + +type LedgerEntry struct { + ID string + AccountUUID string + BucketStart time.Time + BucketEnd time.Time + EntryType string + RatedBytes int64 + AmountDelta float64 + BalanceAfter float64 + PricingRuleVersion string +} + +type QuotaState struct { + AccountUUID string + RemainingIncludedQuota int64 + CurrentBalance float64 + Arrears bool + ThrottleState string + SuspendState string + LastRatedBucketAt *time.Time + EffectiveAt time.Time +} + +type JobResult struct { + Job string `json:"job"` + StartedAt time.Time `json:"started_at"` + FinishedAt time.Time `json:"finished_at"` + ProcessedSamples int `json:"processed_samples"` + WrittenMinutes int `json:"written_minutes"` + ReplayedMinutes int `json:"replayed_minutes"` + Status string `json:"status"` + Error string `json:"error,omitempty"` +} diff --git a/internal/repository/postgres.go b/internal/repository/postgres.go new file mode 100644 index 0000000..76a3a4c --- /dev/null +++ b/internal/repository/postgres.go @@ -0,0 +1,235 @@ +package repository + +import ( + "context" + "database/sql" + "fmt" + "time" + + "billing-service/internal/model" +) + +type Postgres struct { + db *sql.DB +} + +func NewPostgres(db *sql.DB) *Postgres { + return &Postgres{db: db} +} + +func (p *Postgres) GetCheckpoint(ctx context.Context, nodeID, accountUUID string) (*model.Checkpoint, error) { + const query = ` + SELECT node_id, account_uuid, last_uplink_total, last_downlink_total, last_seen_at, xray_revision, reset_epoch + FROM traffic_stat_checkpoints + WHERE node_id = $1 AND account_uuid = $2` + var checkpoint model.Checkpoint + err := p.db.QueryRowContext(ctx, query, nodeID, accountUUID).Scan( + &checkpoint.NodeID, + &checkpoint.AccountUUID, + &checkpoint.LastUplinkTotal, + &checkpoint.LastDownlinkTotal, + &checkpoint.LastSeenAt, + &checkpoint.XrayRevision, + &checkpoint.ResetEpoch, + ) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + return &checkpoint, nil +} + +func (p *Postgres) UpsertCheckpoint(ctx context.Context, checkpoint model.Checkpoint) error { + const query = ` + INSERT INTO traffic_stat_checkpoints ( + node_id, account_uuid, last_uplink_total, last_downlink_total, last_seen_at, xray_revision, reset_epoch + ) VALUES ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (node_id, account_uuid) DO UPDATE SET + last_uplink_total = EXCLUDED.last_uplink_total, + last_downlink_total = EXCLUDED.last_downlink_total, + last_seen_at = EXCLUDED.last_seen_at, + xray_revision = EXCLUDED.xray_revision, + reset_epoch = EXCLUDED.reset_epoch, + updated_at = now()` + _, err := p.db.ExecContext(ctx, query, + checkpoint.NodeID, + checkpoint.AccountUUID, + checkpoint.LastUplinkTotal, + checkpoint.LastDownlinkTotal, + checkpoint.LastSeenAt.UTC(), + checkpoint.XrayRevision, + checkpoint.ResetEpoch, + ) + return err +} + +func (p *Postgres) UpsertMinuteBucket(ctx context.Context, bucket model.MinuteBucket) (bool, error) { + existed, err := p.minuteBucketExists(ctx, bucket) + if err != nil { + return false, err + } + + const query = ` + INSERT INTO traffic_minute_buckets ( + bucket_start, node_id, account_uuid, region, line_code, uplink_bytes, downlink_bytes, total_bytes, multiplier, rating_status, source_revision + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) + ON CONFLICT (bucket_start, node_id, account_uuid, region, line_code) DO UPDATE SET + uplink_bytes = EXCLUDED.uplink_bytes, + downlink_bytes = EXCLUDED.downlink_bytes, + total_bytes = EXCLUDED.total_bytes, + multiplier = EXCLUDED.multiplier, + rating_status = EXCLUDED.rating_status, + source_revision = EXCLUDED.source_revision, + updated_at = now()` + _, err = p.db.ExecContext(ctx, query, + bucket.BucketStart.UTC(), + bucket.NodeID, + bucket.AccountUUID, + bucket.Region, + bucket.LineCode, + bucket.UplinkBytes, + bucket.DownlinkBytes, + bucket.TotalBytes, + bucket.Multiplier, + bucket.RatingStatus, + bucket.SourceRevision, + ) + return existed, err +} + +func (p *Postgres) minuteBucketExists(ctx context.Context, bucket model.MinuteBucket) (bool, error) { + const query = ` + SELECT 1 + FROM traffic_minute_buckets + WHERE bucket_start = $1 AND node_id = $2 AND account_uuid = $3 AND region = $4 AND line_code = $5` + var marker int + err := p.db.QueryRowContext(ctx, query, + bucket.BucketStart.UTC(), + bucket.NodeID, + bucket.AccountUUID, + bucket.Region, + bucket.LineCode, + ).Scan(&marker) + if err != nil { + if err == sql.ErrNoRows { + return false, nil + } + return false, err + } + return true, nil +} + +func (p *Postgres) UpsertLedger(ctx context.Context, entry model.LedgerEntry) (bool, error) { + existed, err := p.ledgerExists(ctx, entry.ID) + if err != nil { + return false, err + } + + const query = ` + INSERT INTO billing_ledger ( + id, account_uuid, bucket_start, bucket_end, entry_type, rated_bytes, amount_delta, balance_after, pricing_rule_version + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + ON CONFLICT (id) DO UPDATE SET + rated_bytes = EXCLUDED.rated_bytes, + amount_delta = EXCLUDED.amount_delta, + balance_after = EXCLUDED.balance_after, + pricing_rule_version = EXCLUDED.pricing_rule_version` + _, err = p.db.ExecContext(ctx, query, + entry.ID, + entry.AccountUUID, + entry.BucketStart.UTC(), + entry.BucketEnd.UTC(), + entry.EntryType, + entry.RatedBytes, + entry.AmountDelta, + entry.BalanceAfter, + entry.PricingRuleVersion, + ) + return existed, err +} + +func (p *Postgres) ledgerExists(ctx context.Context, id string) (bool, error) { + var marker int + err := p.db.QueryRowContext(ctx, `SELECT 1 FROM billing_ledger WHERE id = $1`, id).Scan(&marker) + if err != nil { + if err == sql.ErrNoRows { + return false, nil + } + return false, err + } + return true, nil +} + +func (p *Postgres) GetQuotaState(ctx context.Context, accountUUID string) (*model.QuotaState, error) { + const query = ` + SELECT account_uuid, remaining_included_quota, current_balance, arrears, throttle_state, suspend_state, last_rated_bucket_at, effective_at + FROM account_quota_states + WHERE account_uuid = $1` + var state model.QuotaState + var lastRated sql.NullTime + err := p.db.QueryRowContext(ctx, query, accountUUID).Scan( + &state.AccountUUID, + &state.RemainingIncludedQuota, + &state.CurrentBalance, + &state.Arrears, + &state.ThrottleState, + &state.SuspendState, + &lastRated, + &state.EffectiveAt, + ) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + if lastRated.Valid { + value := lastRated.Time + state.LastRatedBucketAt = &value + } + return &state, nil +} + +func (p *Postgres) UpsertQuotaState(ctx context.Context, state model.QuotaState) error { + const query = ` + INSERT INTO account_quota_states ( + account_uuid, remaining_included_quota, current_balance, arrears, throttle_state, suspend_state, last_rated_bucket_at, effective_at + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + ON CONFLICT (account_uuid) DO UPDATE SET + remaining_included_quota = EXCLUDED.remaining_included_quota, + current_balance = EXCLUDED.current_balance, + arrears = EXCLUDED.arrears, + throttle_state = EXCLUDED.throttle_state, + suspend_state = EXCLUDED.suspend_state, + last_rated_bucket_at = EXCLUDED.last_rated_bucket_at, + effective_at = EXCLUDED.effective_at, + updated_at = now()` + + var lastRated interface{} + if state.LastRatedBucketAt != nil { + lastRated = state.LastRatedBucketAt.UTC() + } + _, err := p.db.ExecContext(ctx, query, + state.AccountUUID, + state.RemainingIncludedQuota, + state.CurrentBalance, + state.Arrears, + state.ThrottleState, + state.SuspendState, + lastRated, + state.EffectiveAt.UTC(), + ) + return err +} + +var _ Repository = (*Postgres)(nil) + +func ensureUTC(ts time.Time) time.Time { + return ts.UTC() +} + +func unexpectedStatus(name string) error { + return fmt.Errorf("unexpected status for %s", name) +} diff --git a/internal/repository/repository.go b/internal/repository/repository.go new file mode 100644 index 0000000..f356f77 --- /dev/null +++ b/internal/repository/repository.go @@ -0,0 +1,16 @@ +package repository + +import ( + "context" + + "billing-service/internal/model" +) + +type Repository interface { + GetCheckpoint(ctx context.Context, nodeID, accountUUID string) (*model.Checkpoint, error) + UpsertCheckpoint(ctx context.Context, checkpoint model.Checkpoint) error + UpsertMinuteBucket(ctx context.Context, bucket model.MinuteBucket) (bool, error) + UpsertLedger(ctx context.Context, entry model.LedgerEntry) (bool, error) + GetQuotaState(ctx context.Context, accountUUID string) (*model.QuotaState, error) + UpsertQuotaState(ctx context.Context, state model.QuotaState) error +} diff --git a/internal/service/postgres_acceptance_test.go b/internal/service/postgres_acceptance_test.go new file mode 100644 index 0000000..0591fa5 --- /dev/null +++ b/internal/service/postgres_acceptance_test.go @@ -0,0 +1,112 @@ +package service + +import ( + "context" + "database/sql" + "os" + "path/filepath" + "testing" + "time" + + "billing-service/internal/config" + "billing-service/internal/model" + "billing-service/internal/repository" + + _ "github.com/jackc/pgx/v5/stdlib" +) + +func TestPostgresAcceptanceWritesAccountingTables(t *testing.T) { + databaseURL := os.Getenv("TEST_DATABASE_URL") + if databaseURL == "" { + t.Skip("TEST_DATABASE_URL is not set") + } + + db, err := sql.Open("pgx", databaseURL) + if err != nil { + t.Fatalf("open database: %v", err) + } + defer db.Close() + + ctx := context.Background() + if err := db.PingContext(ctx); err != nil { + t.Fatalf("ping database: %v", err) + } + + bootstrapPath := filepath.Join("..", "..", "testdata", "postgres", "init.sql") + bootstrapSQL, err := os.ReadFile(bootstrapPath) + if err != nil { + t.Fatalf("read bootstrap sql: %v", err) + } + if _, err := db.ExecContext(ctx, string(bootstrapSQL)); err != nil { + t.Fatalf("apply bootstrap sql: %v", err) + } + + accountUUID := "11111111-1111-1111-1111-111111111111" + if _, err := db.ExecContext(ctx, ` + DELETE FROM billing_ledger; + DELETE FROM traffic_minute_buckets; + DELETE FROM traffic_stat_checkpoints; + DELETE FROM account_quota_states; + DELETE FROM users; + `); err != nil { + t.Fatalf("reset tables: %v", err) + } + if _, err := db.ExecContext(ctx, ` + INSERT INTO users (uuid, username, password, email, proxy_uuid) + VALUES ($1, 'billing-test', 'irrelevant', 'billing@example.com', 'aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa') + `, accountUUID); err != nil { + t.Fatalf("seed user: %v", err) + } + + svc := New(config.Config{ + DefaultRegion: "", + SourceRevision: "billing-service-acceptance", + PricePerByte: 0.5, + InitialIncludedQuotaBytes: 1000, + InitialBalance: 0, + }, fakeSource{snapshot: model.Snapshot{ + CollectedAt: time.Date(2026, 4, 8, 11, 0, 45, 0, time.UTC), + NodeID: "jp-node", + Env: "prod", + Samples: []model.Sample{{ + UUID: accountUUID, + Email: "billing@example.com", + InboundTag: "premium", + UplinkBytesTotal: 100, + DownlinkBytesTotal: 50, + }}, + }}, repository.NewPostgres(db)) + + result, err := svc.RunCollectAndRate(ctx, "collect-and-rate") + if err != nil { + t.Fatalf("run collect-and-rate: %v", err) + } + if result.ProcessedSamples != 1 || result.WrittenMinutes != 1 { + t.Fatalf("unexpected result %#v", result) + } + + assertRowCount(t, db, "traffic_stat_checkpoints", 1) + assertRowCount(t, db, "traffic_minute_buckets", 1) + assertRowCount(t, db, "billing_ledger", 1) + assertRowCount(t, db, "account_quota_states", 1) + + var totalBytes int64 + if err := db.QueryRowContext(ctx, `SELECT total_bytes FROM traffic_minute_buckets LIMIT 1`).Scan(&totalBytes); err != nil { + t.Fatalf("query total_bytes: %v", err) + } + if totalBytes != 150 { + t.Fatalf("expected total_bytes 150, got %d", totalBytes) + } +} + +func assertRowCount(t *testing.T, db *sql.DB, table string, want int) { + t.Helper() + + var got int + if err := db.QueryRow(`SELECT count(*) FROM ` + table).Scan(&got); err != nil { + t.Fatalf("count rows in %s: %v", table, err) + } + if got != want { + t.Fatalf("expected %d rows in %s, got %d", want, table, got) + } +} diff --git a/internal/service/service.go b/internal/service/service.go new file mode 100644 index 0000000..ceda198 --- /dev/null +++ b/internal/service/service.go @@ -0,0 +1,275 @@ +package service + +import ( + "context" + "errors" + "fmt" + "strings" + "sync" + "time" + + "billing-service/internal/config" + "billing-service/internal/model" + "billing-service/internal/repository" + + "github.com/google/uuid" +) + +type snapshotSource interface { + FetchLatestSnapshot(ctx context.Context) (model.Snapshot, error) +} + +type Service struct { + cfg config.Config + source snapshotSource + repo repository.Repository + + mu sync.Mutex + lastResult model.JobResult + lastOK bool + lastError string +} + +func New(cfg config.Config, source snapshotSource, repo repository.Repository) *Service { + return &Service{ + cfg: cfg, + source: source, + repo: repo, + } +} + +func (s *Service) Start(ctx context.Context) { + go func() { + _, _ = s.RunCollectAndRate(ctx, "collect-and-rate") + ticker := time.NewTicker(s.cfg.CollectInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + _, _ = s.RunCollectAndRate(ctx, "collect-and-rate") + } + } + }() +} + +func (s *Service) RunCollectAndRate(ctx context.Context, job string) (model.JobResult, error) { + s.mu.Lock() + defer s.mu.Unlock() + + startedAt := time.Now().UTC() + result := model.JobResult{ + Job: job, + StartedAt: startedAt, + Status: "ok", + } + + snapshot, err := s.source.FetchLatestSnapshot(ctx) + if err != nil { + result.Status = "error" + result.Error = err.Error() + result.FinishedAt = time.Now().UTC() + s.record(result) + return result, err + } + + for _, sample := range snapshot.Samples { + if err := validateSample(sample); err != nil { + result.Status = "partial" + result.Error = joinError(result.Error, err.Error()) + continue + } + + processed, err := s.processSample(ctx, snapshot, sample, &result) + if err != nil { + result.Status = "partial" + result.Error = joinError(result.Error, err.Error()) + continue + } + if processed { + result.ProcessedSamples++ + } + } + + result.FinishedAt = time.Now().UTC() + s.record(result) + if result.Status == "error" { + return result, errors.New(result.Error) + } + return result, nil +} + +func (s *Service) Status() model.JobResult { + s.mu.Lock() + defer s.mu.Unlock() + return s.lastResult +} + +func (s *Service) Health() (bool, string) { + s.mu.Lock() + defer s.mu.Unlock() + return s.lastOK, s.lastError +} + +func (s *Service) processSample(ctx context.Context, snapshot model.Snapshot, sample model.Sample, result *model.JobResult) (bool, error) { + storageNodeID := composeStorageNodeID(snapshot.Env, snapshot.NodeID) + minuteStart := snapshot.CollectedAt.UTC().Truncate(time.Minute) + + checkpoint, err := s.repo.GetCheckpoint(ctx, storageNodeID, sample.UUID) + if err != nil { + return false, fmt.Errorf("get checkpoint %s: %w", sample.UUID, err) + } + + deltaUplink := sample.UplinkBytesTotal + deltaDownlink := sample.DownlinkBytesTotal + resetEpoch := int64(0) + if checkpoint != nil { + deltaUplink = sample.UplinkBytesTotal - checkpoint.LastUplinkTotal + deltaDownlink = sample.DownlinkBytesTotal - checkpoint.LastDownlinkTotal + resetEpoch = checkpoint.ResetEpoch + } + + if deltaUplink < 0 || deltaDownlink < 0 { + resetEpoch++ + err := s.repo.UpsertCheckpoint(ctx, model.Checkpoint{ + NodeID: storageNodeID, + AccountUUID: sample.UUID, + LastUplinkTotal: sample.UplinkBytesTotal, + LastDownlinkTotal: sample.DownlinkBytesTotal, + LastSeenAt: snapshot.CollectedAt.UTC(), + XrayRevision: s.cfg.SourceRevision, + ResetEpoch: resetEpoch, + }) + if err != nil { + return false, fmt.Errorf("upsert reset checkpoint %s: %w", sample.UUID, err) + } + return false, nil + } + + totalBytes := deltaUplink + deltaDownlink + bucket := model.MinuteBucket{ + BucketStart: minuteStart, + NodeID: storageNodeID, + AccountUUID: sample.UUID, + Region: s.cfg.DefaultRegion, + LineCode: strings.TrimSpace(sample.InboundTag), + UplinkBytes: deltaUplink, + DownlinkBytes: deltaDownlink, + TotalBytes: totalBytes, + Multiplier: 1.0, + RatingStatus: "rated", + SourceRevision: s.cfg.SourceRevision, + } + + minuteExisted, err := s.repo.UpsertMinuteBucket(ctx, bucket) + if err != nil { + return false, fmt.Errorf("upsert minute bucket %s: %w", sample.UUID, err) + } + if minuteExisted { + result.ReplayedMinutes++ + } else { + result.WrittenMinutes++ + } + + amountDelta := -float64(totalBytes) * s.cfg.PricePerByte + entry := model.LedgerEntry{ + ID: deterministicLedgerID(bucket), + AccountUUID: sample.UUID, + BucketStart: minuteStart, + BucketEnd: minuteStart.Add(time.Minute), + EntryType: "traffic_charge", + RatedBytes: totalBytes, + AmountDelta: amountDelta, + PricingRuleVersion: s.cfg.SourceRevision, + } + + quota, err := s.repo.GetQuotaState(ctx, sample.UUID) + if err != nil { + return false, fmt.Errorf("get quota state %s: %w", sample.UUID, err) + } + if quota == nil { + quota = &model.QuotaState{ + AccountUUID: sample.UUID, + RemainingIncludedQuota: s.cfg.InitialIncludedQuotaBytes, + CurrentBalance: s.cfg.InitialBalance, + ThrottleState: "normal", + SuspendState: "active", + EffectiveAt: snapshot.CollectedAt.UTC(), + } + } + entry.BalanceAfter = quota.CurrentBalance + amountDelta + + ledgerExisted, err := s.repo.UpsertLedger(ctx, entry) + if err != nil { + return false, fmt.Errorf("upsert ledger %s: %w", sample.UUID, err) + } + + if !ledgerExisted { + remainingQuota := quota.RemainingIncludedQuota - totalBytes + if remainingQuota < 0 { + remainingQuota = 0 + } + quota.RemainingIncludedQuota = remainingQuota + quota.CurrentBalance = entry.BalanceAfter + quota.EffectiveAt = snapshot.CollectedAt.UTC() + lastRated := minuteStart + quota.LastRatedBucketAt = &lastRated + if err := s.repo.UpsertQuotaState(ctx, *quota); err != nil { + return false, fmt.Errorf("upsert quota state %s: %w", sample.UUID, err) + } + } else { + result.ReplayedMinutes++ + } + + if err := s.repo.UpsertCheckpoint(ctx, model.Checkpoint{ + NodeID: storageNodeID, + AccountUUID: sample.UUID, + LastUplinkTotal: sample.UplinkBytesTotal, + LastDownlinkTotal: sample.DownlinkBytesTotal, + LastSeenAt: snapshot.CollectedAt.UTC(), + XrayRevision: s.cfg.SourceRevision, + ResetEpoch: resetEpoch, + }); err != nil { + return false, fmt.Errorf("upsert checkpoint %s: %w", sample.UUID, err) + } + + return true, nil +} + +func validateSample(sample model.Sample) error { + if strings.TrimSpace(sample.UUID) == "" { + return fmt.Errorf("sample uuid is required") + } + if _, err := uuid.Parse(strings.TrimSpace(sample.UUID)); err != nil { + return fmt.Errorf("sample uuid %q is not a valid UUID", sample.UUID) + } + return nil +} + +func deterministicLedgerID(bucket model.MinuteBucket) string { + key := fmt.Sprintf("%s|%s|%s|%s|%s", bucket.BucketStart.UTC().Format(time.RFC3339), bucket.NodeID, bucket.AccountUUID, bucket.Region, bucket.LineCode) + return uuid.NewSHA1(uuid.NameSpaceOID, []byte(key)).String() +} + +func composeStorageNodeID(env, nodeID string) string { + env = strings.TrimSpace(env) + nodeID = strings.TrimSpace(nodeID) + if env == "" { + return nodeID + } + return env + ":" + nodeID +} + +func joinError(existing, next string) string { + if existing == "" { + return next + } + return existing + "; " + next +} + +func (s *Service) record(result model.JobResult) { + s.lastResult = result + s.lastError = result.Error + s.lastOK = result.Status != "error" +} diff --git a/internal/service/service_test.go b/internal/service/service_test.go new file mode 100644 index 0000000..4445102 --- /dev/null +++ b/internal/service/service_test.go @@ -0,0 +1,302 @@ +package service + +import ( + "context" + "testing" + "time" + + "billing-service/internal/config" + "billing-service/internal/model" + "billing-service/internal/repository" +) + +type fakeSource struct { + snapshot model.Snapshot + err error +} + +func (f fakeSource) FetchLatestSnapshot(context.Context) (model.Snapshot, error) { + return f.snapshot, f.err +} + +type memoryRepo struct { + checkpoints map[string]model.Checkpoint + buckets map[string]model.MinuteBucket + ledgers map[string]model.LedgerEntry + quotas map[string]model.QuotaState +} + +func newMemoryRepo() *memoryRepo { + return &memoryRepo{ + checkpoints: map[string]model.Checkpoint{}, + buckets: map[string]model.MinuteBucket{}, + ledgers: map[string]model.LedgerEntry{}, + quotas: map[string]model.QuotaState{}, + } +} + +func checkpointKey(nodeID, accountUUID string) string { + return nodeID + "\x00" + accountUUID +} + +func bucketKey(bucket model.MinuteBucket) string { + return bucket.BucketStart.UTC().Format(time.RFC3339) + "\x00" + bucket.NodeID + "\x00" + bucket.AccountUUID + "\x00" + bucket.Region + "\x00" + bucket.LineCode +} + +func (m *memoryRepo) GetCheckpoint(ctx context.Context, nodeID, accountUUID string) (*model.Checkpoint, error) { + if checkpoint, ok := m.checkpoints[checkpointKey(nodeID, accountUUID)]; ok { + copy := checkpoint + return ©, nil + } + return nil, nil +} + +func (m *memoryRepo) UpsertCheckpoint(ctx context.Context, checkpoint model.Checkpoint) error { + m.checkpoints[checkpointKey(checkpoint.NodeID, checkpoint.AccountUUID)] = checkpoint + return nil +} + +func (m *memoryRepo) UpsertMinuteBucket(ctx context.Context, bucket model.MinuteBucket) (bool, error) { + key := bucketKey(bucket) + _, existed := m.buckets[key] + m.buckets[key] = bucket + return existed, nil +} + +func (m *memoryRepo) UpsertLedger(ctx context.Context, entry model.LedgerEntry) (bool, error) { + _, existed := m.ledgers[entry.ID] + m.ledgers[entry.ID] = entry + return existed, nil +} + +func (m *memoryRepo) GetQuotaState(ctx context.Context, accountUUID string) (*model.QuotaState, error) { + if quota, ok := m.quotas[accountUUID]; ok { + copy := quota + return ©, nil + } + return nil, nil +} + +func (m *memoryRepo) UpsertQuotaState(ctx context.Context, state model.QuotaState) error { + m.quotas[state.AccountUUID] = state + return nil +} + +var _ repository.Repository = (*memoryRepo)(nil) + +func baseConfig() config.Config { + return config.Config{ + DefaultRegion: "", + SourceRevision: "billing-service-v1", + PricePerByte: 0.5, + InitialIncludedQuotaBytes: 1000, + InitialBalance: 0, + } +} + +func TestDeltaCalculationAndQuotaUpdate(t *testing.T) { + repo := newMemoryRepo() + svc := New(baseConfig(), fakeSource{snapshot: model.Snapshot{ + CollectedAt: time.Date(2026, 4, 8, 10, 30, 15, 0, time.UTC), + NodeID: "jp-node", + Env: "prod", + Samples: []model.Sample{{ + UUID: "11111111-1111-1111-1111-111111111111", + Email: "user@example.com", + InboundTag: "premium", + UplinkBytesTotal: 100, + DownlinkBytesTotal: 50, + }}, + }}, repo) + + result, err := svc.RunCollectAndRate(context.Background(), "collect-and-rate") + if err != nil { + t.Fatalf("run job: %v", err) + } + if result.ProcessedSamples != 1 || result.WrittenMinutes != 1 { + t.Fatalf("unexpected result %#v", result) + } + quota := repo.quotas["11111111-1111-1111-1111-111111111111"] + if quota.RemainingIncludedQuota != 850 { + t.Fatalf("expected remaining quota 850, got %d", quota.RemainingIncludedQuota) + } + if quota.CurrentBalance != -75 { + t.Fatalf("expected current balance -75, got %v", quota.CurrentBalance) + } +} + +func TestDuplicateMinuteIsReplaySafe(t *testing.T) { + repo := newMemoryRepo() + snapshot := model.Snapshot{ + CollectedAt: time.Date(2026, 4, 8, 10, 30, 30, 0, time.UTC), + NodeID: "jp-node", + Env: "prod", + Samples: []model.Sample{{ + UUID: "11111111-1111-1111-1111-111111111111", + Email: "user@example.com", + InboundTag: "premium", + UplinkBytesTotal: 100, + DownlinkBytesTotal: 50, + }}, + } + svc := New(baseConfig(), fakeSource{snapshot: snapshot}, repo) + + if _, err := svc.RunCollectAndRate(context.Background(), "collect-and-rate"); err != nil { + t.Fatalf("first run: %v", err) + } + result, err := svc.RunCollectAndRate(context.Background(), "collect-and-rate") + if err != nil { + t.Fatalf("second run: %v", err) + } + if result.ReplayedMinutes == 0 { + t.Fatalf("expected replayed minutes, got %#v", result) + } + if len(repo.ledgers) != 1 { + t.Fatalf("expected 1 ledger entry, got %d", len(repo.ledgers)) + } +} + +func TestNegativeDeltaProtection(t *testing.T) { + repo := newMemoryRepo() + cfg := baseConfig() + accountUUID := "11111111-1111-1111-1111-111111111111" + nodeKey := composeStorageNodeID("prod", "jp-node") + repo.checkpoints[checkpointKey(nodeKey, accountUUID)] = model.Checkpoint{ + NodeID: nodeKey, + AccountUUID: accountUUID, + LastUplinkTotal: 200, + LastDownlinkTotal: 200, + LastSeenAt: time.Now().UTC(), + XrayRevision: "prev", + ResetEpoch: 0, + } + svc := New(cfg, fakeSource{snapshot: model.Snapshot{ + CollectedAt: time.Date(2026, 4, 8, 10, 31, 0, 0, time.UTC), + NodeID: "jp-node", + Env: "prod", + Samples: []model.Sample{{ + UUID: accountUUID, + InboundTag: "premium", + UplinkBytesTotal: 10, + DownlinkBytesTotal: 20, + }}, + }}, repo) + + result, err := svc.RunCollectAndRate(context.Background(), "collect-and-rate") + if err != nil { + t.Fatalf("run job: %v", err) + } + if result.ProcessedSamples != 0 { + t.Fatalf("expected negative delta sample to be skipped, got %#v", result) + } + if len(repo.buckets) != 0 || len(repo.ledgers) != 0 { + t.Fatalf("expected no writes on negative delta") + } + if repo.checkpoints[checkpointKey(nodeKey, accountUUID)].ResetEpoch != 1 { + t.Fatalf("expected reset epoch increment") + } +} + +func TestRestartRecoveryFromCheckpoint(t *testing.T) { + repo := newMemoryRepo() + accountUUID := "11111111-1111-1111-1111-111111111111" + nodeKey := composeStorageNodeID("prod", "jp-node") + repo.checkpoints[checkpointKey(nodeKey, accountUUID)] = model.Checkpoint{ + NodeID: nodeKey, + AccountUUID: accountUUID, + LastUplinkTotal: 100, + LastDownlinkTotal: 100, + LastSeenAt: time.Now().UTC(), + } + svc := New(baseConfig(), fakeSource{snapshot: model.Snapshot{ + CollectedAt: time.Date(2026, 4, 8, 10, 32, 0, 0, time.UTC), + NodeID: "jp-node", + Env: "prod", + Samples: []model.Sample{{ + UUID: accountUUID, + InboundTag: "premium", + UplinkBytesTotal: 130, + DownlinkBytesTotal: 140, + }}, + }}, repo) + + result, err := svc.RunCollectAndRate(context.Background(), "collect-and-rate") + if err != nil { + t.Fatalf("run job: %v", err) + } + if result.ProcessedSamples != 1 || result.WrittenMinutes != 1 { + t.Fatalf("unexpected result %#v", result) + } + bucket := repo.buckets[bucketKey(model.MinuteBucket{ + BucketStart: time.Date(2026, 4, 8, 10, 32, 0, 0, time.UTC), + NodeID: nodeKey, + AccountUUID: accountUUID, + Region: "", + LineCode: "premium", + })] + if bucket.TotalBytes != 70 { + t.Fatalf("expected recovered delta 70, got %d", bucket.TotalBytes) + } +} + +func TestMultiEnvIsolation(t *testing.T) { + repo := newMemoryRepo() + accountUUID := "11111111-1111-1111-1111-111111111111" + cfg := baseConfig() + + prodSvc := New(cfg, fakeSource{snapshot: model.Snapshot{ + CollectedAt: time.Date(2026, 4, 8, 10, 33, 0, 0, time.UTC), + NodeID: "jp-node", + Env: "prod", + Samples: []model.Sample{{UUID: accountUUID, InboundTag: "premium", UplinkBytesTotal: 10, DownlinkBytesTotal: 10}}, + }}, repo) + previewSvc := New(cfg, fakeSource{snapshot: model.Snapshot{ + CollectedAt: time.Date(2026, 4, 8, 10, 33, 0, 0, time.UTC), + NodeID: "jp-node", + Env: "preview", + Samples: []model.Sample{{UUID: accountUUID, InboundTag: "premium", UplinkBytesTotal: 10, DownlinkBytesTotal: 10}}, + }}, repo) + + if _, err := prodSvc.RunCollectAndRate(context.Background(), "collect-and-rate"); err != nil { + t.Fatalf("prod run: %v", err) + } + if _, err := previewSvc.RunCollectAndRate(context.Background(), "collect-and-rate"); err != nil { + t.Fatalf("preview run: %v", err) + } + if len(repo.buckets) != 2 { + t.Fatalf("expected isolated buckets per env, got %d", len(repo.buckets)) + } +} + +func TestLateMinuteReconcileUsesSameMinuteKey(t *testing.T) { + repo := newMemoryRepo() + accountUUID := "11111111-1111-1111-1111-111111111111" + cfg := baseConfig() + collectedAt := time.Date(2026, 4, 8, 10, 34, 50, 0, time.UTC) + snapshot := model.Snapshot{ + CollectedAt: collectedAt, + NodeID: "jp-node", + Env: "prod", + Samples: []model.Sample{{ + UUID: accountUUID, + InboundTag: "premium", + UplinkBytesTotal: 20, + DownlinkBytesTotal: 20, + }}, + } + svc := New(cfg, fakeSource{snapshot: snapshot}, repo) + + if _, err := svc.RunCollectAndRate(context.Background(), "collect-and-rate"); err != nil { + t.Fatalf("first run: %v", err) + } + result, err := svc.RunCollectAndRate(context.Background(), "reconcile") + if err != nil { + t.Fatalf("reconcile run: %v", err) + } + if result.ReplayedMinutes == 0 { + t.Fatalf("expected reconcile to report replayed minute, got %#v", result) + } + if len(repo.buckets) != 1 { + t.Fatalf("expected single logical minute bucket, got %d", len(repo.buckets)) + } +} diff --git a/testdata/postgres/init.sql b/testdata/postgres/init.sql new file mode 100644 index 0000000..81268f0 --- /dev/null +++ b/testdata/postgres/init.sql @@ -0,0 +1,72 @@ +CREATE TABLE IF NOT EXISTS public.users ( + uuid UUID PRIMARY KEY, + username TEXT NOT NULL, + password TEXT NOT NULL, + email TEXT, + role TEXT NOT NULL DEFAULT 'user', + level INTEGER NOT NULL DEFAULT 20, + groups JSONB NOT NULL DEFAULT '[]'::jsonb, + permissions JSONB NOT NULL DEFAULT '[]'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + version BIGINT NOT NULL DEFAULT 0, + origin_node TEXT NOT NULL DEFAULT 'local', + active BOOLEAN NOT NULL DEFAULT TRUE, + proxy_uuid UUID NOT NULL, + CONSTRAINT users_email_optional_ck CHECK (email IS NULL OR length(email) > 0) +); + +CREATE TABLE IF NOT EXISTS public.traffic_stat_checkpoints ( + node_id TEXT NOT NULL, + account_uuid UUID NOT NULL REFERENCES public.users(uuid) ON DELETE CASCADE, + last_uplink_total BIGINT NOT NULL DEFAULT 0, + last_downlink_total BIGINT NOT NULL DEFAULT 0, + last_seen_at TIMESTAMPTZ NOT NULL, + xray_revision TEXT NOT NULL DEFAULT '', + reset_epoch BIGINT NOT NULL DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY (node_id, account_uuid) +); + +CREATE TABLE IF NOT EXISTS public.traffic_minute_buckets ( + bucket_start TIMESTAMPTZ NOT NULL, + node_id TEXT NOT NULL, + account_uuid UUID NOT NULL REFERENCES public.users(uuid) ON DELETE CASCADE, + region TEXT NOT NULL DEFAULT '', + line_code TEXT NOT NULL DEFAULT '', + uplink_bytes BIGINT NOT NULL DEFAULT 0, + downlink_bytes BIGINT NOT NULL DEFAULT 0, + total_bytes BIGINT NOT NULL DEFAULT 0, + multiplier DOUBLE PRECISION NOT NULL DEFAULT 1.0, + rating_status TEXT NOT NULL DEFAULT 'pending', + source_revision TEXT NOT NULL DEFAULT '', + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY (bucket_start, node_id, account_uuid, region, line_code) +); + +CREATE TABLE IF NOT EXISTS public.billing_ledger ( + id UUID PRIMARY KEY, + account_uuid UUID NOT NULL REFERENCES public.users(uuid) ON DELETE CASCADE, + bucket_start TIMESTAMPTZ NOT NULL, + bucket_end TIMESTAMPTZ NOT NULL, + entry_type TEXT NOT NULL, + rated_bytes BIGINT NOT NULL DEFAULT 0, + amount_delta DOUBLE PRECISION NOT NULL DEFAULT 0, + balance_after DOUBLE PRECISION NOT NULL DEFAULT 0, + pricing_rule_version TEXT NOT NULL DEFAULT '', + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE TABLE IF NOT EXISTS public.account_quota_states ( + account_uuid UUID PRIMARY KEY REFERENCES public.users(uuid) ON DELETE CASCADE, + remaining_included_quota BIGINT NOT NULL DEFAULT 0, + current_balance DOUBLE PRECISION NOT NULL DEFAULT 0, + arrears BOOLEAN NOT NULL DEFAULT false, + throttle_state TEXT NOT NULL DEFAULT 'normal', + suspend_state TEXT NOT NULL DEFAULT 'active', + last_rated_bucket_at TIMESTAMPTZ NULL, + effective_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +);