Skip to content

Commit 9656a55

Browse files
committed
first cut at proper read/write locking for KeyBundle
1 parent c9e59bc commit 9656a55

File tree

3 files changed

+90
-53
lines changed

3 files changed

+90
-53
lines changed

poetry.lock

Lines changed: 17 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ jwtpeek = "cryptojwt.tools.jwtpeek:main"
4141
python = "^3.6"
4242
cryptography = "^3.4.6"
4343
requests = "^2.25.1"
44+
readerwriterlock = "^1.0.8"
4445

4546
[tool.poetry.dev-dependencies]
4647
alabaster = "^0.7.12"

src/cryptojwt/key_bundle.py

Lines changed: 72 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
import json
44
import logging
55
import os
6-
import threading
76
import time
87
from datetime import datetime
98
from functools import cmp_to_key
109
from typing import List
1110
from typing import Optional
1211

1312
import requests
13+
from readerwriterlock import rwlock
1414

1515
from cryptojwt.jwk.ec import NIST2SEC
1616
from cryptojwt.jwk.hmac import new_sym_key
@@ -47,8 +47,6 @@
4747

4848
MAP = {"dec": "enc", "enc": "enc", "ver": "sig", "sig": "sig"}
4949

50-
update_lock = threading.Lock()
51-
5250

5351
def harmonize_usage(use):
5452
"""
@@ -153,6 +151,14 @@ def ec_init(spec):
153151
return _kb
154152

155153

154+
def keys_writer(func):
155+
def wrapper(self, *args, **kwargs):
156+
with self._lock_writer:
157+
return func(self, *args, **kwargs)
158+
159+
return wrapper
160+
161+
156162
class KeyBundle:
157163
"""The Key Bundle"""
158164

@@ -230,6 +236,10 @@ def __init__(
230236
self.source = None
231237
self.time_out = 0
232238

239+
self._lock = rwlock.RWLockFairD()
240+
self._lock_reader = self._lock.gen_rlock()
241+
self._lock_writer = self._lock.gen_wlock()
242+
233243
if httpc:
234244
self.httpc = httpc
235245
else:
@@ -500,6 +510,7 @@ def _uptodate(self):
500510
return self.update()
501511
return False
502512

513+
@keys_writer
503514
def update(self):
504515
"""
505516
Reload the keys if necessary.
@@ -510,35 +521,34 @@ def update(self):
510521
:return: True if update was ok or False if we encountered an error during update.
511522
"""
512523
if self.source:
513-
with update_lock:
514-
_old_keys = self._keys # just in case
524+
_old_keys = self._keys # just in case
515525

516-
# reread everything
517-
self._keys = []
518-
updated = None
526+
# reread everything
527+
self._keys = []
528+
updated = None
519529

520-
try:
521-
if self.local:
522-
if self.fileformat in ["jwks", "jwk"]:
523-
updated = self.do_local_jwk(self.source)
524-
elif self.fileformat == "der":
525-
updated = self.do_local_der(self.source, self.keytype, self.keyusage)
526-
elif self.remote:
527-
updated = self.do_remote()
528-
except Exception as err:
529-
LOGGER.error("Key bundle update failed: %s", err)
530-
self._keys = _old_keys # restore
531-
return False
532-
533-
if updated:
534-
now = time.time()
535-
for _key in _old_keys:
536-
if _key not in self._keys:
537-
if not _key.inactive_since: # If already marked don't mess
538-
_key.inactive_since = now
539-
self._keys.append(_key)
540-
else:
541-
self._keys = _old_keys
530+
try:
531+
if self.local:
532+
if self.fileformat in ["jwks", "jwk"]:
533+
updated = self.do_local_jwk(self.source)
534+
elif self.fileformat == "der":
535+
updated = self.do_local_der(self.source, self.keytype, self.keyusage)
536+
elif self.remote:
537+
updated = self.do_remote()
538+
except Exception as err:
539+
LOGGER.error("Key bundle update failed: %s", err)
540+
self._keys = _old_keys # restore
541+
return False
542+
543+
if updated:
544+
now = time.time()
545+
for _key in _old_keys:
546+
if _key not in self._keys:
547+
if not _key.inactive_since: # If already marked don't mess
548+
_key.inactive_since = now
549+
self._keys.append(_key)
550+
else:
551+
self._keys = _old_keys
542552

543553
return True
544554

@@ -551,32 +561,34 @@ def get(self, typ="", only_active=True):
551561
otherwise the appropriate keys in a list
552562
"""
553563
self._uptodate()
554-
_typs = [typ.lower(), typ.upper()]
555564

556-
if typ:
557-
_keys = [k for k in self._keys if k.kty in _typs]
558-
else:
559-
_keys = self._keys
565+
with self._lock_reader:
566+
if typ:
567+
_typs = [typ.lower(), typ.upper()]
568+
_keys = [k for k in self._keys if k.kty in _typs]
569+
else:
570+
_keys = self._keys
560571

561572
if only_active:
562573
return [k for k in _keys if not k.inactive_since]
563574

564575
return _keys
565576

