-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtokenvalidationmiddleware.go
240 lines (202 loc) · 5.91 KB
/
tokenvalidationmiddleware.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
package tokenvalidationmiddleware
import (
"context"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"math/big"
"net/http"
"strings"
"github.com/dgrijalva/jwt-go"
)
// TokenValidationMiddleware represents the configuration for the validation middleware
type TokenValidationMiddleware struct {
Options Options
}
// New creates a new instance of the middleware
func New() *TokenValidationMiddleware {
return &TokenValidationMiddleware{
Options: Options{},
}
}
// NewWithOptions creates a new instance of the middleware using the provided options
func NewWithOptions(opts Options) *TokenValidationMiddleware {
return &TokenValidationMiddleware{
Options: opts,
}
}
func contains(s []interface{}, in string) bool {
for _, a := range s {
if a == in {
return true
}
}
return false
}
// GetDefaultValidator godoc
func GetDefaultValidator(options *Options) func(token *jwt.Token) (interface{}, error) {
return func(token *jwt.Token) (interface{}, error) {
// Verify 'aud' claim
if options.VerifyAudience {
aud := options.Audience
tokenAud := token.Claims.(jwt.MapClaims)["aud"]
switch tokenAud.(type) {
case string:
checkAud := token.Claims.(jwt.MapClaims).VerifyAudience(aud, options.VerifyAudience)
if !checkAud {
return token, errors.New("Invalid audience")
}
case []interface{}:
if !contains(tokenAud.([]interface{}), aud) {
return token, errors.New("Invalid audience")
}
}
}
// Verify 'iss' claim
if options.VerifyIssuer {
iss := options.Issuer
checkIss := token.Claims.(jwt.MapClaims).VerifyIssuer(iss, options.VerifyIssuer)
if !checkIss {
return token, errors.New("Invalid issuer")
}
}
cert, err := getPemCert(token, options.Issuer)
if err != nil {
return token, errors.New("Unable to get signing key")
}
return cert, nil
}
}
// NewRSA256Validator will issue a new instance of the TokenValidationMiddleware that uses RS256 validation for a Bearer token
func NewRSA256Validator(options *Options) *TokenValidationMiddleware {
return New().SetOptions(Options{
ValidationKeyFunc: GetDefaultValidator(options),
SigningMethod: jwt.SigningMethodRS256,
})
}
// SetOptions sets the options on a middleware instance
func (m *TokenValidationMiddleware) SetOptions(options Options) *TokenValidationMiddleware {
m.Options = options
return m
}
// ValidateBearerToken will validate the cliams of the incoming auth token
func (m *TokenValidationMiddleware) ValidateBearerToken(r *http.Request) (bool, error) {
// Extract bearer token from auth header
t, err := GetBearerToken(r)
if err != nil {
return false, fmt.Errorf("Error extracting token: %w", err)
}
pt, err := jwt.Parse(t, m.Options.ValidationKeyFunc)
if err != nil {
return false, fmt.Errorf("Error parsing token: %w", err)
}
if m.Options.SigningMethod != nil && m.Options.SigningMethod.Alg() != pt.Header["alg"] {
message := fmt.Sprintf("Expected %s signing method but token specified %s",
m.Options.SigningMethod.Alg(),
pt.Header["alg"])
return false, fmt.Errorf("Error validating token algorithm: %s", message)
}
nr := r.WithContext(context.WithValue(r.Context(), "user", pt))
*r = *nr
return true, nil
}
// GetBearerToken extracts a bearer token from the request Authorization header
func GetBearerToken(r *http.Request) (string, error) {
authHeaderParts := strings.Split(r.Header.Get("Authorization"), " ")
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
return "", errors.New("Authorization header format must be Bearer {token}")
}
return authHeaderParts[1], nil
}
// ValidateScope validates the presense of a specific scope on a bearer token
func ValidateScope(scope string, tokenString string) bool {
token, _ := jwt.ParseWithClaims(tokenString, &Claims{}, nil)
if token == nil {
return false
}
claims, _ := token.Claims.(*Claims)
hasScope := false
result := strings.Split(claims.Scope, " ")
for i := range result {
if result[i] == scope {
hasScope = true
}
}
return hasScope
}
// RequestHasScope checks for a scope in the token in the Authorization Header of the Request
func RequestHasScope(scope string, r *http.Request) bool {
token, err := GetBearerToken(r)
if err != nil {
return false
}
return ValidateScope(scope, token)
}
var httpClient = &http.Client{}
// getPemCert uses the IDP well-known endpoint to collect modulus and exponent to construct an XC5 public key
func getPemCert(token *jwt.Token, issuer string) (*rsa.PublicKey, error) {
var openIDConfig OpenIDConfig
nStr := ""
eStr := ""
var cert *rsa.PublicKey
wke := fmt.Sprintf("%s/.well-known/openid-configuration", strings.TrimSuffix(issuer, "/"))
resp, err := httpClient.Get(wke)
if err != nil {
return cert, err
}
// read the payload
body, err := ioutil.ReadAll(resp.Body)
defer resp.Body.Close()
if err != nil {
return cert, err
}
err = json.Unmarshal(body, &openIDConfig)
if err != nil {
return cert, err
}
jwksResp, err := httpClient.Get(openIDConfig.JwksURI)
if err != nil {
return cert, err
}
var jwks = Jwks{}
err = json.NewDecoder(jwksResp.Body).Decode(&jwks)
if err != nil {
return cert, err
}
for k := range jwks.Keys {
if token.Header["kid"] == jwks.Keys[k].Kid {
nStr = jwks.Keys[k].N
eStr = jwks.Keys[k].E
}
}
if nStr == "" || eStr == "" {
err := errors.New("Unable to find appropriate key")
return cert, err
}
cert, err = genXC5(nStr, eStr)
if err != nil {
return cert, err
}
return cert, nil
}
func genXC5(nStr string, eStr string) (*rsa.PublicKey, error) {
// decode the base64 bytes for n
var pub *rsa.PublicKey
nb, err := base64.RawURLEncoding.DecodeString(nStr)
if err != nil {
return pub, fmt.Errorf("%s", err)
}
// The default exponent is usually 65537
e := 65537
if eStr != "AQAB" && eStr != "AAEAAQ" {
return pub, fmt.Errorf("need to decode e: %v", eStr)
}
var pubKey = &rsa.PublicKey{
N: new(big.Int).SetBytes(nb),
E: e,
}
return pubKey, nil
}