Skip to content

Commit 8a60d1a

Browse files
Merge pull request openwallet-foundation#1115 from andrewwhitehead/refactor/jwe
Add generic JWE envelope handling
2 parents 5512566 + 0724642 commit 8a60d1a

File tree

6 files changed

+471
-203
lines changed

6 files changed

+471
-203
lines changed

aries_cloudagent/utils/jwe.py

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
"""JSON Web Encryption utilities."""
2+
3+
import json
4+
5+
from collections import OrderedDict
6+
from typing import Any, Iterable, List, Mapping, Optional, Union
7+
8+
from marshmallow import fields, Schema, ValidationError
9+
10+
from ..wallet.util import b64_to_bytes, bytes_to_b64
11+
12+
IDENT_ENC_KEY = "encrypted_key"
13+
IDENT_HEADER = "header"
14+
IDENT_PROTECTED = "protected"
15+
IDENT_RECIPIENTS = "recipients"
16+
17+
18+
def b64url(value: Union[bytes, str]) -> str:
19+
"""Encode a string or bytes value as unpadded base64-URL."""
20+
if isinstance(value, str):
21+
value = value.encode("utf-8")
22+
return bytes_to_b64(value, urlsafe=True, pad=False)
23+
24+
25+
def from_b64url(value: str) -> bytes:
26+
"""Decode an unpadded base64-URL value."""
27+
return b64_to_bytes(value, urlsafe=True)
28+
29+
30+
class B64Value(fields.Str):
31+
"""A marshmallow-compatible wrapper for base64-URL values."""
32+
33+
def _serialize(self, value, attr, obj, **kwargs) -> Optional[str]:
34+
if value is None:
35+
return None
36+
if not isinstance(value, bytes):
37+
return TypeError("Expected bytes")
38+
return b64url(value)
39+
40+
def _deserialize(self, value, attr, data, **kwargs) -> Any:
41+
value = super()._deserialize(value, attr, data, **kwargs)
42+
return from_b64url(value)
43+
44+
45+
class JweSchema(Schema):
46+
"""JWE envelope schema."""
47+
48+
protected = fields.Str(required=True)
49+
unprotected = fields.Dict(required=False)
50+
ciphertext = B64Value(required=True)
51+
iv = B64Value(required=True)
52+
tag = B64Value(required=True)
53+
aad = B64Value(required=False)
54+
# flattened:
55+
header = fields.Dict(required=False)
56+
encrypted_key = B64Value(required=False)
57+
58+
59+
class JweRecipientSchema(Schema):
60+
"""JWE recipient schema."""
61+
62+
encrypted_key = B64Value(required=True)
63+
header = fields.Dict(many=True, required=False)
64+
65+
66+
class JweRecipient:
67+
"""A single message recipient."""
68+
69+
def __init__(self, *, encrypted_key: bytes, header: dict = None) -> "JweRecipient":
70+
"""Initialize the JWE recipient."""
71+
self.encrypted_key = encrypted_key
72+
self.header = header or {}
73+
74+
@classmethod
75+
def deserialize(cls, entry: Mapping[str, Any]) -> "JweRecipient":
76+
"""Deserialize a JWE recipient from a mapping."""
77+
vals = JweRecipientSchema().load(entry)
78+
return cls(**vals)
79+
80+
def serialize(self) -> dict:
81+
"""Serialize the JWE recipient to a mapping."""
82+
ret = OrderedDict([("encrypted_key", b64url(self.encrypted_key))])
83+
if self.header:
84+
ret["header"] = self.header
85+
return ret
86+
87+
88+
class JweEnvelope:
89+
"""JWE envelope instance."""
90+
91+
def __init__(
92+
self,
93+
*,
94+
protected: dict = None,
95+
protected_b64: bytes = None,
96+
unprotected: dict = None,
97+
ciphertext: bytes = None,
98+
iv: bytes = None,
99+
tag: bytes = None,
100+
aad: bytes = None,
101+
):
102+
"""Initialize a new JWE envelope instance."""
103+
self.protected = protected
104+
self.protected_b64 = protected_b64
105+
self.unprotected = unprotected or OrderedDict()
106+
self.ciphertext = ciphertext
107+
self.iv = iv
108+
self.tag = tag
109+
self.aad = aad
110+
self._recipients: List[JweRecipient] = []
111+
112+
@classmethod
113+
def from_json(cls, message: Union[bytes, str]) -> "JweEnvelope":
114+
"""Decode a JWE envelope from a JSON string or bytes value."""
115+
return cls._deserialize(JweSchema().loads(message))
116+
117+
@classmethod
118+
def deserialize(cls, message: Mapping[str, Any]) -> "JweEnvelope":
119+
"""Deserialize a JWE envelope from a mapping."""
120+
return cls._deserialize(JweSchema().load(message))
121+
122+
@classmethod
123+
def _deserialize(cls, parsed: Mapping[str, Any]) -> "JweEnvelope":
124+
protected_b64 = parsed[IDENT_PROTECTED]
125+
try:
126+
protected: dict = json.loads(from_b64url(protected_b64))
127+
except json.JSONDecodeError:
128+
raise ValidationError(
129+
"Invalid JWE: invalid JSON for protected headers"
130+
) from None
131+
unprotected = parsed.get("unprotected") or dict()
132+
if protected.keys() & unprotected.keys():
133+
raise ValidationError("Invalid JWE: duplicate header")
134+
135+
if IDENT_RECIPIENTS in protected:
136+
recips = [
137+
JweRecipient.deserialize(recip)
138+
for recip in protected.pop(IDENT_RECIPIENTS)
139+
]
140+
if IDENT_ENC_KEY in protected or IDENT_HEADER in protected:
141+
raise ValidationError("Invalid JWE: flattened form with recipients")
142+
else:
143+
if IDENT_ENC_KEY not in protected:
144+
raise ValidationError("Invalid JWE: no recipients")
145+
header = protected.pop(IDENT_HEADER) if IDENT_HEADER in protected else None
146+
recips = [
147+
JweRecipient(
148+
encrypted_key=from_b64url(protected.pop(IDENT_ENC_KEY)),
149+
header=header,
150+
)
151+
]
152+
153+
inst = cls(
154+
protected=protected,
155+
protected_b64=protected_b64,
156+
unprotected=unprotected,
157+
ciphertext=parsed["ciphertext"],
158+
iv=parsed.get("iv"),
159+
tag=parsed["tag"],
160+
aad=parsed.get("aad"),
161+
)
162+
all_h = protected.keys() | unprotected.keys()
163+
for recip in recips:
164+
if recip.header and recip.header.keys() & all_h:
165+
raise ValidationError("Invalid JWE: duplicate header")
166+
inst.add_recipient(recip)
167+
168+
return inst
169+
170+
def serialize(self) -> dict:
171+
"""Serialize the JWE envelope to a mapping."""
172+
if self.protected_b64 is None:
173+
raise ValidationError("Missing protected: use set_protected")
174+
if self.ciphertext is None:
175+
raise ValidationError("Missing ciphertext for JWE")
176+
if self.iv is None:
177+
raise ValidationError("Missing iv (nonce) for JWE")
178+
if self.tag is None:
179+
raise ValidationError("Missing tag for JWE")
180+
env = OrderedDict()
181+
env["protected"] = self.protected_b64
182+
if self.unprotected:
183+
env["unprotected"] = self.unprotected
184+
env["iv"] = b64url(self.iv)
185+
env["ciphertext"] = b64url(self.ciphertext)
186+
env["tag"] = b64url(self.tag)
187+
if self.aad:
188+
env["aad"] = b64url(self.aad)
189+
return env
190+
191+
def to_json(self) -> str:
192+
"""Serialize the JWE envelope to a JSON string."""
193+
return json.dumps(self.serialize())
194+
195+
def add_recipient(self, recip: JweRecipient):
196+
"""Add a recipient to the JWE envelope."""
197+
self._recipients.append(recip)
198+
199+
def set_protected(self, protected: Mapping[str, Any], auto_flatten: bool = True):
200+
"""Set the protected headers of the JWE envelope.
201+
202+
This method must be called after adding the message recipients,
203+
or with the recipients pre-encoded and added to the headers.
204+
"""
205+
protected = OrderedDict(protected.items())
206+
have_recips = IDENT_RECIPIENTS in protected or IDENT_ENC_KEY in protected
207+
if not have_recips:
208+
recipients = [recip.serialize() for recip in self._recipients]
209+
if auto_flatten and len(recipients) == 1:
210+
protected[IDENT_ENC_KEY] = recipients[0]["encrypted_key"]
211+
if "header" in recipients[0]:
212+
protected[IDENT_HEADER] = recipients[0]["header"]
213+
elif recipients:
214+
protected[IDENT_RECIPIENTS] = recipients
215+
else:
216+
raise ValidationError("Missing message recipients")
217+
self.protected_b64 = b64url(json.dumps(protected))
218+
219+
@property
220+
def protected_bytes(self) -> bytes:
221+
"""Access the protected data encoded as bytes.
222+
223+
This value is used in the additional authenticated data when encrypting.
224+
"""
225+
return (
226+
self.protected_b64.encode("utf-8")
227+
if self.protected_b64 is not None
228+
else None
229+
)
230+
231+
def set_payload(self, ciphertext: bytes, iv: bytes, tag: bytes, aad: bytes = None):
232+
"""Set the payload of the JWE envelope."""
233+
self.ciphertext = ciphertext
234+
self.iv = iv
235+
self.tag = tag
236+
self.aad = aad
237+
238+
def recipients(self) -> Iterable[JweRecipient]:
239+
"""Accessor for an iterator over the JWE recipients.
240+
241+
The headers for each recipient include protected and unprotected headers from the
242+
outer envelope.
243+
"""
244+
header = self.protected.copy()
245+
header.update(self.unprotected)
246+
for recip in self._recipients:
247+
if recip.header:
248+
recip_h = header.copy()
249+
recip_h.update(recip.header)
250+
yield JweRecipient(encrypted_key=recip.encrypted_key, header=recip_h)
251+
else:
252+
yield JweRecipient(encrypted_key=recip.encrypted_key, header=header)
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import json
2+
3+
from unittest import TestCase
4+
5+
from ..jwe import b64url, JweEnvelope, JweRecipient, from_b64url
6+
7+
IV = b"test nonce"
8+
TAG = b"test tag"
9+
AAD = b"test aad"
10+
CIPHERTEXT = b"test ciphertext"
11+
ENC_KEY_1 = b"test enc key 1"
12+
ENC_KEY_2 = b"test enc key 2"
13+
PARAMS = {"alg": "MyAlg"}
14+
UNPROTECTED = {"abc": "ABC"}
15+
16+
17+
class TestJwe(TestCase):
18+
def test_envelope_load_single_recipient(self):
19+
protected = PARAMS.copy()
20+
protected.update(
21+
{
22+
# flattened single recipient
23+
"header": {"def": "DEF"},
24+
"encrypted_key": b64url(ENC_KEY_1),
25+
}
26+
)
27+
message = {
28+
"protected": b64url(json.dumps(protected)),
29+
"unprotected": UNPROTECTED.copy(),
30+
"iv": b64url(IV),
31+
"ciphertext": b64url(CIPHERTEXT),
32+
"tag": b64url(TAG),
33+
"aad": b64url(AAD),
34+
}
35+
loaded = JweEnvelope.deserialize(message)
36+
37+
assert loaded.protected == PARAMS
38+
assert loaded.unprotected == UNPROTECTED
39+
assert loaded.iv == IV
40+
assert loaded.tag == TAG
41+
assert loaded.aad == AAD
42+
assert loaded.ciphertext == CIPHERTEXT
43+
44+
recips = list(loaded.recipients())
45+
assert len(recips) == 1
46+
assert recips[0].encrypted_key == ENC_KEY_1
47+
assert recips[0].header == {"alg": "MyAlg", "abc": "ABC", "def": "DEF"}
48+
49+
def test_envelope_load_multiple_recipients(self):
50+
protected = PARAMS.copy()
51+
protected.update(
52+
{
53+
"recipients": [
54+
{"header": {"def": "DEF"}, "encrypted_key": b64url(ENC_KEY_1)},
55+
{"header": {"ghi": "GHI"}, "encrypted_key": b64url(ENC_KEY_2)},
56+
]
57+
}
58+
)
59+
message = {
60+
"protected": b64url(json.dumps(protected)),
61+
"unprotected": UNPROTECTED.copy(),
62+
"iv": b64url(IV),
63+
"ciphertext": b64url(CIPHERTEXT),
64+
"tag": b64url(TAG),
65+
"aad": b64url(AAD),
66+
}
67+
loaded = JweEnvelope.deserialize(message)
68+
69+
assert loaded.protected == PARAMS
70+
assert loaded.unprotected == UNPROTECTED
71+
assert loaded.iv == IV
72+
assert loaded.tag == TAG
73+
assert loaded.aad == AAD
74+
assert loaded.ciphertext == CIPHERTEXT
75+
76+
recips = list(loaded.recipients())
77+
assert len(recips) == 2
78+
assert recips[0].encrypted_key == ENC_KEY_1
79+
assert recips[0].header == {"alg": "MyAlg", "abc": "ABC", "def": "DEF"}
80+
assert recips[1].encrypted_key == ENC_KEY_2
81+
assert recips[1].header == {"alg": "MyAlg", "abc": "ABC", "ghi": "GHI"}
82+
83+
def test_envelope_serialize_single_recipient(self):
84+
env = JweEnvelope(
85+
unprotected=UNPROTECTED.copy(),
86+
iv=IV,
87+
ciphertext=CIPHERTEXT,
88+
tag=TAG,
89+
aad=AAD,
90+
)
91+
env.add_recipient(JweRecipient(encrypted_key=ENC_KEY_1, header={"def": "DEF"}))
92+
env.set_protected(PARAMS)
93+
message = env.to_json()
94+
loaded = JweEnvelope.from_json(message)
95+
96+
# check in flattened form
97+
prot = json.loads(from_b64url(loaded.protected_b64))
98+
assert "encrypted_key" in prot
99+
100+
assert loaded.protected == PARAMS
101+
assert loaded.unprotected == UNPROTECTED
102+
assert loaded.iv == IV
103+
assert loaded.tag == TAG
104+
assert loaded.aad == AAD
105+
assert loaded.ciphertext == CIPHERTEXT
106+
107+
recips = list(loaded.recipients())
108+
assert len(recips) == 1
109+
assert recips[0].encrypted_key == ENC_KEY_1
110+
assert recips[0].header == {"alg": "MyAlg", "abc": "ABC", "def": "DEF"}
111+
112+
def test_envelope_serialize_multiple_recipients(self):
113+
env = JweEnvelope(
114+
unprotected=UNPROTECTED.copy(),
115+
iv=IV,
116+
ciphertext=CIPHERTEXT,
117+
tag=TAG,
118+
aad=AAD,
119+
)
120+
env.add_recipient(JweRecipient(encrypted_key=ENC_KEY_1, header={"def": "DEF"}))
121+
env.add_recipient(JweRecipient(encrypted_key=ENC_KEY_2, header={"ghi": "GHI"}))
122+
env.set_protected(PARAMS)
123+
message = env.to_json()
124+
loaded = JweEnvelope.from_json(message)
125+
126+
assert loaded.protected == PARAMS
127+
assert loaded.unprotected == UNPROTECTED
128+
assert loaded.iv == IV
129+
assert loaded.tag == TAG
130+
assert loaded.aad == AAD
131+
assert loaded.ciphertext == CIPHERTEXT
132+
133+
recips = list(loaded.recipients())
134+
assert len(recips) == 2
135+
assert recips[0].encrypted_key == ENC_KEY_1
136+
assert recips[0].header == {"alg": "MyAlg", "abc": "ABC", "def": "DEF"}
137+
assert recips[1].encrypted_key == ENC_KEY_2
138+
assert recips[1].header == {"alg": "MyAlg", "abc": "ABC", "ghi": "GHI"}

0 commit comments

Comments
 (0)