|
| 1 | +"""JSON Web Encryption utilities.""" |
| 2 | + |
| 3 | +import json |
| 4 | + |
| 5 | +from collections import OrderedDict |
| 6 | +from typing import Any, Iterable, List, Mapping, Optional, Union |
| 7 | + |
| 8 | +from marshmallow import fields, Schema, ValidationError |
| 9 | + |
| 10 | +from ..wallet.util import b64_to_bytes, bytes_to_b64 |
| 11 | + |
| 12 | +IDENT_ENC_KEY = "encrypted_key" |
| 13 | +IDENT_HEADER = "header" |
| 14 | +IDENT_PROTECTED = "protected" |
| 15 | +IDENT_RECIPIENTS = "recipients" |
| 16 | + |
| 17 | + |
| 18 | +def b64url(value: Union[bytes, str]) -> str: |
| 19 | + """Encode a string or bytes value as unpadded base64-URL.""" |
| 20 | + if isinstance(value, str): |
| 21 | + value = value.encode("utf-8") |
| 22 | + return bytes_to_b64(value, urlsafe=True, pad=False) |
| 23 | + |
| 24 | + |
| 25 | +def from_b64url(value: str) -> bytes: |
| 26 | + """Decode an unpadded base64-URL value.""" |
| 27 | + return b64_to_bytes(value, urlsafe=True) |
| 28 | + |
| 29 | + |
| 30 | +class B64Value(fields.Str): |
| 31 | + """A marshmallow-compatible wrapper for base64-URL values.""" |
| 32 | + |
| 33 | + def _serialize(self, value, attr, obj, **kwargs) -> Optional[str]: |
| 34 | + if value is None: |
| 35 | + return None |
| 36 | + if not isinstance(value, bytes): |
| 37 | + return TypeError("Expected bytes") |
| 38 | + return b64url(value) |
| 39 | + |
| 40 | + def _deserialize(self, value, attr, data, **kwargs) -> Any: |
| 41 | + value = super()._deserialize(value, attr, data, **kwargs) |
| 42 | + return from_b64url(value) |
| 43 | + |
| 44 | + |
| 45 | +class JweSchema(Schema): |
| 46 | + """JWE envelope schema.""" |
| 47 | + |
| 48 | + protected = fields.Str(required=True) |
| 49 | + unprotected = fields.Dict(required=False) |
| 50 | + ciphertext = B64Value(required=True) |
| 51 | + iv = B64Value(required=True) |
| 52 | + tag = B64Value(required=True) |
| 53 | + aad = B64Value(required=False) |
| 54 | + # flattened: |
| 55 | + header = fields.Dict(required=False) |
| 56 | + encrypted_key = B64Value(required=False) |
| 57 | + |
| 58 | + |
| 59 | +class JweRecipientSchema(Schema): |
| 60 | + """JWE recipient schema.""" |
| 61 | + |
| 62 | + encrypted_key = B64Value(required=True) |
| 63 | + header = fields.Dict(many=True, required=False) |
| 64 | + |
| 65 | + |
| 66 | +class JweRecipient: |
| 67 | + """A single message recipient.""" |
| 68 | + |
| 69 | + def __init__(self, *, encrypted_key: bytes, header: dict = None) -> "JweRecipient": |
| 70 | + """Initialize the JWE recipient.""" |
| 71 | + self.encrypted_key = encrypted_key |
| 72 | + self.header = header or {} |
| 73 | + |
| 74 | + @classmethod |
| 75 | + def deserialize(cls, entry: Mapping[str, Any]) -> "JweRecipient": |
| 76 | + """Deserialize a JWE recipient from a mapping.""" |
| 77 | + vals = JweRecipientSchema().load(entry) |
| 78 | + return cls(**vals) |
| 79 | + |
| 80 | + def serialize(self) -> dict: |
| 81 | + """Serialize the JWE recipient to a mapping.""" |
| 82 | + ret = OrderedDict([("encrypted_key", b64url(self.encrypted_key))]) |
| 83 | + if self.header: |
| 84 | + ret["header"] = self.header |
| 85 | + return ret |
| 86 | + |
| 87 | + |
| 88 | +class JweEnvelope: |
| 89 | + """JWE envelope instance.""" |
| 90 | + |
| 91 | + def __init__( |
| 92 | + self, |
| 93 | + *, |
| 94 | + protected: dict = None, |
| 95 | + protected_b64: bytes = None, |
| 96 | + unprotected: dict = None, |
| 97 | + ciphertext: bytes = None, |
| 98 | + iv: bytes = None, |
| 99 | + tag: bytes = None, |
| 100 | + aad: bytes = None, |
| 101 | + ): |
| 102 | + """Initialize a new JWE envelope instance.""" |
| 103 | + self.protected = protected |
| 104 | + self.protected_b64 = protected_b64 |
| 105 | + self.unprotected = unprotected or OrderedDict() |
| 106 | + self.ciphertext = ciphertext |
| 107 | + self.iv = iv |
| 108 | + self.tag = tag |
| 109 | + self.aad = aad |
| 110 | + self._recipients: List[JweRecipient] = [] |
| 111 | + |
| 112 | + @classmethod |
| 113 | + def from_json(cls, message: Union[bytes, str]) -> "JweEnvelope": |
| 114 | + """Decode a JWE envelope from a JSON string or bytes value.""" |
| 115 | + return cls._deserialize(JweSchema().loads(message)) |
| 116 | + |
| 117 | + @classmethod |
| 118 | + def deserialize(cls, message: Mapping[str, Any]) -> "JweEnvelope": |
| 119 | + """Deserialize a JWE envelope from a mapping.""" |
| 120 | + return cls._deserialize(JweSchema().load(message)) |
| 121 | + |
| 122 | + @classmethod |
| 123 | + def _deserialize(cls, parsed: Mapping[str, Any]) -> "JweEnvelope": |
| 124 | + protected_b64 = parsed[IDENT_PROTECTED] |
| 125 | + try: |
| 126 | + protected: dict = json.loads(from_b64url(protected_b64)) |
| 127 | + except json.JSONDecodeError: |
| 128 | + raise ValidationError( |
| 129 | + "Invalid JWE: invalid JSON for protected headers" |
| 130 | + ) from None |
| 131 | + unprotected = parsed.get("unprotected") or dict() |
| 132 | + if protected.keys() & unprotected.keys(): |
| 133 | + raise ValidationError("Invalid JWE: duplicate header") |
| 134 | + |
| 135 | + if IDENT_RECIPIENTS in protected: |
| 136 | + recips = [ |
| 137 | + JweRecipient.deserialize(recip) |
| 138 | + for recip in protected.pop(IDENT_RECIPIENTS) |
| 139 | + ] |
| 140 | + if IDENT_ENC_KEY in protected or IDENT_HEADER in protected: |
| 141 | + raise ValidationError("Invalid JWE: flattened form with recipients") |
| 142 | + else: |
| 143 | + if IDENT_ENC_KEY not in protected: |
| 144 | + raise ValidationError("Invalid JWE: no recipients") |
| 145 | + header = protected.pop(IDENT_HEADER) if IDENT_HEADER in protected else None |
| 146 | + recips = [ |
| 147 | + JweRecipient( |
| 148 | + encrypted_key=from_b64url(protected.pop(IDENT_ENC_KEY)), |
| 149 | + header=header, |
| 150 | + ) |
| 151 | + ] |
| 152 | + |
| 153 | + inst = cls( |
| 154 | + protected=protected, |
| 155 | + protected_b64=protected_b64, |
| 156 | + unprotected=unprotected, |
| 157 | + ciphertext=parsed["ciphertext"], |
| 158 | + iv=parsed.get("iv"), |
| 159 | + tag=parsed["tag"], |
| 160 | + aad=parsed.get("aad"), |
| 161 | + ) |
| 162 | + all_h = protected.keys() | unprotected.keys() |
| 163 | + for recip in recips: |
| 164 | + if recip.header and recip.header.keys() & all_h: |
| 165 | + raise ValidationError("Invalid JWE: duplicate header") |
| 166 | + inst.add_recipient(recip) |
| 167 | + |
| 168 | + return inst |
| 169 | + |
| 170 | + def serialize(self) -> dict: |
| 171 | + """Serialize the JWE envelope to a mapping.""" |
| 172 | + if self.protected_b64 is None: |
| 173 | + raise ValidationError("Missing protected: use set_protected") |
| 174 | + if self.ciphertext is None: |
| 175 | + raise ValidationError("Missing ciphertext for JWE") |
| 176 | + if self.iv is None: |
| 177 | + raise ValidationError("Missing iv (nonce) for JWE") |
| 178 | + if self.tag is None: |
| 179 | + raise ValidationError("Missing tag for JWE") |
| 180 | + env = OrderedDict() |
| 181 | + env["protected"] = self.protected_b64 |
| 182 | + if self.unprotected: |
| 183 | + env["unprotected"] = self.unprotected |
| 184 | + env["iv"] = b64url(self.iv) |
| 185 | + env["ciphertext"] = b64url(self.ciphertext) |
| 186 | + env["tag"] = b64url(self.tag) |
| 187 | + if self.aad: |
| 188 | + env["aad"] = b64url(self.aad) |
| 189 | + return env |
| 190 | + |
| 191 | + def to_json(self) -> str: |
| 192 | + """Serialize the JWE envelope to a JSON string.""" |
| 193 | + return json.dumps(self.serialize()) |
| 194 | + |
| 195 | + def add_recipient(self, recip: JweRecipient): |
| 196 | + """Add a recipient to the JWE envelope.""" |
| 197 | + self._recipients.append(recip) |
| 198 | + |
| 199 | + def set_protected(self, protected: Mapping[str, Any], auto_flatten: bool = True): |
| 200 | + """Set the protected headers of the JWE envelope. |
| 201 | +
|
| 202 | + This method must be called after adding the message recipients, |
| 203 | + or with the recipients pre-encoded and added to the headers. |
| 204 | + """ |
| 205 | + protected = OrderedDict(protected.items()) |
| 206 | + have_recips = IDENT_RECIPIENTS in protected or IDENT_ENC_KEY in protected |
| 207 | + if not have_recips: |
| 208 | + recipients = [recip.serialize() for recip in self._recipients] |
| 209 | + if auto_flatten and len(recipients) == 1: |
| 210 | + protected[IDENT_ENC_KEY] = recipients[0]["encrypted_key"] |
| 211 | + if "header" in recipients[0]: |
| 212 | + protected[IDENT_HEADER] = recipients[0]["header"] |
| 213 | + elif recipients: |
| 214 | + protected[IDENT_RECIPIENTS] = recipients |
| 215 | + else: |
| 216 | + raise ValidationError("Missing message recipients") |
| 217 | + self.protected_b64 = b64url(json.dumps(protected)) |
| 218 | + |
| 219 | + @property |
| 220 | + def protected_bytes(self) -> bytes: |
| 221 | + """Access the protected data encoded as bytes. |
| 222 | +
|
| 223 | + This value is used in the additional authenticated data when encrypting. |
| 224 | + """ |
| 225 | + return ( |
| 226 | + self.protected_b64.encode("utf-8") |
| 227 | + if self.protected_b64 is not None |
| 228 | + else None |
| 229 | + ) |
| 230 | + |
| 231 | + def set_payload(self, ciphertext: bytes, iv: bytes, tag: bytes, aad: bytes = None): |
| 232 | + """Set the payload of the JWE envelope.""" |
| 233 | + self.ciphertext = ciphertext |
| 234 | + self.iv = iv |
| 235 | + self.tag = tag |
| 236 | + self.aad = aad |
| 237 | + |
| 238 | + def recipients(self) -> Iterable[JweRecipient]: |
| 239 | + """Accessor for an iterator over the JWE recipients. |
| 240 | +
|
| 241 | + The headers for each recipient include protected and unprotected headers from the |
| 242 | + outer envelope. |
| 243 | + """ |
| 244 | + header = self.protected.copy() |
| 245 | + header.update(self.unprotected) |
| 246 | + for recip in self._recipients: |
| 247 | + if recip.header: |
| 248 | + recip_h = header.copy() |
| 249 | + recip_h.update(recip.header) |
| 250 | + yield JweRecipient(encrypted_key=recip.encrypted_key, header=recip_h) |
| 251 | + else: |
| 252 | + yield JweRecipient(encrypted_key=recip.encrypted_key, header=header) |
0 commit comments