566-
def keys(self):
577+
def keys(self, update: bool = True):
567578
"""
568579
Return all keys after having updated them
569580
570581
:return: List of all keys
571582
"""
572-
self._uptodate()
573-
574-
return self._keys
583+
if update:
584+
self._uptodate()
585+
with self._lock_reader:
586+
return self._keys
575587

576588
def active_keys(self):
577589
"""Return the set of active keys."""
578590
_res = []
579-
for k in self._keys:
591+
for k in self.keys():
580592
try:
581593
ias = k.inactive_since
582594
except ValueError:
@@ -586,6 +598,7 @@ def active_keys(self):
586598
_res.append(k)
587599
return _res
588600

601+
@keys_writer
589602
def remove_keys_by_type(self, typ):
590603
"""
591604
Remove keys that are of a specific type.
@@ -605,9 +618,8 @@ def jwks(self, private=False):
605618
:param private: Whether private key information should be included.
606619
:return: A JWKS JSON representation of the keys in this bundle
607620
"""
608-
self._uptodate()
609621
keys = list()
610-
for k in self._keys:
622+
for k in self.keys():
611623
if private:
612624
key = k.serialize(private)
613625
else:
@@ -617,6 +629,7 @@ def jwks(self, private=False):
617629
keys.append(key)
618630
return json.dumps({"keys": keys})
619631

632+
@keys_writer
620633
def append(self, key):
621634
"""
622635
Add a key to list of keys in this bundle
@@ -625,10 +638,12 @@ def append(self, key):
625638
"""
626639
self._keys.append(key)
627640

641+
@keys_writer
628642
def extend(self, keys):
629643
"""Add a key to the list of keys."""
630644
self._keys.extend(keys)
631645

646+
@keys_writer
632647
def remove(self, key):
633648
"""
634649
Remove a specific key from this bundle
@@ -648,6 +663,7 @@ def __len__(self):
648663
"""
649664
return len(self._keys)
650665

666+
@keys_writer
651667
def set(self, keys):
652668
"""Set the keys to the set provided."""
653669
self._keys = keys
@@ -659,13 +675,15 @@ def get_key_with_kid(self, kid):
659675
:param kid: The Key ID
660676
:return: The key or None
661677
"""
678+
self._uptodate()
679+
with self._lock_reader:
680+
return self._get_key_with_kid(kid)
681+
682+
def _get_key_with_kid(self, kid):
662683
for key in self._keys:
663684
if key.kid == kid:
664685
return key
665686

666-
# Try updating since there might have been an update to the key file
667-
self.update()
668-
669687
for key in self._keys:
670688
if key.kid == kid:
671689
return key
@@ -680,16 +698,16 @@ def kids(self):
680698
The reason might be that there are some keys with no key ID.
681699
:return: A list of all the key IDs that exists in this bundle
682700
"""
683-
self._uptodate()
684-
return [key.kid for key in self._keys if key.kid != ""]
701+
return [key.kid for key in self.keys() if key.kid != ""]
685702

703+
@keys_writer
686704
def mark_as_inactive(self, kid):
687705
"""
688706
Mark a specific key as inactive based on the keys KeyID.
689707
690708
:param kid: The Key Identifier
691709
"""
692-
k = self.get_key_with_kid(kid)
710+
k = self._get_key_with_kid(kid)
693711
if k:
694712
self._keys.remove(k)
695713
k.inactive_since = time.time()
@@ -698,17 +716,19 @@ def mark_as_inactive(self, kid):
698716
else:
699717
return False
700718

719+
@keys_writer
701720
def mark_all_as_inactive(self):
702721
"""
703722
Mark a specific key as inactive based on the keys KeyID.
704723
"""
705-
_keys = self.keys()
724+
_keys = self._keys
706725
_updated = []
707726
for k in _keys:
708727
k.inactive_since = time.time()
709728
_updated.append(k)
710729
self._keys = _updated
711730

731+
@keys_writer
712732
def remove_outdated(self, after, when=0):
713733
"""
714734
Remove keys that should not be available any more.
@@ -775,7 +795,7 @@ def difference(self, bundle):
775795
if not isinstance(bundle, KeyBundle):
776796
return ValueError("Not a KeyBundle instance")
777797

778-
return [k for k in self._keys if k not in bundle]
798+
return [k for k in self.keys() if k not in bundle]
779799

780800
def dump(self, exclude_attributes: Optional[List[str]] = None):
781801
if exclude_attributes is None:
@@ -785,7 +805,7 @@ def dump(self, exclude_attributes: Optional[List[str]] = None):
785805

786806
if "keys" not in exclude_attributes:
787807
_keys = []
788-
for _k in self._keys:
808+
for _k in self.keys(update=False):
789809
_ser = _k.to_dict()
790810
if _k.inactive_since:
791811
_ser["inactive_since"] = _k.inactive_since
@@ -819,6 +839,7 @@ def load(self, spec):
819839

820840
return self
821841

842+
@keys_writer
822843
def flush(self):
823844
self._keys = []
824845
self.cache_time = (300,)

0 commit comments

Comments
 (0)