Skip to content

Commit e7ad87e

Browse files
committed
add key wrapping functions
1 parent 2950813 commit e7ad87e

File tree

2 files changed

+73
-0
lines changed

2 files changed

+73
-0
lines changed

src/cryptojwt/jwk/wrap.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
"""JWK wrapping"""
2+
3+
import json
4+
5+
from . import JWK
6+
from ..jwe.jwe import JWE
7+
from ..jwx import key_from_jwk_dict
8+
9+
__author__ = 'jschlyter'
10+
11+
DEFAULT_WRAP_PARAMS = {
12+
"EC": {"alg": "ECDH-ES+A128KW", "enc": "A128GCM"},
13+
"RSA": {"alg": "RSA1_5", "enc": "A128CBC-HS256"},
14+
"oct": {"alg": "A128KW", "enc": "A128CBC-HS256"},
15+
}
16+
17+
18+
def wrap_key(key: JWK, wrapping_key: JWK, wrap_params: dict = DEFAULT_WRAP_PARAMS) -> str:
19+
message = json.dumps(key.serialize(private=True)).encode()
20+
try:
21+
enc_params = wrap_params[wrapping_key.kty]
22+
except KeyError:
23+
raise ValueError("Unsupported wrapping key type")
24+
_jwe = JWE(msg=message, **enc_params)
25+
return _jwe.encrypt(keys=[wrapping_key], kid=wrapping_key.kid)
26+
27+
28+
def unwrap_key(jwe: str, wrapping_keys: JWK) -> JWK:
29+
_jwe = JWE()
30+
message = _jwe.decrypt(token=jwe, keys=wrapping_keys)
31+
return key_from_jwk_dict(json.loads(message))

tests/test_10_jwk_wrap.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import os
2+
3+
from cryptojwt.jwk.ec import new_ec_key
4+
from cryptojwt.jwk.hmac import SYMKey
5+
from cryptojwt.jwk.rsa import new_rsa_key
6+
from cryptojwt.jwk.wrap import wrap_key, unwrap_key
7+
8+
__author__ = 'jschlyter'
9+
10+
WRAPPING_KEYS = [
11+
SYMKey(use="enc", key=os.urandom(32)),
12+
new_ec_key(crv="P-256"),
13+
new_ec_key(crv="P-384"),
14+
new_rsa_key(size=2048),
15+
new_rsa_key(size=4096),
16+
]
17+
18+
SECRET_KEYS = [
19+
SYMKey(use="enc", key=os.urandom(32)),
20+
new_ec_key(crv="P-256"),
21+
new_rsa_key(size=2048),
22+
]
23+
24+
25+
def test_wrap_default():
26+
for wrapping_key in WRAPPING_KEYS:
27+
for key in SECRET_KEYS:
28+
wrapped_key = wrap_key(key, wrapping_key)
29+
unwrapped_key = unwrap_key(wrapped_key, [wrapping_key])
30+
assert key == unwrapped_key
31+
32+
def test_wrap_params():
33+
wrap_params = {
34+
"EC": {"alg": "ECDH-ES+A256KW", "enc": "A256GCM"},
35+
"RSA": {"alg": "RSA1_5", "enc": "A256CBC-HS512"},
36+
"oct": {"alg": "A256KW", "enc": "A256CBC-HS512"},
37+
}
38+
for wrapping_key in WRAPPING_KEYS:
39+
for key in SECRET_KEYS:
40+
wrapped_key = wrap_key(key, wrapping_key, wrap_params)
41+
unwrapped_key = unwrap_key(wrapped_key, [wrapping_key])
42+
assert key == unwrapped_key

0 commit comments

Comments
 (0)