Skip to content

Commit

Permalink
Link payer to gateway envelopes
Browse files Browse the repository at this point in the history
  • Loading branch information
neekolas committed Feb 21, 2025
1 parent 1758eb6 commit 9eb284b
Show file tree
Hide file tree
Showing 11 changed files with 103 additions and 11 deletions.
22 changes: 20 additions & 2 deletions pkg/api/message/publishWorker.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ import (

"github.com/xmtp/xmtpd/pkg/db"
"github.com/xmtp/xmtpd/pkg/db/queries"
"github.com/xmtp/xmtpd/pkg/envelopes"
"github.com/xmtp/xmtpd/pkg/registrant"
"go.uber.org/zap"
"google.golang.org/protobuf/proto"
)

type publishWorker struct {
Expand Down Expand Up @@ -110,14 +110,31 @@ func (p *publishWorker) publishStagedEnvelope(stagedEnv queries.StagedOriginator
)
return false
}
originatorBytes, err := proto.Marshal(originatorEnv)
validatedEnvelope, err := envelopes.NewOriginatorEnvelope(originatorEnv)
if err != nil {
logger.Error("Failed to validate originator envelope", zap.Error(err))
return false
}
originatorBytes, err := validatedEnvelope.Bytes()
if err != nil {
logger.Error("Failed to marshal originator envelope", zap.Error(err))
return false
}

q := queries.New(p.store)

payerAddress, err := validatedEnvelope.UnsignedOriginatorEnvelope.PayerEnvelope.RecoverSigner()
if err != nil {
logger.Error("Failed to recover payer address", zap.Error(err))
return false
}

payerId, err := q.FindOrCreatePayer(p.ctx, payerAddress.Hex())
if err != nil {
logger.Error("Failed to find or create payer", zap.Error(err))
return false
}

