From 4d5834b59608aae3b147d5a975bce9933af28d14 Mon Sep 17 00:00:00 2001 From: Nicholas Molnar <65710+neekolas@users.noreply.github.com> Date: Mon, 24 Feb 2025 19:41:33 -0500 Subject: [PATCH] Track unsettled usage --- pkg/api/message/publishWorker.go | 14 +- pkg/db/gatewayEnvelope.go | 41 ++++++ pkg/db/gatewayEnvelope_test.go | 133 ++++++++++++++++++ pkg/db/queries.sql | 19 +++ pkg/db/queries/models.go | 7 + pkg/db/queries/queries.sql.go | 51 +++++++ pkg/db/tx.go | 36 +++++ pkg/db/unsettledUsage_test.go | 100 +++++++++++++ pkg/envelopes/originator.go | 5 + pkg/migrations/00007_unsettled-usage.down.sql | 4 + pkg/migrations/00007_unsettled-usage.up.sql | 9 ++ pkg/sync/syncWorker.go | 21 ++- pkg/testutils/random.go | 4 + pkg/utils/time.go | 17 +++ 14 files changed, 457 insertions(+), 4 deletions(-) create mode 100644 pkg/db/gatewayEnvelope.go create mode 100644 pkg/db/gatewayEnvelope_test.go create mode 100644 pkg/db/unsettledUsage_test.go create mode 100644 pkg/migrations/00007_unsettled-usage.down.sql create mode 100644 pkg/migrations/00007_unsettled-usage.up.sql create mode 100644 pkg/utils/time.go diff --git a/pkg/api/message/publishWorker.go b/pkg/api/message/publishWorker.go index 0f2d5356..81dfa25e 100644 --- a/pkg/api/message/publishWorker.go +++ b/pkg/api/message/publishWorker.go @@ -13,6 +13,7 @@ import ( "github.com/xmtp/xmtpd/pkg/envelopes" "github.com/xmtp/xmtpd/pkg/fees" "github.com/xmtp/xmtpd/pkg/registrant" + "github.com/xmtp/xmtpd/pkg/utils" "go.uber.org/zap" ) @@ -147,17 +148,26 @@ func (p *publishWorker) publishStagedEnvelope(stagedEnv queries.StagedOriginator return false } + originatorID := int32(p.registrant.NodeID()) + // On unique constraint conflicts, no error is thrown, but numRows is 0 - inserted, err := q.InsertGatewayEnvelope( + inserted, err := db.InsertGatewayEnvelopeAndIncrementUnsettledUsage( p.ctx, + p.store, queries.InsertGatewayEnvelopeParams{ - OriginatorNodeID: int32(p.registrant.NodeID()), + OriginatorNodeID: originatorID, OriginatorSequenceID: stagedEnv.ID, Topic: stagedEnv.Topic, OriginatorEnvelope: originatorBytes, PayerID: db.NullInt32(payerId), GatewayTime: stagedEnv.OriginatorTime, }, + queries.IncrementUnsettledUsageParams{ + PayerID: payerId, + OriginatorID: originatorID, + MinutesSinceEpoch: utils.MinutesSinceEpoch(stagedEnv.OriginatorTime), + SpendPicodollars: int64(baseFee) + int64(congestionFee), + }, ) if p.ctx.Err() != nil { return false diff --git a/pkg/db/gatewayEnvelope.go b/pkg/db/gatewayEnvelope.go new file mode 100644 index 00000000..677c0f0e --- /dev/null +++ b/pkg/db/gatewayEnvelope.go @@ -0,0 +1,41 @@ +package db + +import ( + "context" + "database/sql" + + "github.com/xmtp/xmtpd/pkg/db/queries" +) + +// InsertGatewayEnvelopeAndIncrementUnsettledUsage inserts a gateway envelope and increments the unsettled usage for the payer. +// It returns the number of rows inserted. +func InsertGatewayEnvelopeAndIncrementUnsettledUsage( + ctx context.Context, + db *sql.DB, + insertParams queries.InsertGatewayEnvelopeParams, + incrementParams queries.IncrementUnsettledUsageParams, +) (int64, error) { + return RunInTxWithResult( + ctx, + db, + &sql.TxOptions{}, + func(ctx context.Context, txQueries *queries.Queries) (int64, error) { + numInserted, err := txQueries.InsertGatewayEnvelope(ctx, insertParams) + if err != nil { + return 0, err + } + // If the numInserted is 0 it means the envelope already exists + // and we don't need to increment the unsettled usage + if numInserted == 0 { + return 0, nil + } + + err = txQueries.IncrementUnsettledUsage(ctx, incrementParams) + if err != nil { + return 0, err + } + + return numInserted, nil + }, + ) +} diff --git a/pkg/db/gatewayEnvelope_test.go b/pkg/db/gatewayEnvelope_test.go new file mode 100644 index 00000000..653c8728 --- /dev/null +++ b/pkg/db/gatewayEnvelope_test.go @@ -0,0 +1,133 @@ +package db + +import ( + "context" + "sync" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/require" + "github.com/xmtp/xmtpd/pkg/db/queries" + "github.com/xmtp/xmtpd/pkg/testutils" +) + +func buildParams( + payerID int32, + originatorID int32, + sequenceID int64, + spendPicodollars int64, +) (queries.InsertGatewayEnvelopeParams, queries.IncrementUnsettledUsageParams) { + insertParams := queries.InsertGatewayEnvelopeParams{ + OriginatorNodeID: originatorID, + OriginatorSequenceID: sequenceID, + Topic: testutils.RandomBytes(32), + OriginatorEnvelope: testutils.RandomBytes(100), + PayerID: NullInt32(payerID), + } + + incrementParams := queries.IncrementUnsettledUsageParams{ + PayerID: payerID, + OriginatorID: originatorID, + MinutesSinceEpoch: 1, + SpendPicodollars: 100, + } + + return insertParams, incrementParams +} + +func TestInsertAndIncrement(t *testing.T) { + ctx := context.Background() + db, _, cleanup := testutils.NewDB(t, ctx) + defer cleanup() + + querier := queries.New(db) + // Create a payer + payerID := testutils.CreatePayer(t, db, testutils.RandomAddress().Hex()) + originatorID := testutils.RandomInt32() + sequenceID := int64(10) + + insertParams, incrementParams := buildParams(payerID, originatorID, sequenceID, 100) + + numInserted, err := InsertGatewayEnvelopeAndIncrementUnsettledUsage( + ctx, + db, + insertParams, + incrementParams, + ) + require.NoError(t, err) + require.Equal(t, numInserted, int64(1)) + + payerSpend, err := querier.GetPayerUnsettledUsage( + ctx, + queries.GetPayerUnsettledUsageParams{PayerID: payerID}, + ) + require.NoError(t, err) + require.Equal(t, payerSpend, int64(100)) +} + +func TestPayerMustExist(t *testing.T) { + ctx := context.Background() + db, _, cleanup := testutils.NewDB(t, ctx) + defer cleanup() + + payerID := testutils.RandomInt32() + originatorID := testutils.RandomInt32() + sequenceID := int64(10) + + insertParams, incrementParams := buildParams(payerID, originatorID, sequenceID, 100) + + _, err := InsertGatewayEnvelopeAndIncrementUnsettledUsage( + ctx, + db, + insertParams, + incrementParams, + ) + require.Error(t, err) +} + +func TestInsertAndIncrementParallel(t *testing.T) { + ctx := context.Background() + db, _, cleanup := testutils.NewDB(t, ctx) + defer cleanup() + + querier := queries.New(db) + // Create a payer + payerID := testutils.CreatePayer(t, db, testutils.RandomAddress().Hex()) + originatorID := testutils.RandomInt32() + sequenceID := int64(10) + numberOfInserts := 20 + + insertParams, incrementParams := buildParams(payerID, originatorID, sequenceID, 100) + + var wg sync.WaitGroup + + totalInserted := int64(0) + + attemptInsert := func() { + defer wg.Done() + numInserted, err := InsertGatewayEnvelopeAndIncrementUnsettledUsage( + ctx, + db, + insertParams, + incrementParams, + ) + require.NoError(t, err) + atomic.AddInt64(&totalInserted, numInserted) + } + + for range numberOfInserts { + wg.Add(1) + go attemptInsert() + } + + wg.Wait() + + require.Equal(t, totalInserted, int64(1)) + + payerSpend, err := querier.GetPayerUnsettledUsage( + ctx, + queries.GetPayerUnsettledUsageParams{PayerID: payerID}, + ) + require.NoError(t, err) + require.Equal(t, payerSpend, int64(100)) +} diff --git a/pkg/db/queries.sql b/pkg/db/queries.sql index 030aacd5..f7880640 100644 --- a/pkg/db/queries.sql +++ b/pkg/db/queries.sql @@ -180,3 +180,22 @@ ON CONFLICT (address) RETURNING id; +-- name: IncrementUnsettledUsage :exec +INSERT INTO unsettled_usage(payer_id, originator_id, minutes_since_epoch, spend_picodollars) + VALUES (@payer_id, @originator_id, @minutes_since_epoch, @spend_picodollars) +ON CONFLICT (payer_id, originator_id, minutes_since_epoch) + DO UPDATE SET + spend_picodollars = unsettled_usage.spend_picodollars + @spend_picodollars; + +-- name: GetPayerUnsettledUsage :one +SELECT + SUM(spend_picodollars) AS total_spend_picodollars +FROM + unsettled_usage +WHERE + payer_id = @payer_id + AND (@minutes_since_epoch_gt::BIGINT = 0 + OR minutes_since_epoch > @minutes_since_epoch_gt::BIGINT) + AND (@minutes_since_epoch_lt::BIGINT = 0 + OR minutes_since_epoch < @minutes_since_epoch_lt::BIGINT); + diff --git a/pkg/db/queries/models.go b/pkg/db/queries/models.go index bfb2c9fc..22ff23e3 100644 --- a/pkg/db/queries/models.go +++ b/pkg/db/queries/models.go @@ -56,3 +56,10 @@ type StagedOriginatorEnvelope struct { Topic []byte PayerEnvelope []byte } + +type UnsettledUsage struct { + PayerID int32 + OriginatorID int32 + MinutesSinceEpoch int32 + SpendPicodollars int64 +} diff --git a/pkg/db/queries/queries.sql.go b/pkg/db/queries/queries.sql.go index 6554d65d..de1a44a5 100644 --- a/pkg/db/queries/queries.sql.go +++ b/pkg/db/queries/queries.sql.go @@ -220,6 +220,57 @@ func (q *Queries) GetLatestSequenceId(ctx context.Context, originatorNodeID int3 return originator_sequence_id, err } +const getPayerUnsettledUsage = `-- name: GetPayerUnsettledUsage :one +SELECT + SUM(spend_picodollars) AS total_spend_picodollars +FROM + unsettled_usage +WHERE + payer_id = $1 + AND ($2::BIGINT = 0 + OR minutes_since_epoch > $2::BIGINT) + AND ($3::BIGINT = 0 + OR minutes_since_epoch < $3::BIGINT) +` + +type GetPayerUnsettledUsageParams struct { + PayerID int32 + MinutesSinceEpochGt int64 + MinutesSinceEpochLt int64 +} + +func (q *Queries) GetPayerUnsettledUsage(ctx context.Context, arg GetPayerUnsettledUsageParams) (int64, error) { + row := q.db.QueryRowContext(ctx, getPayerUnsettledUsage, arg.PayerID, arg.MinutesSinceEpochGt, arg.MinutesSinceEpochLt) + var total_spend_picodollars int64 + err := row.Scan(&total_spend_picodollars) + return total_spend_picodollars, err +} + +const incrementUnsettledUsage = `-- name: IncrementUnsettledUsage :exec +INSERT INTO unsettled_usage(payer_id, originator_id, minutes_since_epoch, spend_picodollars) + VALUES ($1, $2, $3, $4) +ON CONFLICT (payer_id, originator_id, minutes_since_epoch) + DO UPDATE SET + spend_picodollars = unsettled_usage.spend_picodollars + $4 +` + +type IncrementUnsettledUsageParams struct { + PayerID int32 + OriginatorID int32 + MinutesSinceEpoch int32 + SpendPicodollars int64 +} + +func (q *Queries) IncrementUnsettledUsage(ctx context.Context, arg IncrementUnsettledUsageParams) error { + _, err := q.db.ExecContext(ctx, incrementUnsettledUsage, + arg.PayerID, + arg.OriginatorID, + arg.MinutesSinceEpoch, + arg.SpendPicodollars, + ) + return err +} + const insertAddressLog = `-- name: InsertAddressLog :execrows INSERT INTO address_log(address, inbox_id, association_sequence_id, revocation_sequence_id) VALUES ($1, decode($2, 'hex'), $3, NULL) diff --git a/pkg/db/tx.go b/pkg/db/tx.go index 830b6784..12029d83 100644 --- a/pkg/db/tx.go +++ b/pkg/db/tx.go @@ -34,3 +34,39 @@ func RunInTx( done = true return tx.Commit() } + +func RunInTxWithResult[T any]( + ctx context.Context, + db *sql.DB, + opts *sql.TxOptions, + fn func(ctx context.Context, txQueries *queries.Queries) (T, error), +) (T, error) { + querier := queries.New(db) + tx, err := db.BeginTx(ctx, opts) + if err != nil { + var zero T + return zero, err + } + + var done bool + + defer func() { + if !done { + _ = tx.Rollback() + } + }() + + result, err := fn(ctx, querier.WithTx(tx)) + if err != nil { + var zero T + return zero, err + } + + done = true + if err := tx.Commit(); err != nil { + var zero T + return zero, err + } + + return result, nil +} diff --git a/pkg/db/unsettledUsage_test.go b/pkg/db/unsettledUsage_test.go new file mode 100644 index 00000000..7a29f636 --- /dev/null +++ b/pkg/db/unsettledUsage_test.go @@ -0,0 +1,100 @@ +package db + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "github.com/xmtp/xmtpd/pkg/db/queries" + "github.com/xmtp/xmtpd/pkg/testutils" + "github.com/xmtp/xmtpd/pkg/utils" +) + +func TestIncrementUnsettledUsage(t *testing.T) { + ctx := context.Background() + db, _, cleanup := testutils.NewDB(t, ctx) + defer cleanup() + + querier := queries.New(db) + payerId := testutils.RandomInt32() + originatorId := testutils.RandomInt32() + minutesSinceEpoch := utils.MinutesSinceEpochNow() + + require.NoError(t, querier.IncrementUnsettledUsage(ctx, queries.IncrementUnsettledUsageParams{ + PayerID: payerId, + OriginatorID: originatorId, + MinutesSinceEpoch: minutesSinceEpoch, + SpendPicodollars: 100, + })) + + unsettledUsage, err := querier.GetPayerUnsettledUsage( + ctx, + queries.GetPayerUnsettledUsageParams{ + PayerID: payerId, + }, + ) + require.NoError(t, err) + require.Equal(t, unsettledUsage, int64(100)) + + require.NoError(t, querier.IncrementUnsettledUsage(ctx, queries.IncrementUnsettledUsageParams{ + PayerID: payerId, + OriginatorID: originatorId, + MinutesSinceEpoch: minutesSinceEpoch, + SpendPicodollars: 100, + })) + + unsettledUsage, err = querier.GetPayerUnsettledUsage( + ctx, + queries.GetPayerUnsettledUsageParams{ + PayerID: payerId, + }, + ) + require.NoError(t, err) + require.Equal(t, unsettledUsage, int64(200)) +} + +func TestGetUnsettledUsage(t *testing.T) { + ctx := context.Background() + db, _, cleanup := testutils.NewDB(t, ctx) + defer cleanup() + + querier := queries.New(db) + payerId := testutils.RandomInt32() + originatorId := testutils.RandomInt32() + + addUsage := func(minutesSinceEpoch int32, spendPicodollars int64) { + require.NoError( + t, + querier.IncrementUnsettledUsage(ctx, queries.IncrementUnsettledUsageParams{ + PayerID: payerId, + OriginatorID: originatorId, + MinutesSinceEpoch: minutesSinceEpoch, + SpendPicodollars: spendPicodollars, + }), + ) + } + + addUsage(1, 100) + addUsage(2, 200) + addUsage(3, 300) + + unsettledUsage, err := querier.GetPayerUnsettledUsage( + ctx, + queries.GetPayerUnsettledUsageParams{ + PayerID: payerId, + MinutesSinceEpochGt: 2, + }, + ) + require.NoError(t, err) + require.Equal(t, unsettledUsage, int64(300)) + + unsettledUsage, err = querier.GetPayerUnsettledUsage( + ctx, + queries.GetPayerUnsettledUsageParams{ + PayerID: payerId, + MinutesSinceEpochGt: 1, + }, + ) + require.NoError(t, err) + require.Equal(t, unsettledUsage, int64(500)) +} diff --git a/pkg/envelopes/originator.go b/pkg/envelopes/originator.go index b86b9f94..db9c80e5 100644 --- a/pkg/envelopes/originator.go +++ b/pkg/envelopes/originator.go @@ -2,6 +2,7 @@ package envelopes import ( "errors" + "time" envelopesProto "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/envelopes" "github.com/xmtp/xmtpd/pkg/topic" @@ -62,6 +63,10 @@ func (o *OriginatorEnvelope) OriginatorNs() int64 { return o.UnsignedOriginatorEnvelope.OriginatorNs() } +func (o *OriginatorEnvelope) OriginatorTime() time.Time { + return utils.NsToDate(o.OriginatorNs()) +} + func (o *OriginatorEnvelope) TargetTopic() topic.Topic { return o.UnsignedOriginatorEnvelope.TargetTopic() } diff --git a/pkg/migrations/00007_unsettled-usage.down.sql b/pkg/migrations/00007_unsettled-usage.down.sql new file mode 100644 index 00000000..d9815448 --- /dev/null +++ b/pkg/migrations/00007_unsettled-usage.down.sql @@ -0,0 +1,4 @@ +DROP TABLE IF EXISTS unsettled_usage; + +DROP INDEX IF EXISTS idx_unsettled_usage_payer_id; + diff --git a/pkg/migrations/00007_unsettled-usage.up.sql b/pkg/migrations/00007_unsettled-usage.up.sql new file mode 100644 index 00000000..b30c682f --- /dev/null +++ b/pkg/migrations/00007_unsettled-usage.up.sql @@ -0,0 +1,9 @@ +CREATE TABLE unsettled_usage( + payer_id INTEGER NOT NULL, + originator_id INTEGER NOT NULL, + minutes_since_epoch INTEGER NOT NULL, + spend_picodollars BIGINT NOT NULL, + PRIMARY KEY (payer_id, originator_id, minutes_since_epoch) +); + +CREATE INDEX idx_unsettled_usage_payer_id ON unsettled_usage(payer_id); \ No newline at end of file diff --git a/pkg/sync/syncWorker.go b/pkg/sync/syncWorker.go index 628a704e..5b157936 100644 --- a/pkg/sync/syncWorker.go +++ b/pkg/sync/syncWorker.go @@ -11,12 +11,14 @@ import ( "github.com/xmtp/xmtpd/pkg/db" "github.com/xmtp/xmtpd/pkg/db/queries" envUtils "github.com/xmtp/xmtpd/pkg/envelopes" + "github.com/xmtp/xmtpd/pkg/fees" clientInterceptors "github.com/xmtp/xmtpd/pkg/interceptors/client" "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/envelopes" "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api" "github.com/xmtp/xmtpd/pkg/registrant" "github.com/xmtp/xmtpd/pkg/registry" "github.com/xmtp/xmtpd/pkg/tracing" + "github.com/xmtp/xmtpd/pkg/utils" "go.uber.org/zap" "google.golang.org/grpc" ) @@ -31,6 +33,7 @@ type syncWorker struct { subscriptions map[uint32]struct{} subscriptionsMutex sync.RWMutex cancel context.CancelFunc + feesCalculator fees.IFeeCalculator } type originatorStream struct { @@ -441,9 +444,12 @@ func (s *syncWorker) insertEnvelope(env *envUtils.OriginatorEnvelope) { return } - q := queries.New(s.store) - inserted, err := q.InsertGatewayEnvelope( + originatorID := int32(env.OriginatorNodeID()) + originatorTime := utils.NsToDate(env.OriginatorNs()) + + inserted, err := db.InsertGatewayEnvelopeAndIncrementUnsettledUsage( s.ctx, + s.store, queries.InsertGatewayEnvelopeParams{ OriginatorNodeID: int32(env.OriginatorNodeID()), OriginatorSequenceID: int64(env.OriginatorSequenceID()), @@ -451,6 +457,17 @@ func (s *syncWorker) insertEnvelope(env *envUtils.OriginatorEnvelope) { OriginatorEnvelope: originatorBytes, PayerID: db.NullInt32(payerId), }, + queries.IncrementUnsettledUsageParams{ + PayerID: payerId, + OriginatorID: originatorID, + MinutesSinceEpoch: utils.MinutesSinceEpoch(originatorTime), + // TODO:(nm) Independently calculate fees + SpendPicodollars: int64( + env.UnsignedOriginatorEnvelope.BaseFee(), + ) + int64( + env.UnsignedOriginatorEnvelope.CongestionFee(), + ), + }, ) if err != nil { s.log.Error("Failed to insert gateway envelope", zap.Error(err)) diff --git a/pkg/testutils/random.go b/pkg/testutils/random.go index e12dbbb8..cd48746c 100644 --- a/pkg/testutils/random.go +++ b/pkg/testutils/random.go @@ -67,3 +67,7 @@ func RandomBlockHash() common.Hash { bytes := RandomBytes(32) return common.BytesToHash(bytes) } + +func RandomInt32() int32 { + return rand.Int31() +} diff --git a/pkg/utils/time.go b/pkg/utils/time.go new file mode 100644 index 00000000..f760acb0 --- /dev/null +++ b/pkg/utils/time.go @@ -0,0 +1,17 @@ +package utils + +import "time" + +func MinutesSinceEpoch(timestamp time.Time) int32 { + durationSinceEpoch := timestamp.Sub(time.Unix(0, 0)) + + return int32(durationSinceEpoch.Minutes()) +} + +func MinutesSinceEpochNow() int32 { + return MinutesSinceEpoch(time.Now()) +} + +func NsToDate(ns int64) time.Time { + return time.Unix(0, ns) +}