Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Guarantee your own commits via a blocking payer #441

Merged
merged 10 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 180 additions & 0 deletions pkg/api/metadata/cursor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
package metadata_test

import (
"context"
"database/sql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/metadata_api"
"testing"
"time"

"github.com/xmtp/xmtpd/pkg/api/message"
"github.com/xmtp/xmtpd/pkg/db/queries"
"github.com/xmtp/xmtpd/pkg/testutils"
testUtilsApi "github.com/xmtp/xmtpd/pkg/testutils/api"
envelopeTestUtils "github.com/xmtp/xmtpd/pkg/testutils/envelopes"
"github.com/xmtp/xmtpd/pkg/topic"
)

var (
topicA = topic.NewTopic(topic.TOPIC_KIND_GROUP_MESSAGES_V1, []byte("topicA")).Bytes()
topicB = topic.NewTopic(topic.TOPIC_KIND_GROUP_MESSAGES_V1, []byte("topicB")).Bytes()
topicC = topic.NewTopic(topic.TOPIC_KIND_GROUP_MESSAGES_V1, []byte("topicC")).Bytes()

Check failure on line 23 in pkg/api/metadata/cursor_test.go

View workflow job for this annotation

GitHub Actions / Lint-Go

var `topicC` is unused (unused)
)
var allRows []queries.InsertGatewayEnvelopeParams

func setupTest(
t *testing.T,
) (metadata_api.MetadataApiClient, *sql.DB, testUtilsApi.ApiServerMocks, func()) {
allRows = []queries.InsertGatewayEnvelopeParams{
// Initial rows
{
OriginatorNodeID: 1,
OriginatorSequenceID: 1,
Topic: topicA,
OriginatorEnvelope: testutils.Marshal(
t,
envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 1, topicA),
),
},
{
OriginatorNodeID: 2,
OriginatorSequenceID: 1,
Topic: topicA,
OriginatorEnvelope: testutils.Marshal(
t,
envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 2, 1, topicA),
),
},
// Later rows
{
OriginatorNodeID: 1,
OriginatorSequenceID: 2,
Topic: topicB,
OriginatorEnvelope: testutils.Marshal(
t,
envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 2, topicB),
),
},
{
OriginatorNodeID: 2,
OriginatorSequenceID: 2,
Topic: topicB,
OriginatorEnvelope: testutils.Marshal(
t,
envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 2, 2, topicB),
),
},
{
OriginatorNodeID: 1,
OriginatorSequenceID: 3,
Topic: topicA,
OriginatorEnvelope: testutils.Marshal(
t,
envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 3, topicA),
),
},
}

return testUtilsApi.NewTestMetadataAPIClient(t)
}

func insertInitialRows(t *testing.T, store *sql.DB) {
testutils.InsertGatewayEnvelopes(t, store, []queries.InsertGatewayEnvelopeParams{
allRows[0], allRows[1],
})
time.Sleep(message.SubscribeWorkerPollTime + 100*time.Millisecond)
}

func insertAdditionalRows(t *testing.T, store *sql.DB, notifyChan ...chan bool) {
testutils.InsertGatewayEnvelopes(t, store, []queries.InsertGatewayEnvelopeParams{
allRows[2], allRows[3], allRows[4],
}, notifyChan...)
}

func TestGetCursorBasic(t *testing.T) {
client, db, _, cleanup := setupTest(t)
defer cleanup()
insertInitialRows(t, db)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

cursor, err := client.GetSyncCursor(ctx, &metadata_api.GetSyncCursorRequest{})

require.NoError(t, err)
require.NotNil(t, cursor)

expectedCursor := map[uint32]uint64{
1: 1,
2: 1,
}

require.Equal(t, expectedCursor, cursor.LatestSync.NodeIdToSequenceId)

insertAdditionalRows(t, db)
require.Eventually(t, func() bool {
expectedCursor := map[uint32]uint64{
1: 3,
2: 2,
}

cursor, err := client.GetSyncCursor(ctx, &metadata_api.GetSyncCursorRequest{})
if err != nil {
t.Logf("Error fetching sync cursor: %v", err)
return false
}
if cursor == nil {
t.Log("Cursor is nil")
return false
}

return assert.ObjectsAreEqual(expectedCursor, cursor.LatestSync.NodeIdToSequenceId)
}, 500*time.Millisecond, 50*time.Millisecond)
}

