Skip to content

Commit 811b7ec

Browse files
kennypKenny Parnell
and
Kenny Parnell
authored
Add ability to override where tokens are read from. (#221)
* Add ability to override where tokens are read from. * Allows for testing code that reads the token. * Fix lint errors. --------- Co-authored-by: Kenny Parnell <[email protected]>
1 parent 4958255 commit 811b7ec

File tree

5 files changed

+303
-23
lines changed

5 files changed

+303
-23
lines changed

dynoid/dynoid.go

+20-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"encoding/base64"
66
"encoding/json"
77
"fmt"
8+
"io/fs"
89
"log/slog"
910
"os"
1011
"path"
@@ -90,7 +91,7 @@ func (s *Subject) MarshalText() ([]byte, error) {
9091

9192
func (s *Subject) UnmarshalText(text []byte) error {
9293
if s == nil {
93-
*s = Subject{}
94+
return fmt.Errorf("cannot unmarshal to a nil pointer")
9495
}
9596

9697
sub := string(text)
@@ -142,11 +143,28 @@ func LocalTokenPath(audience string) string {
142143
return fmt.Sprintf("/etc/heroku/dyno-id/%s/token", audience)
143144
}
144145

146+
type osReader struct{}
147+
148+
func (*osReader) Open(name string) (fs.File, error) {
149+
return os.Open(name)
150+
}
151+
152+
func (*osReader) ReadFile(name string) ([]byte, error) {
153+
return os.ReadFile(name)
154+
}
155+
156+
// DefaultFS is used by [ReadLocal] and [ReadLocalToken] to retrieve tokens.
157+
//
158+
// By default they are retrieved via [os.Open] and [os.ReadFile].
159+
//
160+
// This is useful when testing code that uses DynoID.
161+
var DefaultFS fs.ReadFileFS = &osReader{}
162+
145163
// ReadLocal reads the local machines token for the given audience
146164
//
147165
// Suitable for passing as a bearer token
148166
func ReadLocal(audience string) (string, error) {
149-
rawToken, err := os.ReadFile(LocalTokenPath(audience))
167+
rawToken, err := DefaultFS.ReadFile(LocalTokenPath(audience))
150168
if err != nil {
151169
return "", err
152170
}

dynoid/dynoid_test.go

+79-7
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,90 @@
1-
package dynoid
1+
package dynoid_test
22

33
import (
44
"context"
55
"testing"
66

7+
"github.com/google/uuid"
8+
9+
"github.com/heroku/x/dynoid"
710
"github.com/heroku/x/dynoid/dynoidtest"
811
)
912

1013
func TestVerification(t *testing.T) {
11-
ctx, iss, err := dynoidtest.NewWithContext(context.Background())
14+
ctx, token := GenerateIDToken(t, "heroku")
15+
16+
verifier := dynoid.NewWithCallback("heroku", dynoid.AllowHerokuHost(dynoidtest.IssuerHost))
17+
18+
if _, err := verifier.Verify(ctx, token); err != nil {
19+
t.Error(err)
20+
}
21+
}
22+
23+
func TestMarshalUnmarshal(t *testing.T) {
24+
in := dynoid.Subject{
25+
AppID: "7eeecc9f-b17f-4027-9aa1-ceb8427036c6",
26+
AppName: "testing",
27+
Dyno: "web.1",
28+
}
29+
30+
var out dynoid.Subject
31+
32+
if err := out.UnmarshalText([]byte(in.String())); err != nil {
33+
t.Fatalf("failed to unmarshal (%v)", err)
34+
}
35+
36+
if out.AppID != in.AppID {
37+
t.Fatalf("AppID missmatch (%q != %q)", out.AppID, in.AppID)
38+
}
39+
40+
if out.AppName != in.AppName {
41+
t.Fatalf("AppName missmatch (%q != %q)", out.AppName, in.AppName)
42+
}
43+
44+
if out.Dyno != in.Dyno {
45+
t.Fatalf("Dyno missmatch (%q != %q)", out.Dyno, in.Dyno)
46+
}
47+
}
48+
49+
func TestReading(t *testing.T) {
50+
oldFS := dynoid.DefaultFS
51+
defer func() {
52+
dynoid.DefaultFS = oldFS
53+
}()
54+
55+
spaceID := uuid.NewString()
56+
appID := uuid.NewString()
57+
58+
ctx, tk := GenerateIDToken(t, "heroku",
59+
dynoidtest.WithSpaceID(spaceID),
60+
dynoidtest.WithTokenOpts(dynoidtest.WithSubject(&dynoid.Subject{
61+
AppID: appID,
62+
AppName: "testapp",
63+
Dyno: "run.1",
64+
})),
65+
)
66+
dynoid.DefaultFS = dynoidtest.NewFS(map[string]string{
67+
"heroku": tk,
68+
})
69+
70+
token, err := dynoid.ReadLocalToken(ctx, "heroku")
71+
if err != nil {
72+
t.Fatalf("failed to read token (%v)", err)
73+
}
74+
75+
if token.SpaceID != spaceID {
76+
t.Fatalf("Unexpected SpaceID %q", token.SpaceID)
77+
}
78+
79+
if token.Subject.AppID != appID {
80+
t.Fatalf("Unexpected AppID %q", token.Subject.AppID)
81+
}
82+
}
83+
84+
func GenerateIDToken(t *testing.T, audience string, opts ...dynoidtest.IssuerOpt) (context.Context, string) {
85+
t.Helper()
86+
87+
ctx, iss, err := dynoidtest.NewWithContext(context.Background(), opts...)
1288
if err != nil {
1389
t.Fatal(err)
1490
}
@@ -18,9 +94,5 @@ func TestVerification(t *testing.T) {
1894
t.Fatal(err)
1995
}
2096

21-
verifier := NewWithCallback("heroku", AllowHerokuHost(dynoidtest.IssuerHost))
22-
23-
if _, err = verifier.Verify(ctx, token); err != nil {
24-
t.Error(err)
25-
}
97+
return ctx, token
2698
}

dynoid/dynoidtest/dynoidtest.go

+89-13
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
// dynoidtest provides helper functions for testing code that uses DynoID
12
package dynoidtest
23

34
import (
45
"context"
56
"crypto/rand"
67
"crypto/rsa"
78
"encoding/json"
9+
"fmt"
810
"net/http"
911
"net/http/httptest"
1012
"strings"
@@ -14,43 +16,115 @@ import (
1416
"github.com/coreos/go-oidc/v3/oidc"
1517
"github.com/golang-jwt/jwt/v4"
1618
jose "gopkg.in/square/go-jose.v2"
19+
20+
"github.com/heroku/x/dynoid"
1721
)
1822

1923
const (
20-
Audience = "heroku"
24+
// IssuerHost is the host used by the dynoidtest.Issuer
2125
IssuerHost = "heroku.local"
26+
27+
DefaultSpaceID = "test" // space id used when one is not provided
28+
DefaultAppID = "00000000-0000-0000-0000-000000000001" // app id used when one is not provided
29+
DefaultAppName = "sushi" // app name used when one is not provided
30+
DefaultDyno = "web.1" // dyno used when one is not provided
2231
)
2332

33+
// Issuer generates test tokens and provides a client for verifying them.
2434
type Issuer struct {
25-
key *rsa.PrivateKey
35+
key *rsa.PrivateKey
36+
spaceID string
37+
tokenOpts []TokenOpt
38+
}
39+
40+
// IssuerOpt allows the behavior of the issuer to be modified.
41+
type IssuerOpt interface {
42+
apply(*Issuer) error
43+
}
44+
45+
type issuerOptFunc func(*Issuer) error
46+
47+
func (f issuerOptFunc) apply(i *Issuer) error {
48+
return f(i)
2649
}
2750

28-
func New() (*Issuer, error) {
29-
_, iss, err := NewWithContext(context.Background())
51+
// WithSpaceID allows a spaceID to be supplied instead of using the default
52+
func WithSpaceID(spaceID string) IssuerOpt {
53+
return issuerOptFunc(func(i *Issuer) error {
54+
i.spaceID = spaceID
55+
return nil
56+
})
57+
}
58+
59+
// WithTokenOpts allows a default set of TokenOpt to be applied to every token
60+
// generated by the issuer
61+
func WithTokenOpts(opts ...TokenOpt) IssuerOpt {
62+
return issuerOptFunc(func(i *Issuer) error {
63+
i.tokenOpts = append(i.tokenOpts, opts...)
64+
return nil
65+
})
66+
}
67+
68+
// Create a new Issuer with the supplied opts applied
69+
func New(opts ...IssuerOpt) (*Issuer, error) {
70+
_, iss, err := NewWithContext(context.Background(), opts...)
3071
return iss, err
3172
}
3273

33-
func NewWithContext(ctx context.Context) (context.Context, *Issuer, error) {
74+
// Create a new Issuer with the supplied opts applied inheriting from the provided context
75+
func NewWithContext(ctx context.Context, opts ...IssuerOpt) (context.Context, *Issuer, error) {
3476
key, err := rsa.GenerateKey(rand.Reader, 2048)
3577
if err != nil {
3678
return ctx, nil, err
3779
}
3880

39-
iss := &Issuer{key: key}
81+
iss := &Issuer{key: key, spaceID: DefaultSpaceID, tokenOpts: []TokenOpt{}}
82+
for _, o := range opts {
83+
if err := o.apply(iss); err != nil {
84+
return ctx, nil, err
85+
}
86+
}
87+
4088
ctx = oidc.ClientContext(ctx, iss.HTTPClient())
4189

4290
return ctx, iss, nil
4391
}
4492

45-
func (iss *Issuer) GenerateIDToken(clientID string) (string, error) {
93+
// A TokenOpt modifies the way a token is minted
94+
type TokenOpt interface {
95+
apply(*jwt.RegisteredClaims) error
96+
}
97+
98+
type tokenOptFunc func(*jwt.RegisteredClaims) error
99+
100+
func (f tokenOptFunc) apply(i *jwt.RegisteredClaims) error {
101+
return f(i)
102+
}
103+
104+
// WithSubject allows the Subject to be different than the default
105+
func WithSubject(s *dynoid.Subject) TokenOpt {
106+
return tokenOptFunc(func(c *jwt.RegisteredClaims) error {
107+
c.Subject = s.String()
108+
return nil
109+
})
110+
}
111+
112+
// GenerateIDToken returns a new signed token as a string
113+
func (iss *Issuer) GenerateIDToken(clientID string, opts ...TokenOpt) (string, error) {
46114
now := time.Now()
47115

48116
claims := &jwt.RegisteredClaims{
49117
Audience: jwt.ClaimStrings([]string{clientID}),
50118
ExpiresAt: jwt.NewNumericDate(now.Add(5 * time.Minute)),
51119
IssuedAt: jwt.NewNumericDate(now),
52-
Issuer: "https://oidc.heroku.local/issuers/test",
53-
Subject: "app:00000000-0000-0000-0000-000000000001.sushi::dyno:web.1",
120+
Issuer: fmt.Sprintf("https://oidc.heroku.local/issuers/%s", iss.spaceID),
121+
Subject: (&dynoid.Subject{AppID: DefaultAppID, AppName: DefaultAppName, Dyno: DefaultDyno}).String(),
122+
}
123+
124+
for _, o := range append(iss.tokenOpts, opts...) {
125+
if err := o.apply(claims); err != nil {
126+
return "", err
127+
}
54128
}
55129

56130
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
@@ -59,6 +133,7 @@ func (iss *Issuer) GenerateIDToken(clientID string) (string, error) {
59133
return token.SignedString(iss.key)
60134
}
61135

136+
// HTTPClient returns a client that leverages the Issuer to validate tokens.
62137
func (iss *Issuer) HTTPClient() *http.Client {
63138
return &http.Client{Transport: &roundTripper{issuer: iss}}
64139
}
@@ -72,7 +147,8 @@ type roundTripper struct {
72147
func (rt *roundTripper) init() {
73148
mux := http.NewServeMux()
74149

75-
mux.HandleFunc("/issuers/test/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
150+
basePath := fmt.Sprintf("/issuers/%s/.well-known", rt.issuer.spaceID)
151+
mux.HandleFunc(basePath+"/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
76152
if !strings.EqualFold(r.Method, http.MethodGet) {
77153
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
78154
return
@@ -84,17 +160,17 @@ func (rt *roundTripper) init() {
84160
w.WriteHeader(http.StatusOK)
85161

86162
_, _ = w.Write([]byte(`{` +
87-
`"issuer":"https://oidc.heroku.local/issuers/test",` +
163+
fmt.Sprintf(`"issuer":"https://oidc.heroku.local/issuers/%s",`, rt.issuer.spaceID) +
88164
`"authorization_endpoint":"/dummy/authorization",` +
89-
`"jwks_uri":"https://oidc.heroku.local/issuers/test/.well-known/jwks.json",` +
165+
fmt.Sprintf(`"jwks_uri":"https://oidc.heroku.local/issuers/%s/.well-known/jwks.json",`, rt.issuer.spaceID) +
90166
`"response_types_supported":["implicit"],` +
91167
`"grant_types_supported":["implicit"],` +
92168
`"subject_types_supported":["public"],` +
93169
`"id_token_signing_alg_values_supported":["RS256"]` +
94170
`}`))
95171
})
96172

97-
mux.HandleFunc("/issuers/test/.well-known/jwks.json", func(w http.ResponseWriter, r *http.Request) {
173+
mux.HandleFunc(basePath+"/jwks.json", func(w http.ResponseWriter, r *http.Request) {
98174
if !strings.EqualFold(r.Method, http.MethodGet) {
99175
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
100176
return

0 commit comments

Comments
 (0)