Skip to content

Commit

Permalink
Improve version checking system (#506)
Browse files Browse the repository at this point in the history
We were depending on a hardcoded value that does not automatically
update when we release.

Instead, determine the version internally.

Add more tests.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Introduced enhanced version validation to ensure authentication tokens
meet compatibility requirements.
  
- **Refactor**
- Consolidated version checking logic and improved error handling during
authentication processes.
  
- **Tests**
- Expanded and reorganized test suites to cover additional version
scenarios, including prerelease and future versions.
  
- **Chores**
- Updated testing utilities and restructured test packages to streamline
version retrieval and verification processes.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
mkysel authored Feb 19, 2025
1 parent a1f9ba5 commit 21f814f
Show file tree
Hide file tree
Showing 10 changed files with 248 additions and 80 deletions.
34 changes: 23 additions & 11 deletions pkg/authn/claims.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,36 @@ import (
"github.com/golang-jwt/jwt/v5"
)

const (
// XMTPD_COMPATIBLE_VERSION_CONSTRAINT major or minor serverVersion bumps indicate backwards incompatible changes
XMTPD_COMPATIBLE_VERSION_CONSTRAINT = "^0.2"
)

type XmtpdClaims struct {
Version *semver.Version `json:"version,omitempty"`
jwt.RegisteredClaims
}
type ClaimValidator struct {
constraint semver.Constraints
}

func ValidateVersionClaimIsCompatible(claims *XmtpdClaims) error {
if claims.Version == nil {
return nil
func NewClaimValidator(serverVersion *semver.Version) (*ClaimValidator, error) {
if serverVersion == nil {
return nil, fmt.Errorf("serverVersion is nil")
}
sanitizedVersion, err := serverVersion.SetPrerelease("")
if err != nil {
return nil, err
}

c, err := semver.NewConstraint(XMTPD_COMPATIBLE_VERSION_CONSTRAINT)
// https://github.com/Masterminds/semver?tab=readme-ov-file#caret-range-comparisons-major
constraintStr := fmt.Sprintf("^%s", sanitizedVersion.String())

constraint, err := semver.NewConstraint(constraintStr)
if err != nil {
return err
return nil, err
}

return &ClaimValidator{constraint: *constraint}, nil
}
func (cv *ClaimValidator) ValidateVersionClaimIsCompatible(claims *XmtpdClaims) error {
if claims.Version == nil {
return nil
}

// SemVer implementations generally do not consider pre-releases to be valid next releases
Expand All @@ -33,7 +45,7 @@ func ValidateVersionClaimIsCompatible(claims *XmtpdClaims) error {
if err != nil {
return err
}
if ok := c.Check(&sanitizedVersion); !ok {
if ok := cv.constraint.Check(&sanitizedVersion); !ok {
return fmt.Errorf("serverVersion %s is not compatible", *claims.Version)
}

Expand Down
132 changes: 96 additions & 36 deletions pkg/authn/claims_test.go
Original file line number Diff line number Diff line change
@@ -1,39 +1,14 @@
package authn
package authn_test

import (
"bytes"
"github.com/Masterminds/semver/v3"
"github.com/stretchr/testify/require"
"github.com/xmtp/xmtpd/pkg/authn"
"github.com/xmtp/xmtpd/pkg/registry"
"github.com/xmtp/xmtpd/pkg/testutils"
"os/exec"
"strings"
"testing"
)

func getLatestTag(t *testing.T) string {
// Prepare the command
cmd := exec.Command("git", "describe", "--tags", "--abbrev=0")

// Capture the output
var out bytes.Buffer
cmd.Stdout = &out
cmd.Stderr = &out

// Run the command
err := cmd.Run()
require.NoError(t, err, out.String())
return strings.TrimSpace(out.String())
}

func getLatestVersion(t *testing.T) semver.Version {
tag := getLatestTag(t)
v, err := semver.NewVersion(tag)
require.NoError(t, err)

return *v
}

func newVersionNoError(t *testing.T, version string, pre string, meta string) semver.Version {
v, err := semver.NewVersion(version)
require.NoError(t, err)
Expand All @@ -47,7 +22,7 @@ func newVersionNoError(t *testing.T, version string, pre string, meta string) se
return vmeta
}

func TestClaimsVerifierNoVersion(t *testing.T) {
func TestClaimsNoVersion(t *testing.T) {
signerPrivateKey := testutils.RandomPrivateKey(t)

tests := []struct {
Expand All @@ -60,9 +35,17 @@ func TestClaimsVerifierNoVersion(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tokenFactory := NewTokenFactory(signerPrivateKey, uint32(SIGNER_NODE_ID), tt.version)

verifier, nodeRegistry := buildVerifier(t, uint32(VERIFIER_NODE_ID))
tokenFactory := authn.NewTokenFactory(
signerPrivateKey,
uint32(SIGNER_NODE_ID),
tt.version,
)

verifier, nodeRegistry := buildVerifier(
t,
uint32(VERIFIER_NODE_ID),
testutils.GetLatestVersion(t),
)
nodeRegistry.EXPECT().GetNode(uint32(SIGNER_NODE_ID)).Return(&registry.Node{
SigningKey: &signerPrivateKey.PublicKey,
NodeID: uint32(SIGNER_NODE_ID),
Expand All @@ -80,10 +63,14 @@ func TestClaimsVerifierNoVersion(t *testing.T) {
}
}

func TestClaimsVerifier(t *testing.T) {
func TestClaimsVariousVersions(t *testing.T) {
signerPrivateKey := testutils.RandomPrivateKey(t)

currentVersion := getLatestVersion(t)
currentVersion := *testutils.GetLatestVersion(t)
version013, err := semver.NewVersion("0.1.3")
require.NoError(t, err)
version014, err := semver.NewVersion("0.1.4")
require.NoError(t, err)

tests := []struct {
name string
Expand All @@ -94,7 +81,6 @@ func TestClaimsVerifier(t *testing.T) {
{"next-patch-version", currentVersion.IncPatch(), false},
{"next-minor-version", currentVersion.IncMinor(), true},
{"next-major-version", currentVersion.IncMajor(), true},
{"last-supported-version", newVersionNoError(t, currentVersion.String(), "", ""), false},
{
"with-prerelease-version",
newVersionNoError(t, currentVersion.String(), "17-gdeadbeef", ""),
Expand All @@ -105,13 +91,87 @@ func TestClaimsVerifier(t *testing.T) {
newVersionNoError(t, currentVersion.String(), "", "branch-dev"),
false,
},
{"known-0.1.3", *version013, true},
{"known-0.1.4", *version014, true},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tokenFactory := authn.NewTokenFactory(
signerPrivateKey,
uint32(SIGNER_NODE_ID),
&tt.version,
)

verifier, nodeRegistry := buildVerifier(
t,
uint32(VERIFIER_NODE_ID),
testutils.GetLatestVersion(t),
)
nodeRegistry.EXPECT().GetNode(uint32(SIGNER_NODE_ID)).Return(&registry.Node{
SigningKey: &signerPrivateKey.PublicKey,
NodeID: uint32(SIGNER_NODE_ID),
}, nil)

token, err := tokenFactory.CreateToken(uint32(VERIFIER_NODE_ID))
require.NoError(t, err)
_, verificationError := verifier.Verify(token.SignedString)
if tt.wantErr {
require.Error(t, verificationError)
} else {
require.NoError(t, verificationError)
}
})
}
}

func TestClaimsValidator(t *testing.T) {
signerPrivateKey := testutils.RandomPrivateKey(t)

currentVersion := *testutils.GetLatestVersion(t)

tests := []struct {
name string
version semver.Version
serverVersion semver.Version
wantErr bool
}{
{"current-version", currentVersion, currentVersion, false},
{
"with-prerelease-version",
currentVersion,
newVersionNoError(t, currentVersion.String(), "17-gdeadbeef", ""),
false,
},
{
"with-metadata-version",
currentVersion,
newVersionNoError(t, currentVersion.String(), "", "branch-dev"),
false,
},
{
"future-major-rejects-us",
currentVersion,
currentVersion.IncMajor(),
true,
},
{
"future-patch-accepts-us",
currentVersion,
currentVersion.IncPatch(),
true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tokenFactory := NewTokenFactory(signerPrivateKey, uint32(SIGNER_NODE_ID), &tt.version)
tokenFactory := authn.NewTokenFactory(
signerPrivateKey,
uint32(SIGNER_NODE_ID),
&tt.version,
)

verifier, nodeRegistry := buildVerifier(t, uint32(VERIFIER_NODE_ID))
verifier, nodeRegistry := buildVerifier(t, uint32(VERIFIER_NODE_ID), &tt.serverVersion)
nodeRegistry.EXPECT().GetNode(uint32(SIGNER_NODE_ID)).Return(&registry.Node{
SigningKey: &signerPrivateKey.PublicKey,
NodeID: uint32(SIGNER_NODE_ID),
Expand Down
11 changes: 6 additions & 5 deletions pkg/authn/signingMethod_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package authn
package authn_test

import (
"github.com/xmtp/xmtpd/pkg/authn"
"testing"

"github.com/golang-jwt/jwt/v5"
Expand All @@ -12,7 +13,7 @@ func TestSign(t *testing.T) {
privateKey := testutils.RandomPrivateKey(t)
publicKey := privateKey.Public()

method := &SigningMethodSecp256k1{}
method := &authn.SigningMethodSecp256k1{}

signingString := "test"
signature, err := method.Sign(signingString, privateKey)
Expand All @@ -28,7 +29,7 @@ func TestWrongSigner(t *testing.T) {
badPrivateKey := testutils.RandomPrivateKey(t)
badPublicKey := badPrivateKey.Public()

method := &SigningMethodSecp256k1{}
method := &authn.SigningMethodSecp256k1{}

signingString := "test"
signature, err := method.Sign(signingString, goodPrivateKey)
Expand All @@ -42,7 +43,7 @@ func TestWrongSigningString(t *testing.T) {
privateKey := testutils.RandomPrivateKey(t)
publicKey := privateKey.Public()

method := &SigningMethodSecp256k1{}
method := &authn.SigningMethodSecp256k1{}

signingString := "test"
signature, err := method.Sign(signingString, privateKey)
Expand All @@ -57,7 +58,7 @@ func TestFullJWT(t *testing.T) {
claims := &jwt.RegisteredClaims{
Issuer: "test",
}
token := jwt.NewWithClaims(&SigningMethodSecp256k1{}, claims)
token := jwt.NewWithClaims(&authn.SigningMethodSecp256k1{}, claims)

// Sign and get the complete encoded token as a string using the secret
tokenString, err := token.SignedString(privateKey)
Expand Down
7 changes: 4 additions & 3 deletions pkg/authn/tokenFactory_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package authn
package authn_test

import (
"github.com/Masterminds/semver/v3"
"github.com/xmtp/xmtpd/pkg/authn"
"testing"
"time"

Expand All @@ -11,7 +12,7 @@ import (

func TestTokenFactory(t *testing.T) {
privateKey := testutils.RandomPrivateKey(t)
factory := NewTokenFactory(privateKey, 100, nil)
factory := authn.NewTokenFactory(privateKey, 100, nil)

token, err := factory.CreateToken(200)
require.NoError(t, err)
Expand All @@ -38,7 +39,7 @@ func TestTokenFactoryWithVersion(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
version, err := semver.NewVersion(tt.version)
require.NoError(t, err)
factory := NewTokenFactory(privateKey, 100, version)
factory := authn.NewTokenFactory(privateKey, 100, version)

token, err := factory.CreateToken(200)
require.NoError(t, err)
Expand Down
21 changes: 16 additions & 5 deletions pkg/authn/verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package authn

import (
"fmt"
"github.com/Masterminds/semver/v3"
"strconv"
"time"

Expand All @@ -15,16 +16,26 @@ const (
)

type RegistryVerifier struct {
registry registry.NodeRegistry
myNodeID uint32
registry registry.NodeRegistry
myNodeID uint32
validator ClaimValidator
}

/*
A RegistryVerifier connects to the NodeRegistry and verifies JWTs against the registered public keys
based on the JWT's subject field
*/
func NewRegistryVerifier(registry registry.NodeRegistry, myNodeID uint32) *RegistryVerifier {
return &RegistryVerifier{registry: registry, myNodeID: myNodeID}
func NewRegistryVerifier(
registry registry.NodeRegistry,
myNodeID uint32,
serverVersion *semver.Version,
) (*RegistryVerifier, error) {
validator, err := NewClaimValidator(serverVersion)
if err != nil {
return nil, err
}

return &RegistryVerifier{registry: registry, myNodeID: myNodeID, validator: *validator}, nil
}

func (v *RegistryVerifier) Verify(tokenString string) (uint32, error) {
Expand Down Expand Up @@ -104,7 +115,7 @@ func (v *RegistryVerifier) validateClaims(token *jwt.Token) error {
return fmt.Errorf("invalid token")
}

return ValidateVersionClaimIsCompatible(claims)
return v.validator.ValidateVersionClaimIsCompatible(claims)
}

// Parse the subject claim of the JWT and return the node ID as a uint32
Expand Down
Loading

0 comments on commit 21f814f

Please sign in to comment.