Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Track unsettled usage #549

Merged
merged 1 commit into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions pkg/api/message/publishWorker.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/xmtp/xmtpd/pkg/envelopes"
"github.com/xmtp/xmtpd/pkg/fees"
"github.com/xmtp/xmtpd/pkg/registrant"
"github.com/xmtp/xmtpd/pkg/utils"
"go.uber.org/zap"
)

Expand Down Expand Up @@ -147,17 +148,26 @@ func (p *publishWorker) publishStagedEnvelope(stagedEnv queries.StagedOriginator
return false
}

originatorID := int32(p.registrant.NodeID())

// On unique constraint conflicts, no error is thrown, but numRows is 0
inserted, err := q.InsertGatewayEnvelope(
inserted, err := db.InsertGatewayEnvelopeAndIncrementUnsettledUsage(
p.ctx,
p.store,
queries.InsertGatewayEnvelopeParams{
OriginatorNodeID: int32(p.registrant.NodeID()),
OriginatorNodeID: originatorID,
OriginatorSequenceID: stagedEnv.ID,
Topic: stagedEnv.Topic,
OriginatorEnvelope: originatorBytes,
PayerID: db.NullInt32(payerId),
GatewayTime: stagedEnv.OriginatorTime,
},
queries.IncrementUnsettledUsageParams{
PayerID: payerId,
OriginatorID: originatorID,
MinutesSinceEpoch: utils.MinutesSinceEpoch(stagedEnv.OriginatorTime),
SpendPicodollars: int64(baseFee) + int64(congestionFee),
},
)
if p.ctx.Err() != nil {
return false
Expand Down
41 changes: 41 additions & 0 deletions pkg/db/gatewayEnvelope.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package db

import (
"context"
"database/sql"

"github.com/xmtp/xmtpd/pkg/db/queries"
)

// InsertGatewayEnvelopeAndIncrementUnsettledUsage inserts a gateway envelope and increments the unsettled usage for the payer.
// It returns the number of rows inserted.
func InsertGatewayEnvelopeAndIncrementUnsettledUsage(
ctx context.Context,
db *sql.DB,
insertParams queries.InsertGatewayEnvelopeParams,
incrementParams queries.IncrementUnsettledUsageParams,
) (int64, error) {
return RunInTxWithResult(
ctx,
db,
&sql.TxOptions{},
func(ctx context.Context, txQueries *queries.Queries) (int64, error) {
numInserted, err := txQueries.InsertGatewayEnvelope(ctx, insertParams)
if err != nil {
return 0, err
}
// If the numInserted is 0 it means the envelope already exists
// and we don't need to increment the unsettled usage
if numInserted == 0 {
return 0, nil
}

err = txQueries.IncrementUnsettledUsage(ctx, incrementParams)
if err != nil {
return 0, err
}

return numInserted, nil
},
)
}
133 changes: 133 additions & 0 deletions pkg/db/gatewayEnvelope_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package db

import (
"context"
"sync"
"sync/atomic"
"testing"

"github.com/stretchr/testify/require"
"github.com/xmtp/xmtpd/pkg/db/queries"
"github.com/xmtp/xmtpd/pkg/testutils"
)

func buildParams(
payerID int32,
originatorID int32,
sequenceID int64,
spendPicodollars int64,
) (queries.InsertGatewayEnvelopeParams, queries.IncrementUnsettledUsageParams) {
insertParams := queries.InsertGatewayEnvelopeParams{
OriginatorNodeID: originatorID,
OriginatorSequenceID: sequenceID,
Topic: testutils.RandomBytes(32),
OriginatorEnvelope: testutils.RandomBytes(100),
PayerID: NullInt32(payerID),
}

incrementParams := queries.IncrementUnsettledUsageParams{
PayerID: payerID,
OriginatorID: originatorID,
MinutesSinceEpoch: 1,
SpendPicodollars: spendPicodollars,
}

return insertParams, incrementParams
}

func TestInsertAndIncrement(t *testing.T) {
ctx := context.Background()
db, _, cleanup := testutils.NewDB(t, ctx)
defer cleanup()

querier := queries.New(db)
// Create a payer
payerID := testutils.CreatePayer(t, db, testutils.RandomAddress().Hex())
originatorID := testutils.RandomInt32()
sequenceID := int64(10)

insertParams, incrementParams := buildParams(payerID, originatorID, sequenceID, 100)

numInserted, err := InsertGatewayEnvelopeAndIncrementUnsettledUsage(
ctx,
db,
insertParams,
incrementParams,
)
require.NoError(t, err)
require.Equal(t, numInserted, int64(1))

payerSpend, err := querier.GetPayerUnsettledUsage(
ctx,
queries.GetPayerUnsettledUsageParams{PayerID: payerID},
)
require.NoError(t, err)
require.Equal(t, payerSpend, int64(100))
}

func TestPayerMustExist(t *testing.T) {
ctx := context.Background()
db, _, cleanup := testutils.NewDB(t, ctx)
defer cleanup()

payerID := testutils.RandomInt32()
originatorID := testutils.RandomInt32()
sequenceID := int64(10)

insertParams, incrementParams := buildParams(payerID, originatorID, sequenceID, 100)

_, err := InsertGatewayEnvelopeAndIncrementUnsettledUsage(
ctx,
db,
insertParams,
incrementParams,
)
require.Error(t, err)
}

func TestInsertAndIncrementParallel(t *testing.T) {
ctx := context.Background()
db, _, cleanup := testutils.NewDB(t, ctx)
defer cleanup()

querier := queries.New(db)
// Create a payer
payerID := testutils.CreatePayer(t, db, testutils.RandomAddress().Hex())
originatorID := testutils.RandomInt32()
sequenceID := int64(10)
numberOfInserts := 20

insertParams, incrementParams := buildParams(payerID, originatorID, sequenceID, 100)

var wg sync.WaitGroup

totalInserted := int64(0)

attemptInsert := func() {
defer wg.Done()
numInserted, err := InsertGatewayEnvelopeAndIncrementUnsettledUsage(
ctx,
db,
insertParams,
incrementParams,
)
require.NoError(t, err)
atomic.AddInt64(&totalInserted, numInserted)
}

for range numberOfInserts {
wg.Add(1)
go attemptInsert()
}

wg.Wait()

require.Equal(t, totalInserted, int64(1))

payerSpend, err := querier.GetPayerUnsettledUsage(
ctx,
queries.GetPayerUnsettledUsageParams{PayerID: payerID},
)
require.NoError(t, err)
require.Equal(t, payerSpend, int64(100))
}
19 changes: 19 additions & 0 deletions pkg/db/queries.sql
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,22 @@ ON CONFLICT (address)
RETURNING
id;

-- name: IncrementUnsettledUsage :exec
INSERT INTO unsettled_usage(payer_id, originator_id, minutes_since_epoch, spend_picodollars)
VALUES (@payer_id, @originator_id, @minutes_since_epoch, @spend_picodollars)
ON CONFLICT (payer_id, originator_id, minutes_since_epoch)
DO UPDATE SET
spend_picodollars = unsettled_usage.spend_picodollars + @spend_picodollars;

-- name: GetPayerUnsettledUsage :one
SELECT
COALESCE(SUM(spend_picodollars), 0)::BIGINT AS total_spend_picodollars
FROM
unsettled_usage
WHERE
payer_id = @payer_id
AND (@minutes_since_epoch_gt::BIGINT = 0
OR minutes_since_epoch > @minutes_since_epoch_gt::BIGINT)
AND (@minutes_since_epoch_lt::BIGINT = 0
OR minutes_since_epoch < @minutes_since_epoch_lt::BIGINT);

7 changes: 7 additions & 0 deletions pkg/db/queries/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

51 changes: 51 additions & 0 deletions pkg/db/queries/queries.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 29 additions & 0 deletions pkg/db/tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,32 @@ func RunInTx(
done = true
return tx.Commit()
}

func RunInTxWithResult[T any](
ctx context.Context,
db *sql.DB,
opts *sql.TxOptions,
fn func(ctx context.Context, txQueries *queries.Queries) (T, error),
) (result T, err error) {
tx, err := db.BeginTx(ctx, opts)
if err != nil {
return result, err
}

defer func() {
if err != nil {
_ = tx.Rollback()
}
}()

result, err = fn(ctx, queries.New(db).WithTx(tx))
if err != nil {
return result, err
}

if err = tx.Commit(); err != nil {
return result, err
}

return result, nil
}
Loading
Loading