// On unique constraint conflicts, no error is thrown, but numRows is 0
inserted, err := q.InsertGatewayEnvelope(
p.ctx,
Expand All @@ -126,6 +143,7 @@ func (p *publishWorker) publishStagedEnvelope(stagedEnv queries.StagedOriginator
OriginatorSequenceID: stagedEnv.ID,
Topic: stagedEnv.Topic,
OriginatorEnvelope: originatorBytes,
PayerID: db.NullInt32(payerId),
},
)
if p.ctx.Err() != nil {
Expand Down
11 changes: 10 additions & 1 deletion pkg/api/message/subscribe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/xmtp/xmtpd/pkg/api/message"
"github.com/xmtp/xmtpd/pkg/db"
dbUtils "github.com/xmtp/xmtpd/pkg/db"
"github.com/xmtp/xmtpd/pkg/db/queries"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/envelopes"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api"
Expand All @@ -28,12 +29,16 @@ var allRows []queries.InsertGatewayEnvelopeParams
func setupTest(
t *testing.T,
) (message_api.ReplicationApiClient, *sql.DB, testUtilsApi.ApiServerMocks, func()) {
api, db, mocks, cleanup := testUtilsApi.NewTestReplicationAPIClient(t)

payerId := dbUtils.NullInt32(testutils.CreatePayer(t, db))
allRows = []queries.InsertGatewayEnvelopeParams{
// Initial rows
{
OriginatorNodeID: 1,
OriginatorSequenceID: 1,
Topic: topicA,
PayerID: payerId,
OriginatorEnvelope: testutils.Marshal(
t,
envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 1, topicA),
Expand All @@ -43,6 +48,7 @@ func setupTest(
OriginatorNodeID: 2,
OriginatorSequenceID: 1,
Topic: topicA,
PayerID: payerId,
OriginatorEnvelope: testutils.Marshal(
t,
envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 2, 1, topicA),
Expand All @@ -53,6 +59,7 @@ func setupTest(
OriginatorNodeID: 1,
OriginatorSequenceID: 2,
Topic: topicB,
PayerID: payerId,
OriginatorEnvelope: testutils.Marshal(
t,
envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 2, topicB),
Expand All @@ -62,6 +69,7 @@ func setupTest(
OriginatorNodeID: 2,
OriginatorSequenceID: 2,
Topic: topicB,
PayerID: payerId,
OriginatorEnvelope: testutils.Marshal(
t,
envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 2, 2, topicB),
Expand All @@ -71,14 +79,15 @@ func setupTest(
OriginatorNodeID: 1,
OriginatorSequenceID: 3,
Topic: topicA,
PayerID: payerId,
OriginatorEnvelope: testutils.Marshal(
t,
envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 3, topicA),
),
},
}

return testUtilsApi.NewTestReplicationAPIClient(t)
return api, db, mocks, cleanup
}

func insertInitialRows(t *testing.T, store *sql.DB) {
Expand Down
16 changes: 13 additions & 3 deletions pkg/api/metadata/cursor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ package metadata_test
import (
"context"
"database/sql"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/metadata_api"
"testing"
"time"

"github.com/xmtp/xmtpd/pkg/api/message"
dbUtils "github.com/xmtp/xmtpd/pkg/db"
"github.com/xmtp/xmtpd/pkg/db/queries"
"github.com/xmtp/xmtpd/pkg/testutils"
testUtilsApi "github.com/xmtp/xmtpd/pkg/testutils/api"
Expand All @@ -26,12 +28,16 @@ var allRows []queries.InsertGatewayEnvelopeParams
func setupTest(
t *testing.T,
) (metadata_api.MetadataApiClient, *sql.DB, testUtilsApi.ApiServerMocks, func()) {
api, db, mocks, cleanup := testUtilsApi.NewTestMetadataAPIClient(t)
payerId := dbUtils.NullInt32(testutils.CreatePayer(t, db))

allRows = []queries.InsertGatewayEnvelopeParams{
// Initial rows
{
OriginatorNodeID: 1,
OriginatorSequenceID: 1,
Topic: topicA,
PayerID: payerId,
OriginatorEnvelope: testutils.Marshal(
t,
envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 1, topicA),
Expand All @@ -41,6 +47,7 @@ func setupTest(
OriginatorNodeID: 2,
OriginatorSequenceID: 1,
Topic: topicA,
PayerID: payerId,
OriginatorEnvelope: testutils.Marshal(
t,
envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 2, 1, topicA),
Expand All @@ -51,6 +58,7 @@ func setupTest(
OriginatorNodeID: 1,
OriginatorSequenceID: 2,
Topic: topicB,
PayerID: payerId,
OriginatorEnvelope: testutils.Marshal(
t,
envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 2, topicB),
Expand All @@ -60,6 +68,7 @@ func setupTest(
OriginatorNodeID: 2,
OriginatorSequenceID: 2,
Topic: topicB,
PayerID: payerId,
OriginatorEnvelope: testutils.Marshal(
t,
envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 2, 2, topicB),
Expand All @@ -69,14 +78,15 @@ func setupTest(
OriginatorNodeID: 1,
OriginatorSequenceID: 3,
Topic: topicA,
PayerID: payerId,
OriginatorEnvelope: testutils.Marshal(
t,
envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 3, topicA),
),
},
}

return testUtilsApi.NewTestMetadataAPIClient(t)
return api, db, mocks, cleanup
}

func insertInitialRows(t *testing.T, store *sql.DB) {
Expand Down
7 changes: 7 additions & 0 deletions pkg/api/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"github.com/stretchr/testify/require"
"github.com/xmtp/xmtpd/pkg/db"
dbUtils "github.com/xmtp/xmtpd/pkg/db"
"github.com/xmtp/xmtpd/pkg/db/queries"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/envelopes"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api"
Expand All @@ -23,11 +24,13 @@ var (
)

func setupQueryTest(t *testing.T, db *sql.DB) []queries.InsertGatewayEnvelopeParams {
payerId := dbUtils.NullInt32(testutils.CreatePayer(t, db))
db_rows := []queries.InsertGatewayEnvelopeParams{
{
OriginatorNodeID: 1,
OriginatorSequenceID: 1,
Topic: topicA,
PayerID: payerId,
OriginatorEnvelope: testutils.Marshal(
t,
envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 1, topicA),
Expand All @@ -37,6 +40,7 @@ func setupQueryTest(t *testing.T, db *sql.DB) []queries.InsertGatewayEnvelopePar
OriginatorNodeID: 2,
OriginatorSequenceID: 1,
Topic: topicA,
PayerID: payerId,
OriginatorEnvelope: testutils.Marshal(
t,
envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 2, 1, topicA),
Expand All @@ -46,6 +50,7 @@ func setupQueryTest(t *testing.T, db *sql.DB) []queries.InsertGatewayEnvelopePar
OriginatorNodeID: 1,
OriginatorSequenceID: 2,
Topic: topicB,
PayerID: payerId,
OriginatorEnvelope: testutils.Marshal(
t,
envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 2, topicB),
Expand All @@ -55,6 +60,7 @@ func setupQueryTest(t *testing.T, db *sql.DB) []queries.InsertGatewayEnvelopePar
OriginatorNodeID: 2,
OriginatorSequenceID: 2,
Topic: topicB,
PayerID: payerId,
OriginatorEnvelope: testutils.Marshal(
t,
envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 2, 2, topicB),
Expand All @@ -64,6 +70,7 @@ func setupQueryTest(t *testing.T, db *sql.DB) []queries.InsertGatewayEnvelopePar
OriginatorNodeID: 1,
OriginatorSequenceID: 3,
Topic: topicA,
PayerID: payerId,
OriginatorEnvelope: testutils.Marshal(
t,
envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 3, topicA),
Expand Down
4 changes: 2 additions & 2 deletions pkg/db/queries.sql
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ WHERE
singleton_id = 1;

-- name: InsertGatewayEnvelope :execrows
INSERT INTO gateway_envelopes(originator_node_id, originator_sequence_id, topic, originator_envelope)
VALUES (@originator_node_id, @originator_sequence_id, @topic, @originator_envelope)
INSERT INTO gateway_envelopes(originator_node_id, originator_sequence_id, topic, originator_envelope, payer_id)
VALUES (@originator_node_id, @originator_sequence_id, @topic, @originator_envelope, @payer_id)
ON CONFLICT
DO NOTHING;

Expand Down
1 change: 1 addition & 0 deletions pkg/db/queries/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 6 additions & 3 deletions pkg/db/queries/queries.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions pkg/migrations/00006_link-payer-to-gateway-envelopes.down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
ALTER TABLE gateway_envelopes
DROP COLUMN payer_id;

4 changes: 4 additions & 0 deletions pkg/migrations/00006_link-payer-to-gateway-envelopes.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
ALTER TABLE gateway_envelopes
-- Leave column nullable since blockchain originated messages won't have a payer_id
ADD COLUMN payer_id INT REFERENCES payers(id);

22 changes: 22 additions & 0 deletions pkg/sync/syncWorker.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,12 @@ func (s *syncWorker) insertEnvelope(env *envUtils.OriginatorEnvelope) {
return
}

payerId, err := s.getPayerID(env)
if err != nil {
s.log.Error("Failed to get payer ID", zap.Error(err))
return
}

q := queries.New(s.store)
inserted, err := q.InsertGatewayEnvelope(
s.ctx,
Expand All @@ -443,6 +449,7 @@ func (s *syncWorker) insertEnvelope(env *envUtils.OriginatorEnvelope) {
OriginatorSequenceID: int64(env.OriginatorSequenceID()),
Topic: env.TargetTopic().Bytes(),
OriginatorEnvelope: originatorBytes,
PayerID: db.NullInt32(payerId),
},
)
if err != nil {
Expand All @@ -454,3 +461,18 @@ func (s *syncWorker) insertEnvelope(env *envUtils.OriginatorEnvelope) {
return
}
}

func (s *syncWorker) getPayerID(env *envUtils.OriginatorEnvelope) (int32, error) {
payerAddress, err := env.UnsignedOriginatorEnvelope.PayerEnvelope.RecoverSigner()
if err != nil {
return 0, err
}

q := queries.New(s.store)
payerId, err := q.FindOrCreatePayer(s.ctx, payerAddress.Hex())
if err != nil {
return 0, err
}

return payerId, nil
}
15 changes: 15 additions & 0 deletions pkg/testutils/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,18 @@ func InsertGatewayEnvelopes(
}
}
}

func CreatePayer(t *testing.T, db *sql.DB, address ...string) int32 {
q := queries.New(db)
var payerAddress string
if len(address) > 0 {
payerAddress = address[0]
} else {
payerAddress = RandomString(42)
}

id, err := q.FindOrCreatePayer(context.Background(), payerAddress)
require.NoError(t, err)

return id
}

0 comments on commit 9eb284b

Please sign in to comment.