Skip to content

Commit

Permalink
Track unsettled usage
Browse files Browse the repository at this point in the history
  • Loading branch information
neekolas committed Feb 25, 2025
1 parent b21cb0d commit 4d5834b
Show file tree
Hide file tree
Showing 14 changed files with 457 additions and 4 deletions.
14 changes: 12 additions & 2 deletions pkg/api/message/publishWorker.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/xmtp/xmtpd/pkg/envelopes"
"github.com/xmtp/xmtpd/pkg/fees"
"github.com/xmtp/xmtpd/pkg/registrant"
"github.com/xmtp/xmtpd/pkg/utils"
"go.uber.org/zap"
)

Expand Down Expand Up @@ -147,17 +148,26 @@ func (p *publishWorker) publishStagedEnvelope(stagedEnv queries.StagedOriginator
return false
}

originatorID := int32(p.registrant.NodeID())

// On unique constraint conflicts, no error is thrown, but numRows is 0
inserted, err := q.InsertGatewayEnvelope(
inserted, err := db.InsertGatewayEnvelopeAndIncrementUnsettledUsage(
p.ctx,
p.store,
queries.InsertGatewayEnvelopeParams{
OriginatorNodeID: int32(p.registrant.NodeID()),
OriginatorNodeID: originatorID,
OriginatorSequenceID: stagedEnv.ID,
Topic: stagedEnv.Topic,
OriginatorEnvelope: originatorBytes,
PayerID: db.NullInt32(payerId),
GatewayTime: stagedEnv.OriginatorTime,
},
queries.IncrementUnsettledUsageParams{
PayerID: payerId,
OriginatorID: originatorID,
MinutesSinceEpoch: utils.MinutesSinceEpoch(stagedEnv.OriginatorTime),
SpendPicodollars: int64(baseFee) + int64(congestionFee),
},
)
if p.ctx.Err() != nil {
return false
Expand Down
41 changes: 41 additions & 0 deletions pkg/db/gatewayEnvelope.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package db

import (
"context"
"database/sql"

"github.com/xmtp/xmtpd/pkg/db/queries"
)

// InsertGatewayEnvelopeAndIncrementUnsettledUsage inserts a gateway envelope and increments the unsettled usage for the payer.
// It returns the number of rows inserted.
func InsertGatewayEnvelopeAndIncrementUnsettledUsage(
ctx context.Context,
db *sql.DB,
insertParams queries.InsertGatewayEnvelopeParams,
incrementParams queries.IncrementUnsettledUsageParams,
) (int64, error) {
return RunInTxWithResult(
ctx,
db,
&sql.TxOptions{},
func(ctx context.Context, txQueries *queries.Queries) (int64, error) {
numInserted, err := txQueries.InsertGatewayEnvelope(ctx, insertParams)
if err != nil {
return 0, err
}
// If the numInserted is 0 it means the envelope already exists
// and we don't need to increment the unsettled usage
if numInserted == 0 {
return 0, nil
}

err = txQueries.IncrementUnsettledUsage(ctx, incrementParams)
if err != nil {
return 0, err
}

return numInserted, nil
},
)
}
133 changes: 133 additions & 0 deletions pkg/db/gatewayEnvelope_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package db

import (
"context"
"sync"
"sync/atomic"
"testing"

"github.com/stretchr/testify/require"
"github.com/xmtp/xmtpd/pkg/db/queries"
"github.com/xmtp/xmtpd/pkg/testutils"
)

func buildParams(
payerID int32,
originatorID int32,
sequenceID int64,
spendPicodollars int64,
) (queries.InsertGatewayEnvelopeParams, queries.IncrementUnsettledUsageParams) {
insertParams := queries.InsertGatewayEnvelopeParams{
OriginatorNodeID: originatorID,
OriginatorSequenceID: sequenceID,
Topic: testutils.RandomBytes(32),
OriginatorEnvelope: testutils.RandomBytes(100),
PayerID: NullInt32(payerID),
}

incrementParams := queries.IncrementUnsettledUsageParams{
PayerID: payerID,
OriginatorID: originatorID,
MinutesSinceEpoch: 1,
SpendPicodollars: 100,
}

return insertParams, incrementParams
}

func TestInsertAndIncrement(t *testing.T) {
ctx := context.Background()
db, _, cleanup := testutils.NewDB(t, ctx)
defer cleanup()

querier := queries.New(db)
// Create a payer
payerID := testutils.CreatePayer(t, db, testutils.RandomAddress().Hex())
originatorID := testutils.RandomInt32()
sequenceID := int64(10)

insertParams, incrementParams := buildParams(payerID, originatorID, sequenceID, 100)

numInserted, err := InsertGatewayEnvelopeAndIncrementUnsettledUsage(
ctx,
db,
insertParams,
incrementParams,
)
require.NoError(t, err)
require.Equal(t, numInserted, int64(1))

payerSpend, err := querier.GetPayerUnsettledUsage(
ctx,
queries.GetPayerUnsettledUsageParams{PayerID: payerID},
)
require.NoError(t, err)
require.Equal(t, payerSpend, int64(100))
}

func TestPayerMustExist(t *testing.T) {
ctx := context.Background()
db, _, cleanup := testutils.NewDB(t, ctx)
defer cleanup()

payerID := testutils.RandomInt32()
originatorID := testutils.RandomInt32()
sequenceID := int64(10)

insertParams, incrementParams := buildParams(payerID, originatorID, sequenceID, 100)

_, err := InsertGatewayEnvelopeAndIncrementUnsettledUsage(
ctx,
db,
insertParams,
incrementParams,
)
require.Error(t, err)
}

func TestInsertAndIncrementParallel(t *testing.T) {
ctx := context.Background()
db, _, cleanup := testutils.NewDB(t, ctx)
defer cleanup()

querier := queries.New(db)
// Create a payer
payerID := testutils.CreatePayer(t, db, testutils.RandomAddress().Hex())
originatorID := testutils.RandomInt32()
sequenceID := int64(10)
numberOfInserts := 20

insertParams, incrementParams := buildParams(payerID, originatorID, sequenceID, 100)

var wg sync.WaitGroup

totalInserted := int64(0)

attemptInsert := func() {
defer wg.Done()
numInserted, err := InsertGatewayEnvelopeAndIncrementUnsettledUsage(
ctx,
db,
insertParams,
incrementParams,
)
require.NoError(t, err)
atomic.AddInt64(&totalInserted, numInserted)
}

for range numberOfInserts {
wg.Add(1)
go attemptInsert()
}

wg.Wait()

require.Equal(t, totalInserted, int64(1))

payerSpend, err := querier.GetPayerUnsettledUsage(
ctx,
queries.GetPayerUnsettledUsageParams{PayerID: payerID},
)
require.NoError(t, err)
require.Equal(t, payerSpend, int64(100))
}
19 changes: 19 additions & 0 deletions pkg/db/queries.sql
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,22 @@ ON CONFLICT (address)
RETURNING
id;

-- name: IncrementUnsettledUsage :exec
INSERT INTO unsettled_usage(payer_id, originator_id, minutes_since_epoch, spend_picodollars)
VALUES (@payer_id, @originator_id, @minutes_since_epoch, @spend_picodollars)
ON CONFLICT (payer_id, originator_id, minutes_since_epoch)
DO UPDATE SET
spend_picodollars = unsettled_usage.spend_picodollars + @spend_picodollars;

-- name: GetPayerUnsettledUsage :one
SELECT
SUM(spend_picodollars) AS total_spend_picodollars
FROM
unsettled_usage
WHERE
payer_id = @payer_id
AND (@minutes_since_epoch_gt::BIGINT = 0
OR minutes_since_epoch > @minutes_since_epoch_gt::BIGINT)
AND (@minutes_since_epoch_lt::BIGINT = 0
OR minutes_since_epoch < @minutes_since_epoch_lt::BIGINT);

7 changes: 7 additions & 0 deletions pkg/db/queries/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

51 changes: 51 additions & 0 deletions pkg/db/queries/queries.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 36 additions & 0 deletions pkg/db/tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,39 @@ func RunInTx(
done = true
return tx.Commit()
}

func RunInTxWithResult[T any](
ctx context.Context,
db *sql.DB,
opts *sql.TxOptions,
fn func(ctx context.Context, txQueries *queries.Queries) (T, error),
) (T, error) {
querier := queries.New(db)
tx, err := db.BeginTx(ctx, opts)
if err != nil {
var zero T
return zero, err
}

var done bool

defer func() {
if !done {
_ = tx.Rollback()
}
}()

result, err := fn(ctx, querier.WithTx(tx))
if err != nil {
var zero T
return zero, err
}

done = true
if err := tx.Commit(); err != nil {
var zero T
return zero, err
}

return result, nil
}
Loading

0 comments on commit 4d5834b

Please sign in to comment.