@@ -180,6 +180,7 @@ def __init__(self, keys=None, source="", cache_time=300, verify_ssl=True,
180
180
181
181
self ._keys = []
182
182
self .remote = False
183
+ self .local = False
183
184
self .cache_time = cache_time
184
185
self .time_out = 0
185
186
self .etag = ""
@@ -189,7 +190,8 @@ def __init__(self, keys=None, source="", cache_time=300, verify_ssl=True,
189
190
self .keyusage = keyusage
190
191
self .imp_jwks = None
191
192
self .last_updated = 0
192
- self .last_remote = None
193
+ self .last_remote = None # HTTP Date of last remote update
194
+ self .last_local = None # UNIX timestamp of last local update
193
195
194
196
if httpc :
195
197
self .httpc = httpc
@@ -209,13 +211,13 @@ def __init__(self, keys=None, source="", cache_time=300, verify_ssl=True,
209
211
self .do_keys (keys )
210
212
else :
211
213
self ._set_source (source , fileformat )
212
-
213
- if not self .remote and self .source : # local file
214
+ if self .local :
214
215
self ._do_local (kid )
215
216
216
217
def _set_source (self , source , fileformat ):
217
218
if source .startswith ("file://" ):
218
219
self .source = source [7 :]
220
+ self .local = True
219
221
elif source .startswith ("http://" ) or source .startswith ("https://" ):
220
222
self .source = source
221
223
self .remote = True
@@ -225,6 +227,7 @@ def _set_source(self, source, fileformat):
225
227
if fileformat .lower () in ['rsa' , 'der' , 'jwks' ]:
226
228
if os .path .isfile (source ):
227
229
self .source = source
230
+ self .local = True
228
231
else :
229
232
raise ImportError ('No such file' )
230
233
else :
@@ -236,6 +239,16 @@ def _do_local(self, kid):
236
239
elif self .fileformat == "der" :
237
240
self .do_local_der (self .source , self .keytype , self .keyusage , kid )
238
241
242
+ def _local_update_required (self ) -> bool :
243
+ stat = os .stat (self .source )
244
+ if self .last_local and stat .st_mtime < self .last_local :
245
+ LOGGER .debug ("%s not modfied" , self .source )
246
+ return False
247
+ else :
248
+ LOGGER .debug ("%s modfied" , self .source )
249
+ self .last_local = stat .st_mtime
250
+ return True
251
+
239
252
def do_keys (self , keys ):
240
253
"""
241
254
Go from JWK description to binary keys
@@ -291,12 +304,15 @@ def do_local_jwk(self, filename):
291
304
292
305
:param filename: Name of the file from which the JWKS should be loaded
293
306
"""
307
+ LOGGER .debug ("Reading JWKS from %s" , filename )
294
308
with open (filename ) as input_file :
295
309
_info = json .load (input_file )
296
310
if 'keys' in _info :
297
311
self .do_keys (_info ["keys" ])
298
312
else :
299
313
self .do_keys ([_info ])
314
+ self .last_local = time .time ()
315
+ self .time_out = self .last_local + self .cache_time
300
316
301
317
def do_local_der (self , filename , keytype , keyusage = None , kid = '' ):
302
318
"""
@@ -306,6 +322,7 @@ def do_local_der(self, filename, keytype, keyusage=None, kid=''):
306
322
:param keytype: Presently 'rsa' and 'ec' supported
307
323
:param keyusage: encryption ('enc') or signing ('sig') or both
308
324
"""
325
+ LOGGER .debug ("Reading DER from %s" , filename )
309
326
key_args = {}
310
327
_kty = keytype .lower ()
311
328
if _kty in ['rsa' , 'ec' ]:
@@ -325,6 +342,8 @@ def do_local_der(self, filename, keytype, keyusage=None, kid=''):
325
342
key_args ['kid' ] = kid
326
343
327
344
self .do_keys ([key_args ])
345
+ self .last_local = time .time ()
346
+ self .time_out = self .last_local + self .cache_time
328
347
329
348
def do_remote (self ):
330
349
"""
@@ -400,14 +419,12 @@ def _parse_remote_response(self, response):
400
419
401
420
def _uptodate (self ):
402
421
res = False
403
- if not self ._keys :
404
- if self .remote : # verify that it's not to old
405
- if time .time () > self .time_out :
406
- if self .update ():
407
- res = True
408
- elif self .remote :
409
- if self .update ():
410
- res = True
422
+ if self .remote or self .local :
423
+ if time .time () > self .time_out :
424
+ if self .local and not self ._local_update_required ():
425
+ res = True
426
+ elif self .update ():
427
+ res = True
411
428
return res
412
429
413
430
def update (self ):
@@ -425,13 +442,13 @@ def update(self):
425
442
self ._keys = []
426
443
427
444
try :
428
- if self .remote is False :
445
+ if self .local :
429
446
if self .fileformat in ["jwks" , "jwk" ]:
430
447
self .do_local_jwk (self .source )
431
448
elif self .fileformat == "der" :
432
449
self .do_local_der (self .source , self .keytype ,
433
450
self .keyusage )
434
- else :
451
+ elif self . remote :
435
452
res = self .do_remote ()
436
453
except Exception as err :
437
454
LOGGER .error ('Key bundle update failed: %s' , err )
@@ -674,8 +691,11 @@ def dump(self):
674
691
"keys" : _keys ,
675
692
"fileformat" : self .fileformat ,
676
693
"last_updated" : self .last_updated ,
694
+ "last_remote" : self .last_remote ,
695
+ "last_local" : self .last_local ,
677
696
"httpc_params" : self .httpc_params ,
678
697
"remote" : self .remote ,
698
+ "local" : self .local ,
679
699
"imp_jwks" : self .imp_jwks ,
680
700
"time_out" : self .time_out ,
681
701
"cache_time" : self .cache_time
@@ -693,7 +713,10 @@ def load(self, spec):
693
713
self .source = spec .get ("source" , None )
694
714
self .fileformat = spec .get ("fileformat" , "jwks" )
695
715
self .last_updated = spec .get ("last_updated" , 0 )
716
+ self .last_remote = spec .get ("last_remote" , None )
717
+ self .last_local = spec .get ("last_local" , None )
696
718
self .remote = spec .get ("remote" , False )
719
+ self .local = spec .get ("local" , False )
697
720
self .imp_jwks = spec .get ('imp_jwks' , None )
698
721
self .time_out = spec .get ('time_out' , 0 )
699
722
self .cache_time = spec .get ('cache_time' , 0 )
0 commit comments