diff --git a/pkg/api/message/publishWorker.go b/pkg/api/message/publishWorker.go index 4253eeb6..86bdc59d 100644 --- a/pkg/api/message/publishWorker.go +++ b/pkg/api/message/publishWorker.go @@ -8,9 +8,9 @@ import ( "github.com/xmtp/xmtpd/pkg/db" "github.com/xmtp/xmtpd/pkg/db/queries" + "github.com/xmtp/xmtpd/pkg/envelopes" "github.com/xmtp/xmtpd/pkg/registrant" "go.uber.org/zap" - "google.golang.org/protobuf/proto" ) type publishWorker struct { @@ -110,7 +110,12 @@ func (p *publishWorker) publishStagedEnvelope(stagedEnv queries.StagedOriginator ) return false } - originatorBytes, err := proto.Marshal(originatorEnv) + validatedEnvelope, err := envelopes.NewOriginatorEnvelope(originatorEnv) + if err != nil { + logger.Error("Failed to validate originator envelope", zap.Error(err)) + return false + } + originatorBytes, err := validatedEnvelope.Bytes() if err != nil { logger.Error("Failed to marshal originator envelope", zap.Error(err)) return false @@ -118,6 +123,18 @@ func (p *publishWorker) publishStagedEnvelope(stagedEnv queries.StagedOriginator q := queries.New(p.store) + payerAddress, err := validatedEnvelope.UnsignedOriginatorEnvelope.PayerEnvelope.RecoverSigner() + if err != nil { + logger.Error("Failed to recover payer address", zap.Error(err)) + return false + } + + payerId, err := q.FindOrCreatePayer(p.ctx, payerAddress.Hex()) + if err != nil { + logger.Error("Failed to find or create payer", zap.Error(err)) + return false + } + // On unique constraint conflicts, no error is thrown, but numRows is 0 inserted, err := q.InsertGatewayEnvelope( p.ctx, @@ -126,6 +143,7 @@ func (p *publishWorker) publishStagedEnvelope(stagedEnv queries.StagedOriginator OriginatorSequenceID: stagedEnv.ID, Topic: stagedEnv.Topic, OriginatorEnvelope: originatorBytes, + PayerID: db.NullInt32(payerId), }, ) if p.ctx.Err() != nil { diff --git a/pkg/api/message/subscribe_test.go b/pkg/api/message/subscribe_test.go index 09a5c8dd..932caaa5 100644 --- a/pkg/api/message/subscribe_test.go +++ b/pkg/api/message/subscribe_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "github.com/xmtp/xmtpd/pkg/api/message" "github.com/xmtp/xmtpd/pkg/db" + dbUtils "github.com/xmtp/xmtpd/pkg/db" "github.com/xmtp/xmtpd/pkg/db/queries" "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/envelopes" "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api" @@ -28,12 +29,16 @@ var allRows []queries.InsertGatewayEnvelopeParams func setupTest( t *testing.T, ) (message_api.ReplicationApiClient, *sql.DB, testUtilsApi.ApiServerMocks, func()) { + api, db, mocks, cleanup := testUtilsApi.NewTestReplicationAPIClient(t) + + payerId := dbUtils.NullInt32(testutils.CreatePayer(t, db)) allRows = []queries.InsertGatewayEnvelopeParams{ // Initial rows { OriginatorNodeID: 1, OriginatorSequenceID: 1, Topic: topicA, + PayerID: payerId, OriginatorEnvelope: testutils.Marshal( t, envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 1, topicA), @@ -43,6 +48,7 @@ func setupTest( OriginatorNodeID: 2, OriginatorSequenceID: 1, Topic: topicA, + PayerID: payerId, OriginatorEnvelope: testutils.Marshal( t, envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 2, 1, topicA), @@ -53,6 +59,7 @@ func setupTest( OriginatorNodeID: 1, OriginatorSequenceID: 2, Topic: topicB, + PayerID: payerId, OriginatorEnvelope: testutils.Marshal( t, envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 2, topicB), @@ -62,6 +69,7 @@ func setupTest( OriginatorNodeID: 2, OriginatorSequenceID: 2, Topic: topicB, + PayerID: payerId, OriginatorEnvelope: testutils.Marshal( t, envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 2, 2, topicB), @@ -71,6 +79,7 @@ func setupTest( OriginatorNodeID: 1, OriginatorSequenceID: 3, Topic: topicA, + PayerID: payerId, OriginatorEnvelope: testutils.Marshal( t, envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 3, topicA), @@ -78,7 +87,7 @@ func setupTest( }, } - return testUtilsApi.NewTestReplicationAPIClient(t) + return api, db, mocks, cleanup } func insertInitialRows(t *testing.T, store *sql.DB) { diff --git a/pkg/api/metadata/cursor_test.go b/pkg/api/metadata/cursor_test.go index 9d0cdd5b..252a8870 100644 --- a/pkg/api/metadata/cursor_test.go +++ b/pkg/api/metadata/cursor_test.go @@ -3,13 +3,15 @@ package metadata_test import ( "context" "database/sql" + "testing" + "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/metadata_api" - "testing" - "time" "github.com/xmtp/xmtpd/pkg/api/message" + dbUtils "github.com/xmtp/xmtpd/pkg/db" "github.com/xmtp/xmtpd/pkg/db/queries" "github.com/xmtp/xmtpd/pkg/testutils" testUtilsApi "github.com/xmtp/xmtpd/pkg/testutils/api" @@ -26,12 +28,16 @@ var allRows []queries.InsertGatewayEnvelopeParams func setupTest( t *testing.T, ) (metadata_api.MetadataApiClient, *sql.DB, testUtilsApi.ApiServerMocks, func()) { + api, db, mocks, cleanup := testUtilsApi.NewTestMetadataAPIClient(t) + payerId := dbUtils.NullInt32(testutils.CreatePayer(t, db)) + allRows = []queries.InsertGatewayEnvelopeParams{ // Initial rows { OriginatorNodeID: 1, OriginatorSequenceID: 1, Topic: topicA, + PayerID: payerId, OriginatorEnvelope: testutils.Marshal( t, envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 1, topicA), @@ -41,6 +47,7 @@ func setupTest( OriginatorNodeID: 2, OriginatorSequenceID: 1, Topic: topicA, + PayerID: payerId, OriginatorEnvelope: testutils.Marshal( t, envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 2, 1, topicA), @@ -51,6 +58,7 @@ func setupTest( OriginatorNodeID: 1, OriginatorSequenceID: 2, Topic: topicB, + PayerID: payerId, OriginatorEnvelope: testutils.Marshal( t, envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 2, topicB), @@ -60,6 +68,7 @@ func setupTest( OriginatorNodeID: 2, OriginatorSequenceID: 2, Topic: topicB, + PayerID: payerId, OriginatorEnvelope: testutils.Marshal( t, envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 2, 2, topicB), @@ -69,6 +78,7 @@ func setupTest( OriginatorNodeID: 1, OriginatorSequenceID: 3, Topic: topicA, + PayerID: payerId, OriginatorEnvelope: testutils.Marshal( t, envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 3, topicA), @@ -76,7 +86,7 @@ func setupTest( }, } - return testUtilsApi.NewTestMetadataAPIClient(t) + return api, db, mocks, cleanup } func insertInitialRows(t *testing.T, store *sql.DB) { diff --git a/pkg/api/query_test.go b/pkg/api/query_test.go index 9468b4d3..18547192 100644 --- a/pkg/api/query_test.go +++ b/pkg/api/query_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/require" "github.com/xmtp/xmtpd/pkg/db" + dbUtils "github.com/xmtp/xmtpd/pkg/db" "github.com/xmtp/xmtpd/pkg/db/queries" "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/envelopes" "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api" @@ -23,11 +24,13 @@ var ( ) func setupQueryTest(t *testing.T, db *sql.DB) []queries.InsertGatewayEnvelopeParams { + payerId := dbUtils.NullInt32(testutils.CreatePayer(t, db)) db_rows := []queries.InsertGatewayEnvelopeParams{ { OriginatorNodeID: 1, OriginatorSequenceID: 1, Topic: topicA, + PayerID: payerId, OriginatorEnvelope: testutils.Marshal( t, envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 1, topicA), @@ -37,6 +40,7 @@ func setupQueryTest(t *testing.T, db *sql.DB) []queries.InsertGatewayEnvelopePar OriginatorNodeID: 2, OriginatorSequenceID: 1, Topic: topicA, + PayerID: payerId, OriginatorEnvelope: testutils.Marshal( t, envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 2, 1, topicA), @@ -46,6 +50,7 @@ func setupQueryTest(t *testing.T, db *sql.DB) []queries.InsertGatewayEnvelopePar OriginatorNodeID: 1, OriginatorSequenceID: 2, Topic: topicB, + PayerID: payerId, OriginatorEnvelope: testutils.Marshal( t, envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 2, topicB), @@ -55,6 +60,7 @@ func setupQueryTest(t *testing.T, db *sql.DB) []queries.InsertGatewayEnvelopePar OriginatorNodeID: 2, OriginatorSequenceID: 2, Topic: topicB, + PayerID: payerId, OriginatorEnvelope: testutils.Marshal( t, envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 2, 2, topicB), @@ -64,6 +70,7 @@ func setupQueryTest(t *testing.T, db *sql.DB) []queries.InsertGatewayEnvelopePar OriginatorNodeID: 1, OriginatorSequenceID: 3, Topic: topicA, + PayerID: payerId, OriginatorEnvelope: testutils.Marshal( t, envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 3, topicA), diff --git a/pkg/db/queries.sql b/pkg/db/queries.sql index dc06efe8..a0e0182e 100644 --- a/pkg/db/queries.sql +++ b/pkg/db/queries.sql @@ -13,8 +13,8 @@ WHERE singleton_id = 1; -- name: InsertGatewayEnvelope :execrows -INSERT INTO gateway_envelopes(originator_node_id, originator_sequence_id, topic, originator_envelope) - VALUES (@originator_node_id, @originator_sequence_id, @topic, @originator_envelope) +INSERT INTO gateway_envelopes(originator_node_id, originator_sequence_id, topic, originator_envelope, payer_id) + VALUES (@originator_node_id, @originator_sequence_id, @topic, @originator_envelope, @payer_id) ON CONFLICT DO NOTHING; diff --git a/pkg/db/queries/models.go b/pkg/db/queries/models.go index 31a7cc22..bfb2c9fc 100644 --- a/pkg/db/queries/models.go +++ b/pkg/db/queries/models.go @@ -30,6 +30,7 @@ type GatewayEnvelope struct { OriginatorSequenceID int64 Topic []byte OriginatorEnvelope []byte + PayerID sql.NullInt32 } type LatestBlock struct { diff --git a/pkg/db/queries/queries.sql.go b/pkg/db/queries/queries.sql.go index 428df05c..f717f173 100644 --- a/pkg/db/queries/queries.sql.go +++ b/pkg/db/queries/queries.sql.go @@ -272,8 +272,8 @@ func (q *Queries) InsertBlockchainMessage(ctx context.Context, arg InsertBlockch } const insertGatewayEnvelope = `-- name: InsertGatewayEnvelope :execrows -INSERT INTO gateway_envelopes(originator_node_id, originator_sequence_id, topic, originator_envelope) - VALUES ($1, $2, $3, $4) +INSERT INTO gateway_envelopes(originator_node_id, originator_sequence_id, topic, originator_envelope, payer_id) + VALUES ($1, $2, $3, $4, $5) ON CONFLICT DO NOTHING ` @@ -283,6 +283,7 @@ type InsertGatewayEnvelopeParams struct { OriginatorSequenceID int64 Topic []byte OriginatorEnvelope []byte + PayerID sql.NullInt32 } func (q *Queries) InsertGatewayEnvelope(ctx context.Context, arg InsertGatewayEnvelopeParams) (int64, error) { @@ -291,6 +292,7 @@ func (q *Queries) InsertGatewayEnvelope(ctx context.Context, arg InsertGatewayEn arg.OriginatorSequenceID, arg.Topic, arg.OriginatorEnvelope, + arg.PayerID, ) if err != nil { return 0, err @@ -368,7 +370,7 @@ func (q *Queries) RevokeAddressFromLog(ctx context.Context, arg RevokeAddressFro const selectGatewayEnvelopes = `-- name: SelectGatewayEnvelopes :many SELECT - gateway_time, originator_node_id, originator_sequence_id, topic, originator_envelope + gateway_time, originator_node_id, originator_sequence_id, topic, originator_envelope, payer_id FROM select_gateway_envelopes($1::INT[], $2::BIGINT[], $3::BYTEA[], $4::INT[], $5::INT) ` @@ -402,6 +404,7 @@ func (q *Queries) SelectGatewayEnvelopes(ctx context.Context, arg SelectGatewayE &i.OriginatorSequenceID, &i.Topic, &i.OriginatorEnvelope, + &i.PayerID, ); err != nil { return nil, err } diff --git a/pkg/migrations/00006_link-payer-to-gateway-envelopes.down.sql b/pkg/migrations/00006_link-payer-to-gateway-envelopes.down.sql new file mode 100644 index 00000000..25421e1a --- /dev/null +++ b/pkg/migrations/00006_link-payer-to-gateway-envelopes.down.sql @@ -0,0 +1,3 @@ +ALTER TABLE gateway_envelopes + DROP COLUMN payer_id; + diff --git a/pkg/migrations/00006_link-payer-to-gateway-envelopes.up.sql b/pkg/migrations/00006_link-payer-to-gateway-envelopes.up.sql new file mode 100644 index 00000000..8a609821 --- /dev/null +++ b/pkg/migrations/00006_link-payer-to-gateway-envelopes.up.sql @@ -0,0 +1,4 @@ +ALTER TABLE gateway_envelopes +-- Leave column nullable since blockchain originated messages won't have a payer_id + ADD COLUMN payer_id INT REFERENCES payers(id); + diff --git a/pkg/sync/syncWorker.go b/pkg/sync/syncWorker.go index 6cbf0373..628a704e 100644 --- a/pkg/sync/syncWorker.go +++ b/pkg/sync/syncWorker.go @@ -435,6 +435,12 @@ func (s *syncWorker) insertEnvelope(env *envUtils.OriginatorEnvelope) { return } + payerId, err := s.getPayerID(env) + if err != nil { + s.log.Error("Failed to get payer ID", zap.Error(err)) + return + } + q := queries.New(s.store) inserted, err := q.InsertGatewayEnvelope( s.ctx, @@ -443,6 +449,7 @@ func (s *syncWorker) insertEnvelope(env *envUtils.OriginatorEnvelope) { OriginatorSequenceID: int64(env.OriginatorSequenceID()), Topic: env.TargetTopic().Bytes(), OriginatorEnvelope: originatorBytes, + PayerID: db.NullInt32(payerId), }, ) if err != nil { @@ -454,3 +461,18 @@ func (s *syncWorker) insertEnvelope(env *envUtils.OriginatorEnvelope) { return } } + +func (s *syncWorker) getPayerID(env *envUtils.OriginatorEnvelope) (int32, error) { + payerAddress, err := env.UnsignedOriginatorEnvelope.PayerEnvelope.RecoverSigner() + if err != nil { + return 0, err + } + + q := queries.New(s.store) + payerId, err := q.FindOrCreatePayer(s.ctx, payerAddress.Hex()) + if err != nil { + return 0, err + } + + return payerId, nil +} diff --git a/pkg/testutils/store.go b/pkg/testutils/store.go index 42b9c87f..2691c9d7 100644 --- a/pkg/testutils/store.go +++ b/pkg/testutils/store.go @@ -110,3 +110,18 @@ func InsertGatewayEnvelopes( } } } + +func CreatePayer(t *testing.T, db *sql.DB, address ...string) int32 { + q := queries.New(db) + var payerAddress string + if len(address) > 0 { + payerAddress = address[0] + } else { + payerAddress = RandomString(42) + } + + id, err := q.FindOrCreatePayer(context.Background(), payerAddress) + require.NoError(t, err) + + return id +}