Skip to content

Commit 44deaa9

Browse files
authored
Update znflow (#548)
* add eager test * first try in supporting AddedConnections * convert tuple to list * use znflow from pypi * add test for combined dict * add error message * add comments * use dict comprehension * add 'name' to '_protected_'
1 parent fb922c6 commit 44deaa9

File tree

6 files changed

+120
-11
lines changed

6 files changed

+120
-11
lines changed

poetry.lock

+9-7
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ typer = "^0.7.0"
1818

1919
dot4dict = "^0.1.1"
2020
zninit = "^0.1.9"
21-
znflow = "^0.1.5"
2221
znjson = "^0.2.2"
22+
znflow = "^0.1.6"
2323

2424

2525
[tool.poetry.urls]
+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import pytest
2+
3+
import zntrack
4+
5+
6+
class GenerateList(zntrack.Node):
7+
size = zntrack.zn.params(10)
8+
outs = zntrack.zn.outs()
9+
10+
def run(self):
11+
self.outs = list(range(self.size))
12+
13+
14+
class AddOneToList(zntrack.Node):
15+
data = zntrack.zn.deps()
16+
outs = zntrack.zn.outs()
17+
18+
def run(self) -> None:
19+
self.outs = [x + 1 for x in self.data]
20+
21+
22+
class AddOneToDict(zntrack.Node):
23+
data = zntrack.zn.deps()
24+
outs = zntrack.zn.outs()
25+
26+
def run(self) -> None:
27+
self.outs = {k: [x + 1 for x in v] for k, v in self.data.items()}
28+
29+
30+
@pytest.mark.parametrize("eager", [True, False])
31+
def test_combine(proj_path, eager):
32+
with zntrack.Project() as proj:
33+
a = GenerateList(size=1, name="a")
34+
b = GenerateList(size=2, name="b")
35+
c = GenerateList(size=3, name="c")
36+
37+
added = AddOneToList(data=a.outs + b.outs + c.outs)
38+
39+
proj.run(eager=eager)
40+
if not eager:
41+
added.load()
42+
43+
assert added.outs == [1] + [1, 2] + [1, 2, 3]
44+
45+
46+
@pytest.mark.parametrize("eager", [True, False])
47+
def test_combine_dict(proj_path, eager):
48+
with zntrack.Project() as proj:
49+
a = GenerateList(size=1, name="a")
50+
b = GenerateList(size=2, name="b")
51+
c = GenerateList(size=3, name="c")
52+
53+
added = AddOneToDict(data={x.name: x.outs for x in [a, b, c]})
54+
55+
proj.run(eager=eager)
56+
if not eager:
57+
added.load()
58+
59+
assert added.outs == {"a": [1], "b": [1, 2], "c": [1, 2, 3]}

tests/integration/test_misc.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
import zntrack
21
import znflow
32

3+
import zntrack
4+
45

56
class NodeWithProperty(zntrack.Node):
67
params = zntrack.zn.params(None)

zntrack/core/node.py

+2
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ class Node(zninit.ZnInit, znflow.Node):
8888
name: str = _NameDescriptor(None)
8989
_name_ = None
9090

91+
_protected_ = znflow.Node._protected_ + ["name"]
92+
9193
def _post_load_(self) -> None:
9294
"""Post load hook.
9395

zntrack/fields/zn/__init__.py

+47-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import pandas as pd
99
import yaml
10+
import znflow
1011
import znflow.utils
1112
import zninit
1213
import znjson
@@ -38,6 +39,36 @@ def decode(self, value: str) -> znflow.Connection:
3839
return znflow.Connection(**value)
3940

4041

42+
class CombinedConnectionsConverter(znjson.ConverterBase):
43+
"""Convert a znflow.Connection object to dict and back."""
44+
45+
level = 100
46+
representation = "znflow.CombinedConnections"
47+
instance = znflow.CombinedConnections
48+
49+
def encode(self, obj: znflow.CombinedConnections) -> dict:
50+
"""Convert the znflow.Connection object to dict."""
51+
if obj.item is not None:
52+
raise NotImplementedError(
53+
"znflow.CombinedConnections getitem is not supported yet."
54+
)
55+
return dataclasses.asdict(obj)
56+
57+
def decode(self, value: str) -> znflow.CombinedConnections:
58+
"""Create znflow.Connection object from dict."""
59+
connections = []
60+
for item in value["connections"]:
61+
if isinstance(item, dict):
62+
# @nodify functions aren't support as 'zn.deps'
63+
# Nodes directly aren't supported because they aren't lists
64+
connections.append(znflow.Connection(**item))
65+
else:
66+
# For the case that item is already a znflow.Connection
67+
connections.append(item)
68+
value["connections"] = connections
69+
return znflow.CombinedConnections(**value)
70+
71+
4172
class SliceConverter(znjson.ConverterBase):
4273
"""Convert a znflow.Connection object to dict and back."""
4374

@@ -269,9 +300,21 @@ def get_files(self, instance) -> list:
269300
files = []
270301

271302
value = getattr(instance, self.name)
303+
# TODO use IterableHandler?
272304

305+
if isinstance(value, dict):
306+
value = list(value.values())
273307
if not isinstance(value, (list, tuple)):
274308
value = [value]
309+
if isinstance(value, tuple):
310+
value = list(value)
311+
312+
others = []
313+
for node in value:
314+
if isinstance(node, znflow.CombinedConnections):
315+
others.extend(node.connections)
316+
317+
value.extend(others)
275318

276319
for node in value:
277320
if node is None:
@@ -298,7 +341,7 @@ def save(self, instance: "Node"):
298341
value,
299342
instance,
300343
encoder=znjson.ZnEncoder.from_converters(
301-
[ConnectionConverter], add_default=True
344+
[ConnectionConverter, CombinedConnectionsConverter], add_default=True
302345
),
303346
)
304347

@@ -313,7 +356,9 @@ def get_data(self, instance: "Node") -> any:
313356

314357
value = json.loads(
315358
json.dumps(value),
316-
cls=znjson.ZnDecoder.from_converters(ConnectionConverter, add_default=True),
359+
cls=znjson.ZnDecoder.from_converters(
360+
[ConnectionConverter, CombinedConnectionsConverter], add_default=True
361+
),
317362
)
318363

319364
# Up until here we have connection objects. Now we need

0 commit comments

Comments
 (0)