3
3
import json
4
4
import logging
5
5
import os
6
- import threading
7
6
import time
8
7
from datetime import datetime
9
8
from functools import cmp_to_key
10
9
from typing import List
11
10
from typing import Optional
12
11
13
12
import requests
13
+ from readerwriterlock import rwlock
14
14
15
15
from cryptojwt .jwk .ec import NIST2SEC
16
16
from cryptojwt .jwk .hmac import new_sym_key
47
47
48
48
MAP = {"dec" : "enc" , "enc" : "enc" , "ver" : "sig" , "sig" : "sig" }
49
49
50
- update_lock = threading .Lock ()
51
-
52
50
53
51
def harmonize_usage (use ):
54
52
"""
@@ -153,6 +151,14 @@ def ec_init(spec):
153
151
return _kb
154
152
155
153
154
+ def keys_writer (func ):
155
+ def wrapper (self , * args , ** kwargs ):
156
+ with self ._lock_writer :
157
+ return func (self , * args , ** kwargs )
158
+
159
+ return wrapper
160
+
161
+
156
162
class KeyBundle :
157
163
"""The Key Bundle"""
158
164
@@ -230,6 +236,10 @@ def __init__(
230
236
self .source = None
231
237
self .time_out = 0
232
238
239
+ self ._lock = rwlock .RWLockFairD ()
240
+ self ._lock_reader = self ._lock .gen_rlock ()
241
+ self ._lock_writer = self ._lock .gen_wlock ()
242
+
233
243
if httpc :
234
244
self .httpc = httpc
235
245
else :
@@ -500,6 +510,7 @@ def _uptodate(self):
500
510
return self .update ()
501
511
return False
502
512
513
+ @keys_writer
503
514
def update (self ):
504
515
"""
505
516
Reload the keys if necessary.
@@ -510,35 +521,34 @@ def update(self):
510
521
:return: True if update was ok or False if we encountered an error during update.
511
522
"""
512
523
if self .source :
513
- with update_lock :
514
- _old_keys = self ._keys # just in case
524
+ _old_keys = self ._keys # just in case
515
525
516
- # reread everything
517
- self ._keys = []
518
- updated = None
526
+ # reread everything
527
+ self ._keys = []
528
+ updated = None
519
529
520
- try :
521
- if self .local :
522
- if self .fileformat in ["jwks" , "jwk" ]:
523
- updated = self .do_local_jwk (self .source )
524
- elif self .fileformat == "der" :
525
- updated = self .do_local_der (self .source , self .keytype , self .keyusage )
526
- elif self .remote :
527
- updated = self .do_remote ()
528
- except Exception as err :
529
- LOGGER .error ("Key bundle update failed: %s" , err )
530
- self ._keys = _old_keys # restore
531
- return False
532
-
533
- if updated :
534
- now = time .time ()
535
- for _key in _old_keys :
536
- if _key not in self ._keys :
537
- if not _key .inactive_since : # If already marked don't mess
538
- _key .inactive_since = now
539
- self ._keys .append (_key )
540
- else :
541
- self ._keys = _old_keys
530
+ try :
531
+ if self .local :
532
+ if self .fileformat in ["jwks" , "jwk" ]:
533
+ updated = self .do_local_jwk (self .source )
534
+ elif self .fileformat == "der" :
535
+ updated = self .do_local_der (self .source , self .keytype , self .keyusage )
536
+ elif self .remote :
537
+ updated = self .do_remote ()
538
+ except Exception as err :
539
+ LOGGER .error ("Key bundle update failed: %s" , err )
540
+ self ._keys = _old_keys # restore
541
+ return False
542
+
543
+ if updated :
544
+ now = time .time ()
545
+ for _key in _old_keys :
546
+ if _key not in self ._keys :
547
+ if not _key .inactive_since : # If already marked don't mess
548
+ _key .inactive_since = now
549
+ self ._keys .append (_key )
550
+ else :
551
+ self ._keys = _old_keys
542
552
543
553
return True
544
554
@@ -551,32 +561,34 @@ def get(self, typ="", only_active=True):
551
561
otherwise the appropriate keys in a list
552
562
"""
553
563
self ._uptodate ()
554
- _typs = [typ .lower (), typ .upper ()]
555
564
556
- if typ :
557
- _keys = [k for k in self ._keys if k .kty in _typs ]
558
- else :
559
- _keys = self ._keys
565
+ with self ._lock_reader :
566
+ if typ :
567
+ _typs = [typ .lower (), typ .upper ()]
568
+ _keys = [k for k in self ._keys if k .kty in _typs ]
569
+ else :
570
+ _keys = self ._keys
560
571
561
572
if only_active :
562
573
return [k for k in _keys if not k .inactive_since ]
563
574
564
575
return _keys
565
576
566
- def keys (self ):
577
+ def keys (self , update : bool = True ):
567
578
"""
568
579
Return all keys after having updated them
569
580
570
581
:return: List of all keys
571
582
"""
572
- self ._uptodate ()
573
-
574
- return self ._keys
583
+ if update :
584
+ self ._uptodate ()
585
+ with self ._lock_reader :
586
+ return self ._keys
575
587
576
588
def active_keys (self ):
577
589
"""Return the set of active keys."""
578
590
_res = []
579
- for k in self ._keys :
591
+ for k in self .keys () :
580
592
try :
581
593
ias = k .inactive_since
582
594
except ValueError :
@@ -586,6 +598,7 @@ def active_keys(self):
586
598
_res .append (k )
587
599
return _res
588
600
601
+ @keys_writer
589
602
def remove_keys_by_type (self , typ ):
590
603
"""
591
604
Remove keys that are of a specific type.
@@ -605,9 +618,8 @@ def jwks(self, private=False):
605
618
:param private: Whether private key information should be included.
606
619
:return: A JWKS JSON representation of the keys in this bundle
607
620
"""
608
- self ._uptodate ()
609
621
keys = list ()
610
- for k in self ._keys :
622
+ for k in self .keys () :
611
623
if private :
612
624
key = k .serialize (private )
613
625
else :
@@ -617,6 +629,7 @@ def jwks(self, private=False):
617
629
keys .append (key )
618
630
return json .dumps ({"keys" : keys })
619
631
632
+ @keys_writer
620
633
def append (self , key ):
621
634
"""
622
635
Add a key to list of keys in this bundle
@@ -625,10 +638,12 @@ def append(self, key):
625
638
"""
626
639
self ._keys .append (key )
627
640
641
+ @keys_writer
628
642
def extend (self , keys ):
629
643
"""Add a key to the list of keys."""
630
644
self ._keys .extend (keys )
631
645
646
+ @keys_writer
632
647
def remove (self , key ):
633
648
"""
634
649
Remove a specific key from this bundle
@@ -648,6 +663,7 @@ def __len__(self):
648
663
"""
649
664
return len (self ._keys )
650
665
666
+ @keys_writer
651
667
def set (self , keys ):
652
668
"""Set the keys to the set provided."""
653
669
self ._keys = keys
@@ -659,13 +675,15 @@ def get_key_with_kid(self, kid):
659
675
:param kid: The Key ID
660
676
:return: The key or None
661
677
"""
678
+ self ._uptodate ()
679
+ with self ._lock_reader :
680
+ return self ._get_key_with_kid (kid )
681
+
682
+ def _get_key_with_kid (self , kid ):
662
683
for key in self ._keys :
663
684
if key .kid == kid :
664
685
return key
665
686
666
- # Try updating since there might have been an update to the key file
667
- self .update ()
668
-
669
687
for key in self ._keys :
670
688
if key .kid == kid :
671
689
return key
@@ -680,16 +698,16 @@ def kids(self):
680
698
The reason might be that there are some keys with no key ID.
681
699
:return: A list of all the key IDs that exists in this bundle
682
700
"""
683
- self ._uptodate ()
684
- return [key .kid for key in self ._keys if key .kid != "" ]
701
+ return [key .kid for key in self .keys () if key .kid != "" ]
685
702
703
+ @keys_writer
686
704
def mark_as_inactive (self , kid ):
687
705
"""
688
706
Mark a specific key as inactive based on the keys KeyID.
689
707
690
708
:param kid: The Key Identifier
691
709
"""
692
- k = self .get_key_with_kid (kid )
710
+ k = self ._get_key_with_kid (kid )
693
711
if k :
694
712
self ._keys .remove (k )
695
713
k .inactive_since = time .time ()
@@ -698,17 +716,19 @@ def mark_as_inactive(self, kid):
698
716
else :
699
717
return False
700
718
719
+ @keys_writer
701
720
def mark_all_as_inactive (self ):
702
721
"""
703
722
Mark a specific key as inactive based on the keys KeyID.
704
723
"""
705
- _keys = self .keys ()
724
+ _keys = self ._keys
706
725
_updated = []
707
726
for k in _keys :
708
727
k .inactive_since = time .time ()
709
728
_updated .append (k )
710
729
self ._keys = _updated
711
730
731
+ @keys_writer
712
732
def remove_outdated (self , after , when = 0 ):
713
733
"""
714
734
Remove keys that should not be available any more.
@@ -775,7 +795,7 @@ def difference(self, bundle):
775
795
if not isinstance (bundle , KeyBundle ):
776
796
return ValueError ("Not a KeyBundle instance" )
777
797
778
- return [k for k in self ._keys if k not in bundle ]
798
+ return [k for k in self .keys () if k not in bundle ]
779
799
780
800
def dump (self , exclude_attributes : Optional [List [str ]] = None ):
781
801
if exclude_attributes is None :
@@ -785,7 +805,7 @@ def dump(self, exclude_attributes: Optional[List[str]] = None):
785
805
786
806
if "keys" not in exclude_attributes :
787
807
_keys = []
788
- for _k in self ._keys :
808
+ for _k in self .keys ( update = False ) :
789
809
_ser = _k .to_dict ()
790
810
if _k .inactive_since :
791
811
_ser ["inactive_since" ] = _k .inactive_since
@@ -819,6 +839,7 @@ def load(self, spec):
819
839
820
840
return self
821
841
842
+ @keys_writer
822
843
def flush (self ):
823
844
self ._keys = []
824
845
self .cache_time = (300 ,)
0 commit comments