Skip to content

Commit ddc626d

Browse files
committed
feat: add jwt framing
Signed-off-by: Daniel Bluhm <[email protected]>
1 parent 2cf75d2 commit ddc626d

File tree

2 files changed

+177
-12
lines changed

2 files changed

+177
-12
lines changed

src/token_status_list/__init__.py

+89-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
"""
88

99
import base64
10-
from typing import Literal, Union
10+
import json
11+
from time import time
12+
from typing import Literal, Optional, Union
1113
import zlib
1214

1315

@@ -25,6 +27,11 @@ def b64url_encode(value: bytes) -> bytes:
2527
return base64.urlsafe_b64encode(value).rstrip(b"=")
2628

2729

30+
def dict_to_b64(value: dict) -> bytes:
31+
"""Transform a dictionary into base64url encoded json dump of dictionary."""
32+
return b64url_encode(json.dumps(value, separators=(",", ":")).encode())
33+
34+
2835
VALID = 0x00
2936
INVALID = 0x01
3037
SUSPENDED = 0x02
@@ -77,6 +84,10 @@ def of_size(cls, bits: Bits, size: int) -> "TokenStatusList":
7784
@classmethod
7885
def with_at_least(cls, bits: Bits, size: int):
7986
"""Create an empty list large enough to accommodate at least the given size."""
87+
# Determine minimum number of bytes to fit size
88+
# This is essentially a fast ceil(n / 2^x)
89+
length = (size + cls.PER_BYTE[bits] - 1) >> cls.SHIFT_BY[bits]
90+
return cls(bits, bytearray(length))
8091

8192
def __getitem__(self, index: int):
8293
"""Retrieve the status of an index."""
@@ -151,3 +162,80 @@ def deserialize(cls, value: dict) -> "TokenStatusList":
151162

152163
parsed_lst = zlib.decompress(b64url_decode(lst.encode()))
153164
return cls(bits, parsed_lst)
165+
166+
def sign_payload(
167+
self,
168+
*,
169+
alg: str,
170+
kid: str,
171+
iss: str,
172+
sub: str,
173+
iat: Optional[int] = None,
174+
exp: Optional[int] = None,
175+
ttl: Optional[int] = None,
176+
) -> bytes:
177+
"""Create a Status List Token payload for signing.
178+
179+
Signing is NOT performed by this function; only the payload to the signature is
180+
prepared. The caller is responsible for producing a signature.
181+
182+
Args:
183+
alg: REQUIRED. The algorithm to be used to sign the payload.
184+
185+
kid: REQUIRED. The kid used to sign the payload.
186+
187+
iss: REQUIRED when also present in the Referenced Token. The iss (issuer)
188+
claim MUST specify a unique string identifier for the entity that issued
189+
the Status List Token. In the absence of an application profile specifying
190+
otherwise, compliant applications MUST compare issuer values using the
191+
Simple String Comparison method defined in Section 6.2.1 of [RFC3986].
192+
The value MUST be equal to that of the iss claim contained within the
193+
Referenced Token.
194+
195+
sub: REQUIRED. The sub (subject) claim MUST specify a unique string identifier
196+
for the Status List Token. The value MUST be equal to that of the uri
197+
claim contained in the status_list claim of the Referenced Token.
198+
199+
iat: REQUIRED. The iat (issued at) claim MUST specify the time at which the
200+
Status List Token was issued.
201+
202+
exp: OPTIONAL. The exp (expiration time) claim, if present, MUST specify the
203+
time at which the Status List Token is considered expired by its issuer.
204+
205+
ttl: OPTIONAL. The ttl (time to live) claim, if present, MUST specify the
206+
maximum amount of time, in seconds, that the Status List Token can be
207+
cached by a consumer before a fresh copy SHOULD be retrieved. The value
208+
of the claim MUST be a positive number.
209+
"""
210+
headers = {
211+
"typ": "statuslist+jwt",
212+
"alg": alg,
213+
"kid": kid,
214+
}
215+
payload = {
216+
"iss": iss,
217+
"sub": sub,
218+
"iat": iat or int(time()),
219+
"status_list": self.serialize(),
220+
}
221+
if exp is not None:
222+
payload["exp"] = exp
223+
224+
if ttl is not None:
225+
payload["ttl"] = ttl
226+
227+
enc_headers = dict_to_b64(headers).decode()
228+
enc_payload = dict_to_b64(payload).decode()
229+
return f"{enc_headers}.{enc_payload}".encode()
230+
231+
def signed_token(self, signed_payload: bytes, signature: bytes) -> str:
232+
"""Finish creating a signed token.
233+
234+
Args:
235+
signed_payload: The value returned from `sign_payload`.
236+
signature: The signature over the signed_payload in bytes.
237+
238+
Returns:
239+
Finished Status List Token.
240+
"""
241+
return f"{signed_payload.decode()}.{b64url_encode(signature)}"

tests/test_token_status_list.py

+88-11
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Test TokenStatusList."""
22

3+
import json
34
import pytest
5+
from secrets import randbelow
46

57
from token_status_list import INVALID, SUSPENDED, TokenStatusList, VALID, b64url_decode
68

