Skip to content

Commit 28808f0

Browse files
Defer event header encoding
This updates event header encoding to be deferred to until the event is written. This is necessary to allow signing or other post-ser modification to take place.
1 parent 40df093 commit 28808f0

File tree

3 files changed

+47
-121
lines changed

3 files changed

+47
-121
lines changed

packages/aws-event-stream/src/aws_event_stream/_private/serializers.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from smithy_core.shapes import ShapeType
2121
from smithy_event_stream.aio.interfaces import AsyncEventPublisher
2222

23-
from ..events import EventHeaderEncoder, EventMessage
23+
from ..events import EventMessage, HEADER_VALUE, Short, Byte, Long
2424
from ..exceptions import InvalidHeaderValue
2525
from . import (
2626
INITIAL_REQUEST_EVENT_TYPE,
@@ -100,30 +100,27 @@ def begin_struct(self, schema: "Schema") -> Iterator[ShapeSerializer]:
100100
finally:
101101
return
102102

103-
headers_encoder = EventHeaderEncoder()
103+
headers: dict[str, HEADER_VALUE] = {}
104104

105105
if ErrorTrait in schema:
106-
headers_encoder.encode_string(":message-type", "exception")
107-
headers_encoder.encode_string(
108-
":exception-type", schema.expect_member_name()
109-
)
106+
headers[":message-type"] = "exception"
107+
headers[":exception-type"] = schema.expect_member_name()
110108
else:
111-
headers_encoder.encode_string(":message-type", "event")
109+
headers[":message-type"] = "event"
112110
if schema.member_name is None:
113111
# If there's no member name, that must mean that the structure is
114112
# either an input or output structure, and so this represents the
115113
# initial message.
116-
headers_encoder.encode_string(
117-
":event-type", self._initial_message_event_type
118-
)
114+
headers[":event-type"] = self._initial_message_event_type
119115
else:
120-
headers_encoder.encode_string(":event-type", schema.member_name)
116+
headers[":event-type"] = schema.member_name
121117

122118
payload = BytesIO()
123119
payload_serializer: ShapeSerializer = self._payload_codec.create_serializer(
124120
payload
125121
)
126-
header_serializer = EventHeaderSerializer(headers_encoder)
122+
123+
header_serializer = EventHeaderSerializer(headers)
127124

128125
media_type = self._payload_codec.media_type
129126

@@ -138,11 +135,9 @@ def begin_struct(self, schema: "Schema") -> Iterator[ShapeSerializer]:
138135

139136
payload_bytes = payload.getvalue()
140137
if payload_bytes:
141-
headers_encoder.encode_string(":content-type", media_type)
138+
headers[":content-type"] = media_type
142139

143-
self._result = EventMessage(
144-
headers_bytes=headers_encoder.get_result(), payload=payload_bytes
145-
)
140+
self._result = EventMessage(headers=headers, payload=payload_bytes)
146141

147142
def _get_payload_media_type(self, schema: Schema, default: str) -> str:
148143
if (media_type := schema.get_trait(MediaTypeTrait)) is not None:
@@ -158,8 +153,8 @@ def _get_payload_media_type(self, schema: Schema, default: str) -> str:
158153

159154

160155
class EventHeaderSerializer(SpecificShapeSerializer):
161-
def __init__(self, encoder: EventHeaderEncoder) -> None:
162-
self._encoder = encoder
156+
def __init__(self, headers: dict[str, HEADER_VALUE]) -> None:
157+
self._headers = headers
163158

164159
def _invalid_state(
165160
self, schema: "Schema | None" = None, message: str | None = None
@@ -169,28 +164,28 @@ def _invalid_state(
169164
raise InvalidHeaderValue(message)
170165

171166
def write_boolean(self, schema: "Schema", value: bool) -> None:
172-
self._encoder.encode_boolean(schema.expect_member_name(), value)
167+
self._headers[schema.expect_member_name()] = value
173168

174169
def write_byte(self, schema: "Schema", value: int) -> None:
175-
self._encoder.encode_byte(schema.expect_member_name(), value)
170+
self._headers[schema.expect_member_name()] = Byte(value)
176171

177172
def write_short(self, schema: "Schema", value: int) -> None:
178-
self._encoder.encode_short(schema.expect_member_name(), value)
173+
self._headers[schema.expect_member_name()] = Short(value)
179174

180175
def write_integer(self, schema: "Schema", value: int) -> None:
181-
self._encoder.encode_integer(schema.expect_member_name(), value)
176+
self._headers[schema.expect_member_name()] = value
182177

183178
def write_long(self, schema: "Schema", value: int) -> None:
184-
self._encoder.encode_long(schema.expect_member_name(), value)
179+
self._headers[schema.expect_member_name()] = Long(value)
185180

186181
def write_string(self, schema: "Schema", value: str) -> None:
187-
self._encoder.encode_string(schema.expect_member_name(), value)
182+
self._headers[schema.expect_member_name()] = value
188183

189184
def write_blob(self, schema: "Schema", value: bytes) -> None:
190-
self._encoder.encode_blob(schema.expect_member_name(), value)
185+
self._headers[schema.expect_member_name()] = value
191186

192187
def write_timestamp(self, schema: "Schema", value: datetime.datetime) -> None:
193-
self._encoder.encode_timestamp(schema.expect_member_name(), value)
188+
self._headers[schema.expect_member_name()] = value
194189

195190

196191
class RawPayloadSerializer(SpecificShapeSerializer):

packages/aws-event-stream/src/aws_event_stream/events.py

Lines changed: 26 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import uuid
1313
from binascii import crc32
1414
from collections.abc import Callable, Iterator, Mapping
15-
from dataclasses import dataclass
15+
from dataclasses import dataclass, field
1616
from io import BytesIO
1717
from struct import pack, unpack
1818
from types import MappingProxyType
@@ -147,6 +147,7 @@ def __post_init__(self):
147147
raise InvalidPayloadLength(payload_length)
148148

149149

150+
@dataclass(kw_only=True, eq=False)
150151
class EventMessage:
151152
"""A message that may be sent over an event stream.
152153
@@ -186,76 +187,31 @@ class EventMessage:
186187
message.
187188
"""
188189

189-
def __init__(
190-
self,
191-
*,
192-
headers: HEADERS_DICT | None = None,
193-
headers_bytes: bytes | None = None,
194-
payload: bytes = b"",
195-
) -> None:
196-
"""Initialize an EventMessage.
197-
198-
:param headers: The headers present in the event message. If this parameter is
199-
unspecified, the default value will be the decoded value of the
200-
`headers_bytes` parameter.
201-
202-
Sized integer values may be indicated for the purpose of serialization
203-
using the `Byte`, `Short`, or `Long` types. int values of unspecified size
204-
will be assumed to be 32-bit.
205-
206-
:param headers_bytes: The serialized bytes of the headers present in the event
207-
message.
208-
209-
:param payload: The serialized bytes of the message payload.
210-
"""
211-
self._payload = payload
212-
self._headers_bytes = headers_bytes
213-
214-
if len(payload) > MAX_PAYLOAD_LENGTH:
215-
raise InvalidPayloadLength(len(payload))
216-
217-
if headers_bytes is None:
218-
if headers is None:
219-
headers = {}
220-
elif headers is None:
221-
headers = EventHeaderDecoder(headers_bytes).decode_headers()
190+
headers: HEADERS_DICT = field(default_factory=dict)
191+
"""The headers present in the event message.
222192
223-
self._headers = headers
224-
225-
@property
226-
def payload(self) -> bytes:
227-
"""The serialized bytes of the message payload.
228-
229-
These bytes may be in any format or media type. The `:content-type` header, if
230-
present, indicates the media type.
231-
"""
232-
return self._payload
193+
Sized integer values may be indicated for the purpose of serialization
194+
using the `Byte`, `Short`, or `Long` types. int values of unspecified size
195+
will be assumed to be 32-bit.
196+
"""
233197

234-
@property
235-
def headers(self) -> HEADERS_DICT:
236-
"""The headers of the event message.
198+
payload: bytes = b""
199+
"""The serialized bytes of the message payload."""
237200

238-
Headers prefixed with `:` contain metadata by convention.
239-
"""
240-
return self._headers
201+
def __post_init__(
202+
self,
203+
) -> None:
204+
if len(self.payload) > MAX_PAYLOAD_LENGTH:
205+
raise InvalidPayloadLength(len(self.payload))
241206

242207
def _get_headers_bytes(self) -> bytes:
243-
if self._headers_bytes is None:
244-
encoder = EventHeaderEncoder()
245-
encoder.encode_headers(self._headers)
246-
self._headers_bytes = encoder.get_result()
247-
248-
return self._headers_bytes
208+
encoder = EventHeaderEncoder()
209+
encoder.encode_headers(self.headers)
210+
return encoder.get_result()
249211

250212
def encode(self) -> bytes:
251213
return _EventEncoder().encode_bytes(
252-
headers=self._get_headers_bytes(), payload=self._payload
253-
)
254-
255-
def __repr__(self) -> str:
256-
return (
257-
f"EventMessage(payload={self._payload!r}, headers={self.headers!r}, "
258-
f"headers_bytes={self._get_headers_bytes()!r})"
214+
headers=self._get_headers_bytes(), payload=self.payload
259215
)
260216

261217
def __eq__(self, other: object) -> bool:
@@ -325,8 +281,9 @@ def decode(cls, source: BytesReader) -> Self | None:
325281

326282
_validate_checksum(prelude_crc_bytes + message_bytes, crc, prelude_crc)
327283

284+
headers_bytes = message_bytes[: prelude.headers_length]
328285
message = EventMessage(
329-
headers_bytes=message_bytes[: prelude.headers_length],
286+
headers=EventHeaderDecoder(headers_bytes).decode_headers(),
330287
payload=message_bytes[prelude.headers_length :],
331288
)
332289
return cls(prelude, message, crc)
@@ -369,8 +326,9 @@ async def decode_async(cls, source: AsyncByteStream) -> Self | None:
369326

370327
_validate_checksum(prelude_crc_bytes + message_bytes, crc, prelude_crc)
371328

329+
headers_bytes = message_bytes[: prelude.headers_length]
372330
message = EventMessage(
373-
headers_bytes=message_bytes[: prelude.headers_length],
331+
headers=EventHeaderDecoder(headers_bytes).decode_headers(),
374332
payload=message_bytes[prelude.headers_length :],
375333
)
376334
return cls(prelude, message, crc)
@@ -647,7 +605,7 @@ def unpack_int8(data: BytesLike):
647605
:returns: A tuple containing the (parsed integer value, bytes consumed)
648606
"""
649607
value = unpack(_DecodeUtils.INT8_BYTE_FORMAT, data[:1])[0]
650-
return value, 1
608+
return Byte(value), 1
651609

652610
@staticmethod
653611
def unpack_int16(data: BytesLike) -> tuple[int, int]:
@@ -657,7 +615,7 @@ def unpack_int16(data: BytesLike) -> tuple[int, int]:
657615
:returns: A tuple containing the (parsed integer value, bytes consumed)
658616
"""
659617
value = unpack(_DecodeUtils.INT16_BYTE_FORMAT, data[:2])[0]
660-
return value, 2
618+
return Short(value), 2
661619

662620
@staticmethod
663621
def unpack_int32(data: BytesLike) -> tuple[int, int]:
@@ -677,7 +635,7 @@ def unpack_int64(data: BytesLike) -> tuple[int, int]:
677635
:returns: A tuple containing the (parsed integer value, bytes consumed)
678636
"""
679637
value = unpack(_DecodeUtils.INT64_BYTE_FORMAT, data[:8])[0]
680-
return value, 8
638+
return Long(value), 8
681639

682640
@staticmethod
683641
def unpack_byte_array(

packages/aws-event-stream/tests/unit/test_events.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -580,33 +580,6 @@ def test_event_message_rejects_long_header_value():
580580
EventMessage(headers=headers).encode()
581581

582582

583-
def test_event_message_rejects_long_headers():
584-
# 5 of these is more than enough to overcome the header size limit.
585-
long_value = b"0" * (MAX_HEADER_VALUE_BYTE_LENGTH - 1)
586-
headers = {
587-
"1": long_value,
588-
"2": long_value,
589-
"3": long_value,
590-
"4": long_value,
591-
"5": long_value,
592-
}
593-
with pytest.raises(InvalidHeadersLength):
594-
EventMessage(headers=headers).encode()
595-
596-
# These are correctly encoded, and individually valid, but collectively too long.
597-
long_headers = b""
598-
for i in range(5):
599-
long_headers += b"\x01" + str(i).encode("utf-8") + b"\x06\x7f\xfe" + long_value
600-
601-
with pytest.raises(InvalidHeadersLength):
602-
EventMessage(headers_bytes=long_headers)
603-
604-
605-
def test_event_message_decodes_headers():
606-
message = EventMessage(headers_bytes=b"\x04true\x00")
607-
assert message.headers == {"true": True}
608-
609-
610583
def test_event_encoder_rejects_long_headers():
611584
long_value = b"0" * (MAX_HEADER_VALUE_BYTE_LENGTH - 1)
612585
long_headers = b""

0 commit comments

Comments
 (0)