func TestSubscribeSyncCursorBasic(t *testing.T) {
client, db, _, cleanup := setupTest(t)
defer cleanup()
insertInitialRows(t, db)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

stream, err := client.SubscribeSyncCursor(ctx, &metadata_api.GetSyncCursorRequest{})
require.NoError(t, err)
require.NotNil(t, stream)

firstUpdate, err := stream.Recv()
require.NoError(t, err)
require.NotNil(t, firstUpdate)

expectedCursor := map[uint32]uint64{
1: 1,
2: 1,
}

require.Equal(t, expectedCursor, firstUpdate.LatestSync.NodeIdToSequenceId)

insertAdditionalRows(t, db)

require.Eventually(t, func() bool {
expectedCursor := map[uint32]uint64{
1: 3,
2: 2,
}

update, err := stream.Recv()
if err != nil {
t.Logf("Error receiving sync cursor update: %v", err)
return false
}
if update == nil {
t.Log("Received nil update from stream")
return false
}

return assert.ObjectsAreEqual(expectedCursor, update.LatestSync.NodeIdToSequenceId)
}, 500*time.Millisecond, 50*time.Millisecond)
}
71 changes: 71 additions & 0 deletions pkg/api/payer/nodeCursorTracker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package payer

import (
"context"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/metadata_api"
"go.uber.org/zap"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

type NodeCursorTracker struct {
ctx context.Context
log *zap.Logger
clientManager *ClientManager
}

func NewNodeCursorTracker(ctx context.Context,
log *zap.Logger, clientManager *ClientManager) *NodeCursorTracker {
return &NodeCursorTracker{ctx: ctx, log: log, clientManager: clientManager}
}

func (ct *NodeCursorTracker) BlockUntilDesiredCursorReached(
ctx context.Context,
nodeId uint32,
desiredOriginatorId uint32,
desiredSequenceId uint64,
) error {
// TODO(mkysel) ideally we wouldn't create and tear down the stream for every request

conn, err := ct.clientManager.GetClient(nodeId)
if err != nil {
return err
}
client := metadata_api.NewMetadataApiClient(conn)
stream, err := client.SubscribeSyncCursor(ctx, &metadata_api.GetSyncCursorRequest{})
if err != nil {
return err
}
for {
select {
case <-ct.ctx.Done():
// server is shutting down
return status.Errorf(codes.Canceled, "node terminated. Cancelled wait for cursor")
case <-ctx.Done():
// client has shut down
return nil
default:
resp, err := stream.Recv()
if err != nil {
if status.Code(err) == codes.Canceled {
return nil
}
// TODO(mkysel): proper handling of failures
return err
}
if err != nil || resp == nil || resp.LatestSync == nil {
return status.Errorf(codes.Internal, "error getting node cursor: %v", err)
}
derefMap := resp.LatestSync.NodeIdToSequenceId
seqId, exists := derefMap[desiredOriginatorId]
if !exists {
continue // Wait for the originator ID to appear
}

// Check if the sequence ID has reached the desired value
if seqId >= desiredSequenceId {
return nil // Desired state achieved
}
}
}
}
20 changes: 19 additions & 1 deletion pkg/api/payer/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type Service struct {
clientManager *ClientManager
blockchainPublisher blockchain.IBlockchainPublisher
payerPrivateKey *ecdsa.PrivateKey
nodeCursorTracker *NodeCursorTracker
}

func NewPayerApiService(
Expand All @@ -41,12 +42,14 @@ func NewPayerApiService(
payerPrivateKey *ecdsa.PrivateKey,
blockchainPublisher blockchain.IBlockchainPublisher,
) (*Service, error) {
clientManager := NewClientManager(log, registry)
return &Service{
ctx: ctx,
log: log,
clientManager: NewClientManager(log, registry),
clientManager: clientManager,
payerPrivateKey: payerPrivateKey,
blockchainPublisher: blockchainPublisher,
nodeCursorTracker: NewNodeCursorTracker(ctx, log, clientManager),
}, nil
}

Expand Down Expand Up @@ -174,6 +177,9 @@ func (s *Service) publishToBlockchain(
) (*envelopesProto.OriginatorEnvelope, error) {
targetTopic := clientEnvelope.TargetTopic()
identifier := targetTopic.Identifier()
expectedNode := clientEnvelope.Aad().TargetOriginator
desiredOriginatorId := uint32(1) //TODO: determine this from the chain
desiredSequenceId := uint64(0)
kind := targetTopic.Kind()

// Get the group ID as [32]byte
Expand Down Expand Up @@ -214,6 +220,7 @@ func (s *Service) publishToBlockchain(
logMessage.SequenceId,
logMessage.Message,
)
desiredSequenceId = logMessage.SequenceId

case topic.TOPIC_KIND_IDENTITY_UPDATES_V1:
var logMessage *identityupdates.IdentityUpdatesIdentityUpdateCreated
Expand All @@ -230,6 +237,8 @@ func (s *Service) publishToBlockchain(
logMessage.SequenceId,
logMessage.Update,
)
desiredSequenceId = logMessage.SequenceId

default:
return nil, status.Errorf(
codes.InvalidArgument,
Expand All @@ -246,6 +255,15 @@ func (s *Service) publishToBlockchain(
err,
)
}
err = s.nodeCursorTracker.BlockUntilDesiredCursorReached(
ctx,
expectedNode,
desiredOriginatorId,
desiredSequenceId,
)
if err != nil {
return nil, err
}

return &envelopesProto.OriginatorEnvelope{
UnsignedOriginatorEnvelope: unsignedBytes,
Expand Down
36 changes: 36 additions & 0 deletions pkg/testutils/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@ import (
"github.com/stretchr/testify/require"
"github.com/xmtp/xmtpd/pkg/api"
"github.com/xmtp/xmtpd/pkg/api/message"
"github.com/xmtp/xmtpd/pkg/api/metadata"
"github.com/xmtp/xmtpd/pkg/api/payer"
"github.com/xmtp/xmtpd/pkg/authn"
"github.com/xmtp/xmtpd/pkg/db/queries"
"github.com/xmtp/xmtpd/pkg/mocks/blockchain"
mlsvalidateMocks "github.com/xmtp/xmtpd/pkg/mocks/mlsvalidate"
mocks "github.com/xmtp/xmtpd/pkg/mocks/registry"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/metadata_api"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/payer_api"
"github.com/xmtp/xmtpd/pkg/registrant"
"github.com/xmtp/xmtpd/pkg/registry"
Expand Down Expand Up @@ -65,6 +67,25 @@ func NewPayerAPIClient(
}
}

func NewMetadataAPIClient(
t *testing.T,
ctx context.Context,
addr string,
) (metadata_api.MetadataApiClient, func()) {
dialAddr := fmt.Sprintf("passthrough://localhost/%s", addr)
conn, err := grpc.NewClient(
dialAddr,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultCallOptions(),
)
require.NoError(t, err)
client := metadata_api.NewMetadataApiClient(conn)
return client, func() {
err := conn.Close()
require.NoError(t, err)
}
}

type ApiServerMocks struct {
MockRegistry *mocks.MockNodeRegistry
MockValidationService *mlsvalidateMocks.MockMLSValidationService
Expand Down Expand Up @@ -117,6 +138,10 @@ func NewTestAPIServer(t *testing.T) (*api.ApiServer, *sql.DB, ApiServerMocks, fu
require.NoError(t, err)
payer_api.RegisterPayerApiServer(grpcServer, payerService)

metadataService, err := metadata.NewMetadataApiService(ctx, log, db)
require.NoError(t, err)
metadata_api.RegisterMetadataApiServer(grpcServer, metadataService)

return nil
}

Expand Down Expand Up @@ -153,3 +178,14 @@ func NewTestReplicationAPIClient(
svcCleanup()
}
}

func NewTestMetadataAPIClient(
t *testing.T,
) (metadata_api.MetadataApiClient, *sql.DB, ApiServerMocks, func()) {
svc, db, allMocks, svcCleanup := NewTestAPIServer(t)
client, clientCleanup := NewMetadataAPIClient(t, context.Background(), svc.Addr().String())
return client, db, allMocks, func() {
clientCleanup()
svcCleanup()
}
}
Loading