diff --git a/pkg/authn/claims.go b/pkg/authn/claims.go index 02728586..e5ab3e9c 100644 --- a/pkg/authn/claims.go +++ b/pkg/authn/claims.go @@ -6,24 +6,36 @@ import ( "github.com/golang-jwt/jwt/v5" ) -const ( - // XMTPD_COMPATIBLE_VERSION_CONSTRAINT major or minor serverVersion bumps indicate backwards incompatible changes - XMTPD_COMPATIBLE_VERSION_CONSTRAINT = "^0.2" -) - type XmtpdClaims struct { Version *semver.Version `json:"version,omitempty"` jwt.RegisteredClaims } +type ClaimValidator struct { + constraint semver.Constraints +} -func ValidateVersionClaimIsCompatible(claims *XmtpdClaims) error { - if claims.Version == nil { - return nil +func NewClaimValidator(serverVersion *semver.Version) (*ClaimValidator, error) { + if serverVersion == nil { + return nil, fmt.Errorf("serverVersion is nil") + } + sanitizedVersion, err := serverVersion.SetPrerelease("") + if err != nil { + return nil, err } - c, err := semver.NewConstraint(XMTPD_COMPATIBLE_VERSION_CONSTRAINT) + // https://github.com/Masterminds/semver?tab=readme-ov-file#caret-range-comparisons-major + constraintStr := fmt.Sprintf("^%s", sanitizedVersion.String()) + + constraint, err := semver.NewConstraint(constraintStr) if err != nil { - return err + return nil, err + } + + return &ClaimValidator{constraint: *constraint}, nil +} +func (cv *ClaimValidator) ValidateVersionClaimIsCompatible(claims *XmtpdClaims) error { + if claims.Version == nil { + return nil } // SemVer implementations generally do not consider pre-releases to be valid next releases @@ -33,7 +45,7 @@ func ValidateVersionClaimIsCompatible(claims *XmtpdClaims) error { if err != nil { return err } - if ok := c.Check(&sanitizedVersion); !ok { + if ok := cv.constraint.Check(&sanitizedVersion); !ok { return fmt.Errorf("serverVersion %s is not compatible", *claims.Version) } diff --git a/pkg/authn/claims_test.go b/pkg/authn/claims_test.go index 058365f1..a85ba906 100644 --- a/pkg/authn/claims_test.go +++ b/pkg/authn/claims_test.go @@ -1,39 +1,14 @@ -package authn +package authn_test import ( - "bytes" "github.com/Masterminds/semver/v3" "github.com/stretchr/testify/require" + "github.com/xmtp/xmtpd/pkg/authn" "github.com/xmtp/xmtpd/pkg/registry" "github.com/xmtp/xmtpd/pkg/testutils" - "os/exec" - "strings" "testing" ) -func getLatestTag(t *testing.T) string { - // Prepare the command - cmd := exec.Command("git", "describe", "--tags", "--abbrev=0") - - // Capture the output - var out bytes.Buffer - cmd.Stdout = &out - cmd.Stderr = &out - - // Run the command - err := cmd.Run() - require.NoError(t, err, out.String()) - return strings.TrimSpace(out.String()) -} - -func getLatestVersion(t *testing.T) semver.Version { - tag := getLatestTag(t) - v, err := semver.NewVersion(tag) - require.NoError(t, err) - - return *v -} - func newVersionNoError(t *testing.T, version string, pre string, meta string) semver.Version { v, err := semver.NewVersion(version) require.NoError(t, err) @@ -47,7 +22,7 @@ func newVersionNoError(t *testing.T, version string, pre string, meta string) se return vmeta } -func TestClaimsVerifierNoVersion(t *testing.T) { +func TestClaimsNoVersion(t *testing.T) { signerPrivateKey := testutils.RandomPrivateKey(t) tests := []struct { @@ -60,9 +35,17 @@ func TestClaimsVerifierNoVersion(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tokenFactory := NewTokenFactory(signerPrivateKey, uint32(SIGNER_NODE_ID), tt.version) - - verifier, nodeRegistry := buildVerifier(t, uint32(VERIFIER_NODE_ID)) + tokenFactory := authn.NewTokenFactory( + signerPrivateKey, + uint32(SIGNER_NODE_ID), + tt.version, + ) + + verifier, nodeRegistry := buildVerifier( + t, + uint32(VERIFIER_NODE_ID), + testutils.GetLatestVersion(t), + ) nodeRegistry.EXPECT().GetNode(uint32(SIGNER_NODE_ID)).Return(®istry.Node{ SigningKey: &signerPrivateKey.PublicKey, NodeID: uint32(SIGNER_NODE_ID), @@ -80,10 +63,14 @@ func TestClaimsVerifierNoVersion(t *testing.T) { } } -func TestClaimsVerifier(t *testing.T) { +func TestClaimsVariousVersions(t *testing.T) { signerPrivateKey := testutils.RandomPrivateKey(t) - currentVersion := getLatestVersion(t) + currentVersion := *testutils.GetLatestVersion(t) + version013, err := semver.NewVersion("0.1.3") + require.NoError(t, err) + version014, err := semver.NewVersion("0.1.4") + require.NoError(t, err) tests := []struct { name string @@ -94,7 +81,6 @@ func TestClaimsVerifier(t *testing.T) { {"next-patch-version", currentVersion.IncPatch(), false}, {"next-minor-version", currentVersion.IncMinor(), true}, {"next-major-version", currentVersion.IncMajor(), true}, - {"last-supported-version", newVersionNoError(t, currentVersion.String(), "", ""), false}, { "with-prerelease-version", newVersionNoError(t, currentVersion.String(), "17-gdeadbeef", ""), @@ -105,13 +91,87 @@ func TestClaimsVerifier(t *testing.T) { newVersionNoError(t, currentVersion.String(), "", "branch-dev"), false, }, + {"known-0.1.3", *version013, true}, + {"known-0.1.4", *version014, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tokenFactory := authn.NewTokenFactory( + signerPrivateKey, + uint32(SIGNER_NODE_ID), + &tt.version, + ) + + verifier, nodeRegistry := buildVerifier( + t, + uint32(VERIFIER_NODE_ID), + testutils.GetLatestVersion(t), + ) + nodeRegistry.EXPECT().GetNode(uint32(SIGNER_NODE_ID)).Return(®istry.Node{ + SigningKey: &signerPrivateKey.PublicKey, + NodeID: uint32(SIGNER_NODE_ID), + }, nil) + + token, err := tokenFactory.CreateToken(uint32(VERIFIER_NODE_ID)) + require.NoError(t, err) + _, verificationError := verifier.Verify(token.SignedString) + if tt.wantErr { + require.Error(t, verificationError) + } else { + require.NoError(t, verificationError) + } + }) + } +} + +func TestClaimsValidator(t *testing.T) { + signerPrivateKey := testutils.RandomPrivateKey(t) + + currentVersion := *testutils.GetLatestVersion(t) + + tests := []struct { + name string + version semver.Version + serverVersion semver.Version + wantErr bool + }{ + {"current-version", currentVersion, currentVersion, false}, + { + "with-prerelease-version", + currentVersion, + newVersionNoError(t, currentVersion.String(), "17-gdeadbeef", ""), + false, + }, + { + "with-metadata-version", + currentVersion, + newVersionNoError(t, currentVersion.String(), "", "branch-dev"), + false, + }, + { + "future-major-rejects-us", + currentVersion, + currentVersion.IncMajor(), + true, + }, + { + "future-patch-accepts-us", + currentVersion, + currentVersion.IncPatch(), + true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tokenFactory := NewTokenFactory(signerPrivateKey, uint32(SIGNER_NODE_ID), &tt.version) + tokenFactory := authn.NewTokenFactory( + signerPrivateKey, + uint32(SIGNER_NODE_ID), + &tt.version, + ) - verifier, nodeRegistry := buildVerifier(t, uint32(VERIFIER_NODE_ID)) + verifier, nodeRegistry := buildVerifier(t, uint32(VERIFIER_NODE_ID), &tt.serverVersion) nodeRegistry.EXPECT().GetNode(uint32(SIGNER_NODE_ID)).Return(®istry.Node{ SigningKey: &signerPrivateKey.PublicKey, NodeID: uint32(SIGNER_NODE_ID), diff --git a/pkg/authn/signingMethod_test.go b/pkg/authn/signingMethod_test.go index ef566d88..f2f2f9cd 100644 --- a/pkg/authn/signingMethod_test.go +++ b/pkg/authn/signingMethod_test.go @@ -1,6 +1,7 @@ -package authn +package authn_test import ( + "github.com/xmtp/xmtpd/pkg/authn" "testing" "github.com/golang-jwt/jwt/v5" @@ -12,7 +13,7 @@ func TestSign(t *testing.T) { privateKey := testutils.RandomPrivateKey(t) publicKey := privateKey.Public() - method := &SigningMethodSecp256k1{} + method := &authn.SigningMethodSecp256k1{} signingString := "test" signature, err := method.Sign(signingString, privateKey) @@ -28,7 +29,7 @@ func TestWrongSigner(t *testing.T) { badPrivateKey := testutils.RandomPrivateKey(t) badPublicKey := badPrivateKey.Public() - method := &SigningMethodSecp256k1{} + method := &authn.SigningMethodSecp256k1{} signingString := "test" signature, err := method.Sign(signingString, goodPrivateKey) @@ -42,7 +43,7 @@ func TestWrongSigningString(t *testing.T) { privateKey := testutils.RandomPrivateKey(t) publicKey := privateKey.Public() - method := &SigningMethodSecp256k1{} + method := &authn.SigningMethodSecp256k1{} signingString := "test" signature, err := method.Sign(signingString, privateKey) @@ -57,7 +58,7 @@ func TestFullJWT(t *testing.T) { claims := &jwt.RegisteredClaims{ Issuer: "test", } - token := jwt.NewWithClaims(&SigningMethodSecp256k1{}, claims) + token := jwt.NewWithClaims(&authn.SigningMethodSecp256k1{}, claims) // Sign and get the complete encoded token as a string using the secret tokenString, err := token.SignedString(privateKey) diff --git a/pkg/authn/tokenFactory_test.go b/pkg/authn/tokenFactory_test.go index 3ebd991b..3e407fab 100644 --- a/pkg/authn/tokenFactory_test.go +++ b/pkg/authn/tokenFactory_test.go @@ -1,7 +1,8 @@ -package authn +package authn_test import ( "github.com/Masterminds/semver/v3" + "github.com/xmtp/xmtpd/pkg/authn" "testing" "time" @@ -11,7 +12,7 @@ import ( func TestTokenFactory(t *testing.T) { privateKey := testutils.RandomPrivateKey(t) - factory := NewTokenFactory(privateKey, 100, nil) + factory := authn.NewTokenFactory(privateKey, 100, nil) token, err := factory.CreateToken(200) require.NoError(t, err) @@ -38,7 +39,7 @@ func TestTokenFactoryWithVersion(t *testing.T) { t.Run(tt.name, func(t *testing.T) { version, err := semver.NewVersion(tt.version) require.NoError(t, err) - factory := NewTokenFactory(privateKey, 100, version) + factory := authn.NewTokenFactory(privateKey, 100, version) token, err := factory.CreateToken(200) require.NoError(t, err) diff --git a/pkg/authn/verifier.go b/pkg/authn/verifier.go index bc074882..405a01e0 100644 --- a/pkg/authn/verifier.go +++ b/pkg/authn/verifier.go @@ -2,6 +2,7 @@ package authn import ( "fmt" + "github.com/Masterminds/semver/v3" "strconv" "time" @@ -15,16 +16,26 @@ const ( ) type RegistryVerifier struct { - registry registry.NodeRegistry - myNodeID uint32 + registry registry.NodeRegistry + myNodeID uint32 + validator ClaimValidator } /* A RegistryVerifier connects to the NodeRegistry and verifies JWTs against the registered public keys based on the JWT's subject field */ -func NewRegistryVerifier(registry registry.NodeRegistry, myNodeID uint32) *RegistryVerifier { - return &RegistryVerifier{registry: registry, myNodeID: myNodeID} +func NewRegistryVerifier( + registry registry.NodeRegistry, + myNodeID uint32, + serverVersion *semver.Version, +) (*RegistryVerifier, error) { + validator, err := NewClaimValidator(serverVersion) + if err != nil { + return nil, err + } + + return &RegistryVerifier{registry: registry, myNodeID: myNodeID, validator: *validator}, nil } func (v *RegistryVerifier) Verify(tokenString string) (uint32, error) { @@ -104,7 +115,7 @@ func (v *RegistryVerifier) validateClaims(token *jwt.Token) error { return fmt.Errorf("invalid token") } - return ValidateVersionClaimIsCompatible(claims) + return v.validator.ValidateVersionClaimIsCompatible(claims) } // Parse the subject claim of the JWT and return the node ID as a uint32 diff --git a/pkg/authn/verifier_test.go b/pkg/authn/verifier_test.go index 3a0f4378..3c223cb3 100644 --- a/pkg/authn/verifier_test.go +++ b/pkg/authn/verifier_test.go @@ -1,8 +1,10 @@ -package authn +package authn_test import ( "crypto/ecdsa" "errors" + "github.com/Masterminds/semver/v3" + "github.com/xmtp/xmtpd/pkg/authn" "strconv" "testing" "time" @@ -21,10 +23,15 @@ const ( func buildVerifier( t *testing.T, - verifierNodeID uint32, -) (*RegistryVerifier, *registryMocks.MockNodeRegistry) { + verifierNodeID uint32, version *semver.Version, +) (*authn.RegistryVerifier, *registryMocks.MockNodeRegistry) { mockRegistry := registryMocks.NewMockNodeRegistry(t) - verifier := NewRegistryVerifier(mockRegistry, verifierNodeID) + verifier, err := authn.NewRegistryVerifier( + mockRegistry, + verifierNodeID, + version, + ) + require.NoError(t, err) return verifier, mockRegistry } @@ -37,7 +44,7 @@ func buildJwt( issuedAt time.Time, expiresAt time.Time, ) string { - token := jwt.NewWithClaims(&SigningMethodSecp256k1{}, &jwt.RegisteredClaims{ + token := jwt.NewWithClaims(&authn.SigningMethodSecp256k1{}, &jwt.RegisteredClaims{ Subject: strconv.Itoa(int(signerNodeID)), Audience: []string{strconv.Itoa(int(verifierNodeID))}, ExpiresAt: jwt.NewNumericDate(expiresAt), @@ -53,9 +60,13 @@ func buildJwt( func TestVerifier(t *testing.T) { signerPrivateKey := testutils.RandomPrivateKey(t) - tokenFactory := NewTokenFactory(signerPrivateKey, uint32(SIGNER_NODE_ID), nil) + tokenFactory := authn.NewTokenFactory(signerPrivateKey, uint32(SIGNER_NODE_ID), nil) - verifier, nodeRegistry := buildVerifier(t, uint32(VERIFIER_NODE_ID)) + verifier, nodeRegistry := buildVerifier( + t, + uint32(VERIFIER_NODE_ID), + testutils.GetLatestVersion(t), + ) nodeRegistry.EXPECT().GetNode(uint32(SIGNER_NODE_ID)).Return(®istry.Node{ SigningKey: &signerPrivateKey.PublicKey, NodeID: uint32(SIGNER_NODE_ID), @@ -79,9 +90,13 @@ func TestVerifier(t *testing.T) { func TestWrongAudience(t *testing.T) { signerPrivateKey := testutils.RandomPrivateKey(t) - tokenFactory := NewTokenFactory(signerPrivateKey, uint32(SIGNER_NODE_ID), nil) + tokenFactory := authn.NewTokenFactory(signerPrivateKey, uint32(SIGNER_NODE_ID), nil) - verifier, nodeRegistry := buildVerifier(t, uint32(VERIFIER_NODE_ID)) + verifier, nodeRegistry := buildVerifier( + t, + uint32(VERIFIER_NODE_ID), + testutils.GetLatestVersion(t), + ) nodeRegistry.EXPECT().GetNode(uint32(SIGNER_NODE_ID)).Return(®istry.Node{ SigningKey: &signerPrivateKey.PublicKey, NodeID: uint32(SIGNER_NODE_ID), @@ -97,9 +112,13 @@ func TestWrongAudience(t *testing.T) { func TestUnknownNode(t *testing.T) { signerPrivateKey := testutils.RandomPrivateKey(t) - tokenFactory := NewTokenFactory(signerPrivateKey, uint32(SIGNER_NODE_ID), nil) + tokenFactory := authn.NewTokenFactory(signerPrivateKey, uint32(SIGNER_NODE_ID), nil) - verifier, nodeRegistry := buildVerifier(t, uint32(VERIFIER_NODE_ID)) + verifier, nodeRegistry := buildVerifier( + t, + uint32(VERIFIER_NODE_ID), + testutils.GetLatestVersion(t), + ) nodeRegistry.EXPECT().GetNode(uint32(SIGNER_NODE_ID)).Return(nil, errors.New("node not found")) token, err := tokenFactory.CreateToken(uint32(VERIFIER_NODE_ID)) @@ -112,9 +131,13 @@ func TestUnknownNode(t *testing.T) { func TestWrongPublicKey(t *testing.T) { signerPrivateKey := testutils.RandomPrivateKey(t) - tokenFactory := NewTokenFactory(signerPrivateKey, uint32(SIGNER_NODE_ID), nil) + tokenFactory := authn.NewTokenFactory(signerPrivateKey, uint32(SIGNER_NODE_ID), nil) - verifier, nodeRegistry := buildVerifier(t, uint32(VERIFIER_NODE_ID)) + verifier, nodeRegistry := buildVerifier( + t, + uint32(VERIFIER_NODE_ID), + testutils.GetLatestVersion(t), + ) wrongPublicKey := testutils.RandomPrivateKey(t).PublicKey nodeRegistry.EXPECT().GetNode(uint32(SIGNER_NODE_ID)).Return(®istry.Node{ @@ -132,7 +155,11 @@ func TestWrongPublicKey(t *testing.T) { func TestExpiredToken(t *testing.T) { signerPrivateKey := testutils.RandomPrivateKey(t) - verifier, nodeRegistry := buildVerifier(t, uint32(VERIFIER_NODE_ID)) + verifier, nodeRegistry := buildVerifier( + t, + uint32(VERIFIER_NODE_ID), + testutils.GetLatestVersion(t), + ) nodeRegistry.EXPECT().GetNode(uint32(SIGNER_NODE_ID)).Return(®istry.Node{ SigningKey: &signerPrivateKey.PublicKey, NodeID: uint32(SIGNER_NODE_ID), @@ -154,7 +181,11 @@ func TestExpiredToken(t *testing.T) { func TestTokenDurationTooLong(t *testing.T) { signerPrivateKey := testutils.RandomPrivateKey(t) - verifier, nodeRegistry := buildVerifier(t, uint32(VERIFIER_NODE_ID)) + verifier, nodeRegistry := buildVerifier( + t, + uint32(VERIFIER_NODE_ID), + testutils.GetLatestVersion(t), + ) nodeRegistry.EXPECT().GetNode(uint32(SIGNER_NODE_ID)).Return(®istry.Node{ SigningKey: &signerPrivateKey.PublicKey, NodeID: uint32(SIGNER_NODE_ID), @@ -176,7 +207,11 @@ func TestTokenDurationTooLong(t *testing.T) { func TestTokenClockSkew(t *testing.T) { signerPrivateKey := testutils.RandomPrivateKey(t) - verifier, nodeRegistry := buildVerifier(t, uint32(VERIFIER_NODE_ID)) + verifier, nodeRegistry := buildVerifier( + t, + uint32(VERIFIER_NODE_ID), + testutils.GetLatestVersion(t), + ) nodeRegistry.EXPECT().GetNode(uint32(SIGNER_NODE_ID)).Return(®istry.Node{ SigningKey: &signerPrivateKey.PublicKey, NodeID: uint32(SIGNER_NODE_ID), diff --git a/pkg/server/server.go b/pkg/server/server.go index 439a5cf0..fba43689 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -136,7 +136,9 @@ func NewReplicationServer( s, writerDB, blockchainPublisher, - listenAddress) + listenAddress, + serverVersion, + ) if err != nil { return nil, err } @@ -170,6 +172,7 @@ func startAPIServer( writerDB *sql.DB, blockchainPublisher blockchain.IBlockchainPublisher, listenAddress string, + serverVersion *semver.Version, ) error { var err error @@ -242,7 +245,14 @@ func startAPIServer( var jwtVerifier authn.JWTVerifier if s.nodeRegistry != nil && s.registrant != nil { - jwtVerifier = authn.NewRegistryVerifier(s.nodeRegistry, s.registrant.NodeID()) + jwtVerifier, err = authn.NewRegistryVerifier( + s.nodeRegistry, + s.registrant.NodeID(), + serverVersion, + ) + if err != nil { + return err + } } s.apiServer, err = api.NewAPIServer( diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index 3fcee8d3..26a1b778 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -75,7 +75,7 @@ func NewTestServer( Enable: true, PrivateKey: hex.EncodeToString(crypto.FromECDSA(privateKey)), }, - }, registry, db, messagePublisher, fmt.Sprintf("localhost:%d", port), nil) + }, registry, db, messagePublisher, fmt.Sprintf("localhost:%d", port), testutils.GetLatestVersion(t)) require.NoError(t, err) return server diff --git a/pkg/testutils/api/api.go b/pkg/testutils/api/api.go index e829c267..d782816b 100644 --- a/pkg/testutils/api/api.go +++ b/pkg/testutils/api/api.go @@ -115,7 +115,12 @@ func NewTestAPIServer(t *testing.T) (*api.ApiServer, *sql.DB, ApiServerMocks, fu mockMessagePublisher := blockchain.NewMockIBlockchainPublisher(t) mockValidationService := mlsvalidateMocks.NewMockMLSValidationService(t) - jwtVerifier := authn.NewRegistryVerifier(mockRegistry, registrant.NodeID()) + jwtVerifier, err := authn.NewRegistryVerifier( + mockRegistry, + registrant.NodeID(), + testutils.GetLatestVersion(t), + ) + require.NoError(t, err) serviceRegistrationFunc := func(grpcServer *grpc.Server) error { replicationService, err := message.NewReplicationApiService( diff --git a/pkg/testutils/versioning.go b/pkg/testutils/versioning.go new file mode 100644 index 00000000..d4149d65 --- /dev/null +++ b/pkg/testutils/versioning.go @@ -0,0 +1,33 @@ +package testutils + +import ( + "bytes" + "github.com/Masterminds/semver/v3" + "github.com/stretchr/testify/require" + "os/exec" + "strings" + "testing" +) + +func GetLatestTag(t *testing.T) string { + // Prepare the command + cmd := exec.Command("git", "describe", "--tags", "--abbrev=0") + + // Capture the output + var out bytes.Buffer + cmd.Stdout = &out + cmd.Stderr = &out + + // Run the command + err := cmd.Run() + require.NoError(t, err, out.String()) + return strings.TrimSpace(out.String()) +} + +func GetLatestVersion(t *testing.T) *semver.Version { + tag := GetLatestTag(t) + v, err := semver.NewVersion(tag) + require.NoError(t, err) + + return v +}