Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 1 addition & 14 deletions sdks/python/apache_beam/coders/coder_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@

import decimal
import enum
import functools
import itertools
import json
import logging
Expand Down Expand Up @@ -376,18 +375,6 @@ def _verify_dill_compat():
raise RuntimeError(base_error + f". Found dill version '{dill.__version__}")


dataclass_uses_kw_only: Callable[[Any], bool]
if dataclasses:
# Cache the result to avoid multiple checks for the same dataclass type.
@functools.cache
def dataclass_uses_kw_only(cls) -> bool:
return any(
field.init and field.kw_only for field in dataclasses.fields(cls))

else:
dataclass_uses_kw_only = lambda cls: False


class FastPrimitivesCoderImpl(StreamCoderImpl):
"""For internal use only; no backwards-compatibility guarantees."""
def __init__(
Expand Down Expand Up @@ -518,7 +505,7 @@ def encode_special_deterministic(self, value, stream):
(value, type(value), self.requires_deterministic_step_label))
init_fields = [field for field in dataclasses.fields(value) if field.init]
try:
if dataclass_uses_kw_only(type(value)):
if any(field.kw_only for field in init_fields):
stream.write_byte(DATACLASS_KW_ONLY_TYPE)
self.encode_type(type(value), stream)
stream.write_var_int64(len(init_fields))
Expand Down
14 changes: 13 additions & 1 deletion sdks/python/apache_beam/coders/coders_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,15 @@ class UnFrozenDataClass:
x: int
y: int

@dataclasses.dataclass(frozen=True, kw_only=True)
class FrozenUnInitKwOnlyDataClass:
side: int
area: int = dataclasses.field(init=False)

def __post_init__(self):
# Hack to update an attribute in a frozen dataclass.
object.__setattr__(self, 'area', self.side**2)


# These tests need to all be run in the same process due to the asserts
# in tearDownClass.
Expand Down Expand Up @@ -309,6 +318,8 @@ def test_deterministic_coder(self, compat_version):
if dataclasses is not None:
self.check_coder(deterministic_coder, FrozenDataClass(1, 2))
self.check_coder(deterministic_coder, FrozenKwOnlyDataClass(c=1, d=2))
self.check_coder(
deterministic_coder, FrozenUnInitKwOnlyDataClass(side=11))

with self.assertRaises(TypeError):
self.check_coder(deterministic_coder, UnFrozenDataClass(1, 2))
Expand Down Expand Up @@ -750,6 +761,7 @@ def test_cross_process_encoding_of_special_types_is_deterministic(
from apache_beam.coders.coders_test_common import DefinesGetAndSetState
from apache_beam.coders.coders_test_common import FrozenDataClass
from apache_beam.coders.coders_test_common import FrozenKwOnlyDataClass
from apache_beam.coders.coders_test_common import FrozenUnInitKwOnlyDataClass


from apache_beam.coders import proto2_coder_test_messages_pb2 as test_message
Expand Down Expand Up @@ -786,7 +798,7 @@ def test_cross_process_encoding_of_special_types_is_deterministic(
("frozen_dataclass", FrozenDataClass(1, 2)),
("frozen_dataclass_list", [FrozenDataClass(1, 2), FrozenDataClass(3, 4)]),
("frozen_kwonly_dataclass", FrozenKwOnlyDataClass(c=1, d=2)),
("frozen_kwonly_dataclass_list", [FrozenKwOnlyDataClass(c=1, d=2), FrozenKwOnlyDataClass(c=3, d=4)]),
("frozen_kwonly_dataclass_list", [FrozenKwOnlyDataClass(c=1, d=2), FrozenUnInitKwOnlyDataClass(side=3)]),
])

compat_version = {'"'+ compat_version +'"' if compat_version else None}
Expand Down
Loading