From ae46babe9190905181252fe9cc73ef978c653bb1 Mon Sep 17 00:00:00 2001 From: Nicholas Molnar <65710+neekolas@users.noreply.github.com> Date: Wed, 26 Feb 2025 12:37:24 -0500 Subject: [PATCH] Link payer to gateway envelopes (#536) ### TL;DR Added payer tracking to gateway envelopes by linking them to payer records in the database. It's very hard to fill this link in after-the-fact, since it requires recovering the signer address from the `PayerEnvelope`, so I think it makes sense to store it on every row even if we don't have an immediate use for querying by the field. ### Issues - https://github.com/xmtp/xmtpd/issues/527 - https://github.com/xmtp/xmtpd/issues/529 ### What changed? - Added a `payer_id` column to the `gateway_envelopes` table that references the `payers` table - Modified the `InsertGatewayEnvelope` query to include the `payer_id` field - Updated the publish worker and sync worker to extract payer information from envelopes and store it - Added validation of originator envelopes before processing - Updated tests to include payer IDs in test data ### How to test? - Run existing test suite which has been updated to include payer tracking - Verify that gateway envelopes are properly linked to payer records - Confirm that invalid envelopes are rejected during validation - Check that payer addresses are correctly recovered from signatures ### Why make this change? This change enables tracking of message payers in the system, which is essential for: - Monitoring message usage per payer - Supporting future billing and rate limiting features - Improving system accountability and auditability ## Summary by CodeRabbit - **New Features** - Enhanced envelope processing now includes robust validation and integration of payer information, improving reliability and message tracking. - **Chores** - Updated the database schema and migrations to support a new payer identifier. - Refined internal setups to align with the enhanced payer-data integration. --- pkg/api/message/publishWorker.go | 22 +++++++++++++++++-- pkg/api/message/subscribe_test.go | 11 +++++++++- pkg/api/metadata/cursor_test.go | 16 +++++++++++--- pkg/api/query_test.go | 7 ++++++ pkg/db/queries.sql | 4 ++-- pkg/db/queries/models.go | 1 + pkg/db/queries/queries.sql.go | 9 +++++--- ...6_link-payer-to-gateway-envelopes.down.sql | 3 +++ ...006_link-payer-to-gateway-envelopes.up.sql | 4 ++++ pkg/sync/syncWorker.go | 22 +++++++++++++++++++ pkg/testutils/envelopes/envelopes.go | 5 +++-- pkg/testutils/store.go | 15 +++++++++++++ 12 files changed, 106 insertions(+), 13 deletions(-) create mode 100644 pkg/migrations/00006_link-payer-to-gateway-envelopes.down.sql create mode 100644 pkg/migrations/00006_link-payer-to-gateway-envelopes.up.sql diff --git a/pkg/api/message/publishWorker.go b/pkg/api/message/publishWorker.go index 4253eeb6..86bdc59d 100644 --- a/pkg/api/message/publishWorker.go +++ b/pkg/api/message/publishWorker.go @@ -8,9 +8,9 @@ import ( "github.com/xmtp/xmtpd/pkg/db" "github.com/xmtp/xmtpd/pkg/db/queries" + "github.com/xmtp/xmtpd/pkg/envelopes" "github.com/xmtp/xmtpd/pkg/registrant" "go.uber.org/zap" - "google.golang.org/protobuf/proto" ) type publishWorker struct { @@ -110,7 +110,12 @@ func (p *publishWorker) publishStagedEnvelope(stagedEnv queries.StagedOriginator ) return false } - originatorBytes, err := proto.Marshal(originatorEnv) + validatedEnvelope, err := envelopes.NewOriginatorEnvelope(originatorEnv) + if err != nil { + logger.Error("Failed to validate originator envelope", zap.Error(err)) + return false + } + originatorBytes, err := validatedEnvelope.Bytes() if err != nil { logger.Error("Failed to marshal originator envelope", zap.Error(err)) return false @@ -118,6 +123,18 @@ func (p *publishWorker) publishStagedEnvelope(stagedEnv queries.StagedOriginator q := queries.New(p.store) + payerAddress, err := validatedEnvelope.UnsignedOriginatorEnvelope.PayerEnvelope.RecoverSigner() + if err != nil { + logger.Error("Failed to recover payer address", zap.Error(err)) + return false + } + + payerId, err := q.FindOrCreatePayer(p.ctx, payerAddress.Hex()) + if err != nil { + logger.Error("Failed to find or create payer", zap.Error(err)) + return false + } + // On unique constraint conflicts, no error is thrown, but numRows is 0 inserted, err := q.InsertGatewayEnvelope( p.ctx, @@ -126,6 +143,7 @@ func (p *publishWorker) publishStagedEnvelope(stagedEnv queries.StagedOriginator OriginatorSequenceID: stagedEnv.ID, Topic: stagedEnv.Topic, OriginatorEnvelope: originatorBytes, + PayerID: db.NullInt32(payerId), }, ) if p.ctx.Err() != nil { diff --git a/pkg/api/message/subscribe_test.go b/pkg/api/message/subscribe_test.go index 09a5c8dd..932caaa5 100644 --- a/pkg/api/message/subscribe_test.go +++ b/pkg/api/message/subscribe_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "github.com/xmtp/xmtpd/pkg/api/message" "github.com/xmtp/xmtpd/pkg/db" + dbUtils "github.com/xmtp/xmtpd/pkg/db" "github.com/xmtp/xmtpd/pkg/db/queries" "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/envelopes" "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api" @@ -28,12 +29,16 @@ var allRows []queries.InsertGatewayEnvelopeParams func setupTest( t *testing.T, ) (message_api.ReplicationApiClient, *sql.DB, testUtilsApi.ApiServerMocks, func()) { + api, db, mocks, cleanup := testUtilsApi.NewTestReplicationAPIClient(t) + + payerId := dbUtils.NullInt32(testutils.CreatePayer(t, db)) allRows = []queries.InsertGatewayEnvelopeParams{ // Initial rows { OriginatorNodeID: 1, OriginatorSequenceID: 1, Topic: topicA, + PayerID: payerId, OriginatorEnvelope: testutils.Marshal( t, envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 1, topicA), @@ -43,6 +48,7 @@ func setupTest( OriginatorNodeID: 2, OriginatorSequenceID: 1, Topic: topicA, + PayerID: payerId, OriginatorEnvelope: testutils.Marshal( t, envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 2, 1, topicA), @@ -53,6 +59,7 @@ func setupTest( OriginatorNodeID: 1, OriginatorSequenceID: 2, Topic: topicB, + PayerID: payerId, OriginatorEnvelope: testutils.Marshal( t, envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 2, topicB), @@ -62,6 +69,7 @@ func setupTest( OriginatorNodeID: 2, OriginatorSequenceID: 2, Topic: topicB, + PayerID: payerId, OriginatorEnvelope: testutils.Marshal( t, envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 2, 2, topicB), @@ -71,6 +79,7 @@ func setupTest( OriginatorNodeID: 1, OriginatorSequenceID: 3, Topic: topicA, + PayerID: payerId, OriginatorEnvelope: testutils.Marshal( t, envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 3, topicA), @@ -78,7 +87,7 @@ func setupTest( }, } - return testUtilsApi.NewTestReplicationAPIClient(t) + return api, db, mocks, cleanup } func insertInitialRows(t *testing.T, store *sql.DB) { diff --git a/pkg/api/metadata/cursor_test.go b/pkg/api/metadata/cursor_test.go index 9d0cdd5b..252a8870 100644 --- a/pkg/api/metadata/cursor_test.go +++ b/pkg/api/metadata/cursor_test.go @@ -3,13 +3,15 @@ package metadata_test import ( "context" "database/sql" + "testing" + "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/metadata_api" - "testing" - "time" "github.com/xmtp/xmtpd/pkg/api/message" + dbUtils "github.com/xmtp/xmtpd/pkg/db" "github.com/xmtp/xmtpd/pkg/db/queries" "github.com/xmtp/xmtpd/pkg/testutils" testUtilsApi "github.com/xmtp/xmtpd/pkg/testutils/api" @@ -26,12 +28,16 @@ var allRows []queries.InsertGatewayEnvelopeParams func setupTest( t *testing.T, ) (metadata_api.MetadataApiClient, *sql.DB, testUtilsApi.ApiServerMocks, func()) { + api, db, mocks, cleanup := testUtilsApi.NewTestMetadataAPIClient(t) + payerId := dbUtils.NullInt32(testutils.CreatePayer(t, db)) + allRows = []queries.InsertGatewayEnvelopeParams{ // Initial rows { OriginatorNodeID: 1, OriginatorSequenceID: 1, Topic: topicA, + PayerID: payerId, OriginatorEnvelope: testutils.Marshal( t, envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 1, topicA), @@ -41,6 +47,7 @@ func setupTest( OriginatorNodeID: 2, OriginatorSequenceID: 1, Topic: topicA, + PayerID: payerId, OriginatorEnvelope: testutils.Marshal( t, envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 2, 1, topicA), @@ -51,6 +58,7 @@ func setupTest( OriginatorNodeID: 1, OriginatorSequenceID: 2, Topic: topicB, + PayerID: payerId, OriginatorEnvelope: testutils.Marshal( t, envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 2, topicB), @@ -60,6 +68,7 @@ func setupTest( OriginatorNodeID: 2, OriginatorSequenceID: 2, Topic: topicB, + PayerID: payerId, OriginatorEnvelope: testutils.Marshal( t, envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 2, 2, topicB), @@ -69,6 +78,7 @@ func setupTest( OriginatorNodeID: 1, OriginatorSequenceID: 3, Topic: topicA, + PayerID: payerId, OriginatorEnvelope: testutils.Marshal( t, envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 3, topicA), @@ -76,7 +86,7 @@ func setupTest( }, } - return testUtilsApi.NewTestMetadataAPIClient(t) + return api, db, mocks, cleanup } func insertInitialRows(t *testing.T, store *sql.DB) { diff --git a/pkg/api/query_test.go b/pkg/api/query_test.go index 9468b4d3..18547192 100644 --- a/pkg/api/query_test.go +++ b/pkg/api/query_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/require" "github.com/xmtp/xmtpd/pkg/db" + dbUtils "github.com/xmtp/xmtpd/pkg/db" "github.com/xmtp/xmtpd/pkg/db/queries" "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/envelopes" "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api" @@ -23,11 +24,13 @@ var ( ) func setupQueryTest(t *testing.T, db *sql.DB) []queries.InsertGatewayEnvelopeParams { + payerId := dbUtils.NullInt32(testutils.CreatePayer(t, db)) db_rows := []queries.InsertGatewayEnvelopeParams{ { OriginatorNodeID: 1, OriginatorSequenceID: 1, Topic: topicA, + PayerID: payerId, OriginatorEnvelope: testutils.Marshal( t, envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 1, topicA), @@ -37,6 +40,7 @@ func setupQueryTest(t *testing.T, db *sql.DB) []queries.InsertGatewayEnvelopePar OriginatorNodeID: 2, OriginatorSequenceID: 1, Topic: topicA, + PayerID: payerId, OriginatorEnvelope: testutils.Marshal( t, envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 2, 1, topicA), @@ -46,6 +50,7 @@ func setupQueryTest(t *testing.T, db *sql.DB) []queries.InsertGatewayEnvelopePar OriginatorNodeID: 1, OriginatorSequenceID: 2, Topic: topicB, + PayerID: payerId, OriginatorEnvelope: testutils.Marshal( t, envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 2, topicB), @@ -55,6 +60,7 @@ func setupQueryTest(t *testing.T, db *sql.DB) []queries.InsertGatewayEnvelopePar OriginatorNodeID: 2, OriginatorSequenceID: 2, Topic: topicB, + PayerID: payerId, OriginatorEnvelope: testutils.Marshal( t, envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 2, 2, topicB), @@ -64,6 +70,7 @@ func setupQueryTest(t *testing.T, db *sql.DB) []queries.InsertGatewayEnvelopePar OriginatorNodeID: 1, OriginatorSequenceID: 3, Topic: topicA, + PayerID: payerId, OriginatorEnvelope: testutils.Marshal( t, envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 3, topicA), diff --git a/pkg/db/queries.sql b/pkg/db/queries.sql index dc06efe8..a0e0182e 100644 --- a/pkg/db/queries.sql +++ b/pkg/db/queries.sql @@ -13,8 +13,8 @@ WHERE singleton_id = 1; -- name: InsertGatewayEnvelope :execrows -INSERT INTO gateway_envelopes(originator_node_id, originator_sequence_id, topic, originator_envelope) - VALUES (@originator_node_id, @originator_sequence_id, @topic, @originator_envelope) +INSERT INTO gateway_envelopes(originator_node_id, originator_sequence_id, topic, originator_envelope, payer_id) + VALUES (@originator_node_id, @originator_sequence_id, @topic, @originator_envelope, @payer_id) ON CONFLICT DO NOTHING; diff --git a/pkg/db/queries/models.go b/pkg/db/queries/models.go index 31a7cc22..bfb2c9fc 100644 --- a/pkg/db/queries/models.go +++ b/pkg/db/queries/models.go @@ -30,6 +30,7 @@ type GatewayEnvelope struct { OriginatorSequenceID int64 Topic []byte OriginatorEnvelope []byte + PayerID sql.NullInt32 } type LatestBlock struct { diff --git a/pkg/db/queries/queries.sql.go b/pkg/db/queries/queries.sql.go index 428df05c..f717f173 100644 --- a/pkg/db/queries/queries.sql.go +++ b/pkg/db/queries/queries.sql.go @@ -272,8 +272,8 @@ func (q *Queries) InsertBlockchainMessage(ctx context.Context, arg InsertBlockch } const insertGatewayEnvelope = `-- name: InsertGatewayEnvelope :execrows -INSERT INTO gateway_envelopes(originator_node_id, originator_sequence_id, topic, originator_envelope) - VALUES ($1, $2, $3, $4) +INSERT INTO gateway_envelopes(originator_node_id, originator_sequence_id, topic, originator_envelope, payer_id) + VALUES ($1, $2, $3, $4, $5) ON CONFLICT DO NOTHING ` @@ -283,6 +283,7 @@ type InsertGatewayEnvelopeParams struct { OriginatorSequenceID int64 Topic []byte OriginatorEnvelope []byte + PayerID sql.NullInt32 } func (q *Queries) InsertGatewayEnvelope(ctx context.Context, arg InsertGatewayEnvelopeParams) (int64, error) { @@ -291,6 +292,7 @@ func (q *Queries) InsertGatewayEnvelope(ctx context.Context, arg InsertGatewayEn arg.OriginatorSequenceID, arg.Topic, arg.OriginatorEnvelope, + arg.PayerID, ) if err != nil { return 0, err @@ -368,7 +370,7 @@ func (q *Queries) RevokeAddressFromLog(ctx context.Context, arg RevokeAddressFro const selectGatewayEnvelopes = `-- name: SelectGatewayEnvelopes :many SELECT - gateway_time, originator_node_id, originator_sequence_id, topic, originator_envelope + gateway_time, originator_node_id, originator_sequence_id, topic, originator_envelope, payer_id FROM select_gateway_envelopes($1::INT[], $2::BIGINT[], $3::BYTEA[], $4::INT[], $5::INT) ` @@ -402,6 +404,7 @@ func (q *Queries) SelectGatewayEnvelopes(ctx context.Context, arg SelectGatewayE &i.OriginatorSequenceID, &i.Topic, &i.OriginatorEnvelope, + &i.PayerID, ); err != nil { return nil, err } diff --git a/pkg/migrations/00006_link-payer-to-gateway-envelopes.down.sql b/pkg/migrations/00006_link-payer-to-gateway-envelopes.down.sql new file mode 100644 index 00000000..25421e1a --- /dev/null +++ b/pkg/migrations/00006_link-payer-to-gateway-envelopes.down.sql @@ -0,0 +1,3 @@ +ALTER TABLE gateway_envelopes + DROP COLUMN payer_id; + diff --git a/pkg/migrations/00006_link-payer-to-gateway-envelopes.up.sql b/pkg/migrations/00006_link-payer-to-gateway-envelopes.up.sql new file mode 100644 index 00000000..8a609821 --- /dev/null +++ b/pkg/migrations/00006_link-payer-to-gateway-envelopes.up.sql @@ -0,0 +1,4 @@ +ALTER TABLE gateway_envelopes +-- Leave column nullable since blockchain originated messages won't have a payer_id + ADD COLUMN payer_id INT REFERENCES payers(id); + diff --git a/pkg/sync/syncWorker.go b/pkg/sync/syncWorker.go index 6cbf0373..628a704e 100644 --- a/pkg/sync/syncWorker.go +++ b/pkg/sync/syncWorker.go @@ -435,6 +435,12 @@ func (s *syncWorker) insertEnvelope(env *envUtils.OriginatorEnvelope) { return } + payerId, err := s.getPayerID(env) + if err != nil { + s.log.Error("Failed to get payer ID", zap.Error(err)) + return + } + q := queries.New(s.store) inserted, err := q.InsertGatewayEnvelope( s.ctx, @@ -443,6 +449,7 @@ func (s *syncWorker) insertEnvelope(env *envUtils.OriginatorEnvelope) { OriginatorSequenceID: int64(env.OriginatorSequenceID()), Topic: env.TargetTopic().Bytes(), OriginatorEnvelope: originatorBytes, + PayerID: db.NullInt32(payerId), }, ) if err != nil { @@ -454,3 +461,18 @@ func (s *syncWorker) insertEnvelope(env *envUtils.OriginatorEnvelope) { return } } + +func (s *syncWorker) getPayerID(env *envUtils.OriginatorEnvelope) (int32, error) { + payerAddress, err := env.UnsignedOriginatorEnvelope.PayerEnvelope.RecoverSigner() + if err != nil { + return 0, err + } + + q := queries.New(s.store) + payerId, err := q.FindOrCreatePayer(s.ctx, payerAddress.Hex()) + if err != nil { + return 0, err + } + + return payerId, nil +} diff --git a/pkg/testutils/envelopes/envelopes.go b/pkg/testutils/envelopes/envelopes.go index 71eafa2a..4cd203e0 100644 --- a/pkg/testutils/envelopes/envelopes.go +++ b/pkg/testutils/envelopes/envelopes.go @@ -1,15 +1,15 @@ package testutils import ( - "github.com/ethereum/go-ethereum/crypto" - "github.com/xmtp/xmtpd/pkg/utils" "testing" + "github.com/ethereum/go-ethereum/crypto" "github.com/stretchr/testify/require" "github.com/xmtp/xmtpd/pkg/proto/identity/associations" mlsv1 "github.com/xmtp/xmtpd/pkg/proto/mls/api/v1" envelopes "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/envelopes" "github.com/xmtp/xmtpd/pkg/topic" + "github.com/xmtp/xmtpd/pkg/utils" "google.golang.org/protobuf/proto" ) @@ -89,6 +89,7 @@ func CreatePayerEnvelope( if len(clientEnv) == 0 { clientEnv = append(clientEnv, CreateClientEnvelope()) } + clientEnvBytes, err := proto.Marshal(clientEnv[0]) require.NoError(t, err) diff --git a/pkg/testutils/store.go b/pkg/testutils/store.go index 42b9c87f..2691c9d7 100644 --- a/pkg/testutils/store.go +++ b/pkg/testutils/store.go @@ -110,3 +110,18 @@ func InsertGatewayEnvelopes( } } } + +func CreatePayer(t *testing.T, db *sql.DB, address ...string) int32 { + q := queries.New(db) + var payerAddress string + if len(address) > 0 { + payerAddress = address[0] + } else { + payerAddress = RandomString(42) + } + + id, err := q.FindOrCreatePayer(context.Background(), payerAddress) + require.NoError(t, err) + + return id +}