Skip to content

Commit 9dd3262

Browse files
committed
all: More refactoring, bug fixes, and added tests
1 parent 0046c34 commit 9dd3262

15 files changed

+424
-245
lines changed

README.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ import (
3535

3636
func init() {
3737
// Unless we want to keep the original RS256 implementation alive, override it (recommended)
38-
gcpjwt.OverrideRS256WithIAMJWT() // For signJwt
39-
40-
gcpjwt.OverrideRS256WithIAMBlob() // For signBlob
38+
gcpjwt.SigningMethodIAMJWT.Override() // For signJwt
39+
// OR
40+
gcpjwt.SigningMethodIAMBlob.Override() // For signBlob
4141
}
4242
```
4343

@@ -95,10 +95,10 @@ func validateToken(tokenString string) {
9595
}
9696
config.EnableCache = true // Enable certificates cache
9797

98-
// To Verify (if we called OverrideRS256WithIAMJWT() or OverrideRS256WithIAMBlob())
98+
// To Verify (if we called Override() for our method type prior)
9999
token, err := jwt.Parse(tokenString, gcpjwt.VerfiyKeyfunc(context.Background(), config))
100100

101-
// If we DID NOT call a OverrideRS256 function
101+
// If we DID NOT call the Override() function
102102
// This is basically copying the https://github.com/dgrijalva/jwt-go/blob/master/parser.go#L23 ParseWithClaims function here but forcing our own method vs getting one based on the Alg field
103103
// Or Try and parse, Ignore the result and try with the proper method:
104104
token, _ := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {

appengine.go

-103
This file was deleted.

certs.go

+24-14
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package gcpjwt
22

33
import (
4+
"context"
45
"crypto/rsa"
56
"encoding/json"
67
"io/ioutil"
@@ -15,21 +16,28 @@ const (
1516
certificateURL = "https://www.googleapis.com/robot/v1/metadata/x509/"
1617
)
1718

18-
type certResponse struct {
19-
certs certificates
20-
expires time.Time
21-
}
22-
2319
// certificates is a map of key id -> public keys
2420
type certificates map[string]*rsa.PublicKey
2521

26-
func getCertificatesForAccount(hc *http.Client, account string) (*certResponse, error) {
27-
req, err := http.NewRequest(http.MethodGet, certificateURL+account, nil)
22+
func getCertificates(ctx context.Context, config *IAMConfig) (certificates, error) {
23+
if config.EnableCache {
24+
if certsResp, ok := getCertsFromCache(config.ServiceAccount); ok {
25+
return certsResp, nil
26+
}
27+
}
28+
29+
// Default config.Client is a http.DefaultClient
30+
client := config.Client
31+
if client == nil {
32+
client = getDefaultClient(ctx)
33+
}
34+
35+
req, err := http.NewRequest(http.MethodGet, certificateURL+config.ServiceAccount, nil)
2836
if err != nil {
2937
return nil, err
3038
}
3139

32-
resp, err := hc.Do(req)
40+
resp, err := client.Do(req)
3341
if err != nil {
3442
return nil, err
3543
}
@@ -41,19 +49,17 @@ func getCertificatesForAccount(hc *http.Client, account string) (*certResponse,
4149
}
4250

4351
certsRaw := make(map[string]string)
44-
4552
err = json.Unmarshal(b, &certsRaw)
4653
if err != nil {
4754
return nil, err
4855
}
49-
_, expires, err := cachecontrol.CachableResponse(req, resp, cachecontrol.Options{PrivateCache: true})
5056

51-
if err != nil {
52-
return nil, err
57+
_, expires, err := cachecontrol.CachableResponse(req, resp, cachecontrol.Options{PrivateCache: true})
58+
if err != nil && config.CacheExpiration > 0 {
59+
expires = time.Now().Add(config.CacheExpiration)
5360
}
5461

5562
certs := make(certificates)
56-
5763
for key, cert := range certsRaw {
5864
rsaKey, err := jwt.ParseRSAPublicKeyFromPEM([]byte(cert))
5965
if err != nil {
@@ -62,5 +68,9 @@ func getCertificatesForAccount(hc *http.Client, account string) (*certResponse,
6268
certs[key] = rsaKey
6369
}
6470

65-
return &certResponse{certs, expires}, err
71+
if config.EnableCache && !expires.IsZero() {
72+
updateCache(config.ServiceAccount, certs, expires)
73+
}
74+
75+
return certs, nil
6676
}

config.go

+28-11
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@ package gcpjwt
22

33
import (
44
"context"
5+
"crypto/md5"
56
"errors"
7+
"fmt"
68
"net/http"
9+
"time"
710

811
"golang.org/x/oauth2/google"
912
"google.golang.org/api/iam/v1"
@@ -37,16 +40,6 @@ type GCPConfig struct {
3740
// Client is a user provided *http.Client to use, http.DefaultClient is used otherwise (AppEngine URL Fetch Supported)
3841
// Used for verify requests
3942
Client *http.Client
40-
41-
// EnableCache will enable the in-memory caching of public certificates.
42-
// The cache will expire certificates when an expiration is known or provided and will refresh the cache if
43-
// it is unable to verify a signature from any of the certificates cached.
44-
EnableCache bool
45-
46-
// InjectKeyID will overwrite the provided header with one that contains the Key ID of the key used to sign the JWT.
47-
// Note that the IAM JWT signing method does this on its own and this is only applicable for the IAM Blob and Cloud KMS
48-
// signing methods. For CloudKMS, this will be a hash of the KeyPath configured.
49-
InjectKeyID bool
5043
}
5144

5245
// IAMConfig is relevant for both the signBlob and signJWT IAM API use-cases
@@ -57,21 +50,45 @@ type IAMConfig struct {
5750
// Service account can be the email address or the uniqueId of the service account used to sign the JWT with
5851
ServiceAccount string
5952

53+
// EnableCache will enable the in-memory caching of public certificates.
54+
// The cache will expire certificates when an expiration is known or fallback to the configured CacheExpiration
55+
EnableCache bool
56+
57+
// CacheExpiration is the default time to keep the certificates in cache if no expiration time is provided
58+
// Use a value of 0 to disable the expiration time fallback. Max reccomneded value is 24 hours.
59+
// https://cloud.google.com/iam/docs/understanding-service-accounts#managing_service_account_keys
60+
CacheExpiration time.Duration
61+
6062
// IAMType is a helper used to help clarify which IAM signing method this config is meant for.
6163
// Used for the jwtmiddleware and oauth2 packages.
6264
IAMType iamType
6365

66+
lastKeyID string
67+
6468
GCPConfig
6569
}
6670

71+
// KeyID will return the last used KeyID to sign the JWT - though it should be noted the signJwt method will always
72+
// add its own token header which is not parsed back to the token.
73+
// Helper function for adding the kid header to your token.
74+
func (i *IAMConfig) KeyID() string {
75+
return i.lastKeyID
76+
}
77+
6778
// KMSConfig is used to sign/verify JWTs with Google Cloud KMS
6879
type KMSConfig struct {
69-
// KeyPath is the name of the key to use in the format of '/projects/-/locations/...'
80+
// KeyPath is the name of the key to use in the format of:
81+
// "name=projects/*/locations/*/keyRings/*/cryptoKeys/*/cryptoKeyVersions/*"
7082
KeyPath string
7183

7284
GCPConfig
7385
}
7486

87+
// KeyID will return the MD5 hash of the configured KeyPath. Helper function for adding the kid header to your token.
88+
func (k *KMSConfig) KeyID() string {
89+
return fmt.Sprintf("%x", md5.Sum([]byte(k.KeyPath)))
90+
}
91+
7592
// NewIAMContext returns a new context.Context that carries a provided IAMConfig value
7693
func NewIAMContext(parent context.Context, val *IAMConfig) context.Context {
7794
return context.WithValue(parent, iamConfigKey{}, val)

0 commit comments

Comments
 (0)