Skip to content

Commit 8579921

Browse files
authored
Merge pull request #684 from tclose/special-form-hash
Added handling of hashing of types with args and typing special forms
2 parents 31aea01 + 4809dfe commit 8579921

File tree

3 files changed

+110
-15
lines changed

3 files changed

+110
-15
lines changed

pydra/utils/hash.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33

44
# import stat
55
import struct
6+
import typing as ty
67
from collections.abc import Mapping
78
from functools import singledispatch
89
from hashlib import blake2b
10+
import logging
911

1012
# from pathlib import Path
1113
from typing import (
@@ -14,10 +16,11 @@
1416
NewType,
1517
Sequence,
1618
Set,
17-
_SpecialForm,
1819
)
1920
import attrs.exceptions
2021

22+
logger = logging.getLogger("pydra")
23+
2124
try:
2225
from typing import Protocol
2326
except ImportError:
@@ -88,7 +91,8 @@ def hash_single(obj: object, cache: Cache) -> Hash:
8891
h = blake2b(digest_size=16, person=b"pydra-hash")
8992
for chunk in bytes_repr(obj, cache):
9093
h.update(chunk)
91-
cache[objid] = Hash(h.digest())
94+
hsh = cache[objid] = Hash(h.digest())
95+
logger.debug("Hash of %s object is %s", obj, hsh)
9296
return cache[objid]
9397

9498

@@ -102,15 +106,14 @@ def __bytes_repr__(self, cache: Cache) -> Iterator[bytes]:
102106
def bytes_repr(obj: object, cache: Cache) -> Iterator[bytes]:
103107
cls = obj.__class__
104108
yield f"{cls.__module__}.{cls.__name__}:{{".encode()
105-
try:
109+
dct: Dict[str, ty.Any]
110+
if attrs.has(type(obj)):
111+
# Drop any attributes that aren't used in comparisons by default
112+
dct = attrs.asdict(obj, recurse=False, filter=lambda a, _: bool(a.eq))
113+
elif hasattr(obj, "__slots__"):
114+
dct = {attr: getattr(obj, attr) for attr in obj.__slots__}
115+
else:
106116
dct = obj.__dict__
107-
except AttributeError as e:
108-
# Attrs creates slots classes by default, so we add this here to handle those
109-
# cases
110-
try:
111-
dct = attrs.asdict(obj, recurse=False) # type: ignore
112-
except attrs.exceptions.NotAnAttrsClassError:
113-
raise TypeError(f"Cannot hash {obj} as it is a slots class") from e
114117
yield from bytes_repr_mapping_contents(dct, cache)
115118
yield b"}"
116119

@@ -224,10 +227,34 @@ def bytes_repr_dict(obj: dict, cache: Cache) -> Iterator[bytes]:
224227
yield b"}"
225228

226229

227-
@register_serializer(_SpecialForm)
230+
@register_serializer(ty._GenericAlias)
231+
@register_serializer(ty._SpecialForm)
228232
@register_serializer(type)
229233
def bytes_repr_type(klass: type, cache: Cache) -> Iterator[bytes]:
230-
yield f"type:({klass.__module__}.{klass.__name__})".encode()
234+
def type_name(tp):
235+
try:
236+
name = tp.__name__
237+
except AttributeError:
238+
name = tp._name
239+
return name
240+
241+
yield b"type:("
242+
origin = ty.get_origin(klass)
243+
if origin:
244+
yield f"{origin.__module__}.{type_name(origin)}[".encode()
245+
for arg in ty.get_args(klass):
246+
if isinstance(
247+
arg, list
248+
): # sometimes (e.g. Callable) the args of a type is a list
249+
yield b"["
250+
yield from (b for t in arg for b in bytes_repr_type(t, cache))
251+
yield b"]"
252+
else:
253+
yield from bytes_repr_type(arg, cache)
254+
yield b"]"
255+
else:
256+
yield f"{klass.__module__}.{type_name(klass)}".encode()
257+
yield b")"
231258

232259

233260
@register_serializer(list)

