@@ -131,6 +131,7 @@ def reload(self):
131
131
self ._auth_cache = {}
132
132
self ._init_backends ()
133
133
134
+
134
135
def get_auth_bearer_token (self , bearer_token , authoritative_source , auth_options = None ):
135
136
""" Returns an authentication object.
136
137
@@ -146,17 +147,21 @@ def get_auth_bearer_token(self, bearer_token, authoritative_source, auth_options
146
147
if auth_options is None :
147
148
auth_options = {}
148
149
150
+ # validate arguments
151
+ if authoritative_source is None :
152
+ raise AuthError ("Missing authoritative_source." )
153
+
149
154
backend = "jwt"
150
155
self ._logger .debug ("Using auth backend %s" % backend )
151
156
# Create auth object
152
157
try :
153
158
auth = self ._backends [backend ](backend , bearer_token , authoritative_source , auth_options )
154
- except Exception :
155
- traceback .print_exc ()
159
+ except KeyError :
156
160
raise AuthError ("Invalid auth backend '%s' specified" % backend )
157
161
158
162
return auth
159
163
164
+
160
165
def get_auth (self , username , password , authoritative_source , auth_options = None ):
161
166
""" Returns an authentication object.
162
167
@@ -299,6 +304,7 @@ class JwtAuth(BaseAuth):
299
304
_jwt_rw_group = None
300
305
_jwt_ro_group = None
301
306
_authenticated = None
307
+ _jwks_client = None
302
308
303
309
def __init__ (self , name , jwt_token , authoritative_source ,
304
310
auth_options = None ):
@@ -338,6 +344,14 @@ def __init__(self, name, jwt_token, authoritative_source,
338
344
self ._logger .error ('Unable to load Python jwt module, please verify it is installed' )
339
345
raise AuthError ('Unable to authenticate' )
340
346
347
+ # Set up JWK client as class variable
348
+ if self ._jwks_client is None :
349
+ jwk_url = self ._cfg .get (base_auth_backend , 'jwk_url' )
350
+ if jwk_url is None :
351
+ self ._logger .error ("Missing jwk_url in config" )
352
+ raise AuthError ("Authentication error" )
353
+ JwtAuth ._jwks_client = jwt .PyJWKClient (jwk_url )
354
+
341
355
# Decode token
342
356
try :
343
357
payload = jwt .decode (
@@ -347,6 +361,7 @@ def __init__(self, name, jwt_token, authoritative_source,
347
361
except jwt .exceptions .DecodeError :
348
362
raise AuthError ('Failed to decode JWT token' )
349
363
364
+
350
365
@create_span_authenticate
351
366
def authenticate (self ):
352
367
""" Verify authentication.
@@ -360,26 +375,12 @@ def authenticate(self):
360
375
return self ._authenticated
361
376
362
377
try :
363
- self ._token = self ._cfg .get ('auth.backends.' +
364
- self .auth_backend , 'jwk_url' )
365
- # Fetch JWKs (done when initializing JwtAuth-class),
366
- # keep the keys in the class instance
367
- jwk_request_response = requests .get (self ._token )
368
- jwks = jwk_request_response .json ()
369
- jwk_keys = {}
370
- for jwk in jwks ['keys' ]:
371
- kid = jwk ['kid' ]
372
- jwk_keys [kid ] = jwt .algorithms .RSAAlgorithm .from_jwk (json .dumps (jwk ))
373
-
374
- # Upon auth with a JWT-token
375
- # Retrieve key for token
376
- jwt_headers = jwt .get_unverified_header (self ._jwt_token )
377
- jwt_jwk_key = jwk_keys [jwt_headers ['kid' ]]
378
-
379
378
# Decode and verify token
379
+ jwt_headers = jwt .get_unverified_header (self ._jwt_token )
380
+ signing_key = self ._jwks_client .get_signing_key_from_jwt (self ._jwt_token )
380
381
payload = jwt .decode (
381
382
self ._jwt_token ,
382
- key = jwt_jwk_key ,
383
+ key = signing_key . key ,
383
384
algorithms = [jwt_headers ['alg' ]],
384
385
options = {"verify_aud" : False })
385
386
0 commit comments