@@ -151,31 +153,44 @@ def test_performance(bits: int):
151153
print("Bits:", bits)
152154

153155
# Create a large TokenStatusList
154-
size = 1000000 # Number of tokens
155-
token_list = TokenStatusList.of_size(bits, size)
156+
size = 1000000 # Number of indices
157+
status_list = TokenStatusList.of_size(bits, size)
158+
159+
# Generate random statuses
160+
statuses = []
161+
while len(statuses) < size:
162+
run = randbelow(10)
163+
status = randbelow(2)
164+
statuses.extend([status] * run)
165+
166+
diff = len(statuses) - size
167+
if diff > 1:
168+
for _ in range(diff + 1):
169+
statuses.pop()
156170

157171
# Test setting values
158172
start_time = time.time()
159-
for i in range(size):
160-
token_list[i] = VALID if i % 2 == 0 else INVALID
173+
for i, status in enumerate(statuses):
174+
status_list[i] = status
161175
end_time = time.time()
162-
print(f"Time to set {size} tokens: {end_time - start_time} seconds")
176+
print(f"Time to set {size} indices: {end_time - start_time:.3f} seconds")
163177

164178
# Test getting values
165179
start_time = time.time()
166180
for i in range(size):
167-
status = token_list[i]
181+
status = status_list[i]
168182
end_time = time.time()
169-
print(f"Time to get {size} tokens: {end_time - start_time} seconds")
183+
print(f"Time to get {size} indices: {end_time - start_time:.3f} seconds")
170184

171185
# Test compression
172186
start_time = time.time()
173-
compressed_data = token_list.compressed()
187+
compressed_data = status_list.compressed()
174188
end_time = time.time()
175-
print(f"Time to compress: {end_time - start_time} seconds")
176-
print(f"Original length: {len(token_list.lst)} bytes")
189+
print(f"Time to compress: {end_time - start_time:.3f} seconds")
190+
print(f"Original length: {len(status_list.lst)} bytes")
177191
print(f"Compressed length: {len(compressed_data)} bytes")
178-
print(f"Compression ratio: {len(compressed_data) / len(token_list.lst) * 100:.3f}%")
192+
print(f"Compression ratio: {len(compressed_data) / len(status_list.lst) * 100:.3f}%")
193+
# print(f"List in hex: {status_list.lst.hex()}")
179194

180195

181196
def test_serde():
@@ -200,3 +215,65 @@ def test_invalid_to_valid():
200215
assert status[7] == INVALID
201216
with pytest.raises(ValueError):
202217
status[7] = 0x00
218+
219+
220+
def test_of_size():
221+
with pytest.raises(ValueError):
222+
status = TokenStatusList.of_size(1, 3)
223+
with pytest.raises(ValueError):
224+
status = TokenStatusList.of_size(2, 21)
225+
with pytest.raises(ValueError):
226+
status = TokenStatusList.of_size(4, 31)
227+
228+
# Lists with bits 8 can have arbitrary size since there's no byte
229+
# boundaries to worry about
230+
status = TokenStatusList.of_size(8, 31)
231+
assert len(status.lst) == 31
232+
233+
status = TokenStatusList.of_size(1, 8)
234+
assert len(status.lst) == 1
235+
status = TokenStatusList.of_size(1, 16)
236+
assert len(status.lst) == 2
237+
status = TokenStatusList.of_size(1, 24)
238+
assert len(status.lst) == 3
239+
status = TokenStatusList.of_size(8, 24)
240+
assert len(status.lst) == 24
241+
242+
243+
def test_with_at_least():
244+
status = TokenStatusList.with_at_least(1, 3)
245+
assert len(status.lst) == 1
246+
status = TokenStatusList.with_at_least(2, 21)
247+
assert len(status.lst) == 6
248+
status = TokenStatusList.with_at_least(4, 31)
249+
assert len(status.lst) == 16
250+
251+
status = TokenStatusList.with_at_least(1, 8)
252+
assert len(status.lst) == 1
253+
status = TokenStatusList.with_at_least(2, 24)
254+
assert len(status.lst) == 6
255+
status = TokenStatusList.with_at_least(4, 32)
256+
assert len(status.lst) == 16
257+
258+
259+
def test_sign_payload():
260+
status = TokenStatusList(1, b"\xb9\xa3")
261+
payload = status.sign_payload(
262+
alg="ES256",
263+
kid="12",
264+
iss="https://example.com",
265+
sub="https://example.com/statuslists/1",
266+
iat=1686920170,
267+
exp=2291720170,
268+
)
269+
headers, payload = payload.split(b".")
270+
headers = json.loads(b64url_decode(headers))
271+
payload = json.loads(b64url_decode(payload))
272+
assert headers == {"alg": "ES256", "kid": "12", "typ": "statuslist+jwt"}
273+
assert payload == {
274+
"exp": 2291720170,
275+
"iat": 1686920170,
276+
"iss": "https://example.com",
277+
"status_list": {"bits": 1, "lst": "eNrbuRgAAhcBXQ"},
278+
"sub": "https://example.com/statuslists/1",
279+
}

0 commit comments

Comments
 (0)