pydra/utils/tests/test_hash.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
import attrs
66
import pytest
7-
7+
import typing as ty
8+
from fileformats.application import Zip, Json
89
from ..hash import Cache, UnhashableError, bytes_repr, hash_object, register_serializer
910

1011

@@ -134,6 +135,17 @@ def __init__(self, x):
134135
assert re.match(rb".*\.MyClass:{str:1:x=.{16}}", obj_repr)
135136

136137

138+
def test_bytes_repr_slots_obj():
139+
class MyClass:
140+
__slots__ = ("x",)
141+
142+
def __init__(self, x):
143+
self.x = x
144+
145+
obj_repr = join_bytes_repr(MyClass(1))
146+
assert re.match(rb".*\.MyClass:{str:1:x=.{16}}", obj_repr)
147+
148+
137149
def test_bytes_repr_attrs_slots():
138150
@attrs.define
139151
class MyClass:
@@ -143,11 +155,67 @@ class MyClass:
143155
assert re.match(rb".*\.MyClass:{str:1:x=.{16}}", obj_repr)
144156

145157

146-
def test_bytes_repr_type():
158+
def test_bytes_repr_attrs_no_slots():
159+
@attrs.define(slots=False)
160+
class MyClass:
161+
x: int
162+
163+
obj_repr = join_bytes_repr(MyClass(1))
164+
assert re.match(rb".*\.MyClass:{str:1:x=.{16}}", obj_repr)
165+
166+
167+
def test_bytes_repr_type1():
147168
obj_repr = join_bytes_repr(Path)
148169
assert obj_repr == b"type:(pathlib.Path)"
149170

150171

172+
def test_bytes_repr_type1a():
173+
obj_repr = join_bytes_repr(Zip[Json])
174+
assert obj_repr == rb"type:(fileformats.application.archive.Json__Zip)"
175+
176+
177+
def test_bytes_repr_type2():
178+
T = ty.TypeVar("T")
179+
180+
class MyClass(ty.Generic[T]):
181+
pass
182+
183+
obj_repr = join_bytes_repr(MyClass[int])
184+
assert (
185+
obj_repr == b"type:(pydra.utils.tests.test_hash.MyClass[type:(builtins.int)])"
186+
)
187+
188+
189+
def test_bytes_special_form1():
190+
obj_repr = join_bytes_repr(ty.Union[int, float])
191+
assert obj_repr == b"type:(typing.Union[type:(builtins.int)type:(builtins.float)])"
192+
193+
194+
def test_bytes_special_form2():
195+
obj_repr = join_bytes_repr(ty.Any)
196+
assert re.match(rb"type:\(typing.Any\)", obj_repr)
197+
198+
199+
def test_bytes_special_form3():
200+
obj_repr = join_bytes_repr(ty.Optional[Path])
201+
assert (
202+
obj_repr == b"type:(typing.Union[type:(pathlib.Path)type:(builtins.NoneType)])"
203+
)
204+
205+
206+
def test_bytes_special_form4():
207+
obj_repr = join_bytes_repr(ty.Type[Path])
208+
assert obj_repr == b"type:(builtins.type[type:(pathlib.Path)])"
209+
210+
211+
def test_bytes_special_form5():
212+
obj_repr = join_bytes_repr(ty.Callable[[Path, int], ty.Tuple[float, str]])
213+
assert obj_repr == (
214+
b"type:(collections.abc.Callable[[type:(pathlib.Path)type:(builtins.int)]"
215+
b"type:(builtins.tuple[type:(builtins.float)type:(builtins.str)])])"
216+
)
217+
218+
151219
def test_recursive_object():
152220
a = []
153221
b = [a]

pydra/utils/tests/test_typing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ...engine.specs import File, LazyOutField
99
from ..typing import TypeParser
1010
from pydra import Workflow
11-
from fileformats.serialization import Json
11+
from fileformats.application import Json
1212
from .utils import (
1313
generic_func_task,
1414
GenericShellTask,

0 commit comments

Comments
 (0)