Skip to content

Commit a24ee7c

Browse files
authored
better support for properties (#42)
* better support for properties * resolve attribute access issue
1 parent ffb261d commit a24ee7c

File tree

7 files changed

+195
-15
lines changed

7 files changed

+195
-15
lines changed

README.md

+3
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@ assert n1.results == 4.0
127127

128128
```
129129

130+
Instead, you can also use the ``znflow.disable_graph`` decorator / context manager to disable the graph for a specific block of code or the ``znflow.Property`` as a drop-in replacement for ``property``.
131+
132+
130133
# Supported Frameworks
131134
ZnFlow includes tests to ensure compatibility with:
132135
- "Plain classes"

tests/test_get_attribute.py

+88-8
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import znflow
44

55

6-
class POW2(znflow.Node):
7-
x_factor: float = 0.5
6+
class POW2Base(znflow.Node):
7+
x_factor: float = 1.0
88
results: float = None
99
_x: float = None
1010

@@ -22,6 +22,11 @@ def x_(self, value):
2222
"""
2323
self._x = value * self.x_factor
2424

25+
def run(self):
26+
self.results = self.x**1
27+
28+
29+
class POW2GetAttr(POW2Base):
2530
@property
2631
def x(self):
2732
return self._x
@@ -30,21 +35,96 @@ def x(self):
3035
def x(self, value):
3136
self._x = value * znflow.get_attribute(self, "x_factor")
3237

33-
def run(self):
34-
self.results = self.x**2
3538

39+
class POW2Decorate(POW2Base):
40+
@property
41+
def x(self):
42+
return self._x
43+
44+
@znflow.disable_graph()
45+
@x.setter
46+
def x(self, value):
47+
self._x = value * self.x_factor
48+
49+
50+
class POW2Decorate2(POW2Base):
51+
@znflow.Property
52+
def x(self):
53+
return self._x
54+
55+
@x.setter
56+
def x(self, value):
57+
self._x = value * self.x_factor
58+
59+
60+
class POW2Context(POW2Base):
61+
@property
62+
def x(self):
63+
return self._x
64+
65+
@x.setter
66+
def x(self, value):
67+
with znflow.disable_graph():
68+
self._x = value * self.x_factor
3669

37-
def test_get_attribute():
70+
71+
@pytest.mark.parametrize("cls", [POW2GetAttr, POW2Decorate, POW2Context, POW2Decorate2])
72+
def test_get_attribute(cls):
3873
with znflow.DiGraph() as graph:
39-
n1 = POW2()
74+
n1 = cls()
4075
n1.x = 4.0 # converted to 2.0
4176

4277
graph.run()
43-
assert n1.x == 2.0
78+
assert n1.x == 4.0
4479
assert n1.results == 4.0
4580

4681
with znflow.DiGraph() as graph:
47-
n1 = POW2()
82+
n1 = cls()
4883
with pytest.raises(TypeError):
4984
# TypeError: unsupported operand type(s) for *: 'float' and 'Connection'
5085
n1.x_ = 4.0
86+
87+
88+
class InvalidAttribute(znflow.Node):
89+
@property
90+
def invalid_attribute(self):
91+
raise ValueError("attribute not available")
92+
93+
94+
def test_invalid_attribute():
95+
node = InvalidAttribute()
96+
with pytest.raises(ValueError):
97+
node.invalid_attribute
98+
99+
with znflow.DiGraph() as graph:
100+
node = InvalidAttribute()
101+
invalid_attribute = node.invalid_attribute
102+
assert isinstance(invalid_attribute, znflow.Connection)
103+
assert invalid_attribute.instance == node
104+
assert invalid_attribute.attribute == "invalid_attribute"
105+
assert node.uuid in graph
106+
107+
108+
class NodeWithInit(znflow.Node):
109+
def __init__(self):
110+
self.x = 1.0
111+
112+
113+
def test_attribute_not_found():
114+
"""Try to access an Attribute which does not exist."""
115+
with pytest.raises(AttributeError):
116+
node = InvalidAttribute()
117+
node.this_does_not_exist
118+
119+
with znflow.DiGraph():
120+
node = POW2GetAttr()
121+
with pytest.raises(AttributeError):
122+
node.this_does_not_exist
123+
124+
with znflow.DiGraph():
125+
node = NodeWithInit()
126+
with pytest.raises(AttributeError):
127+
node.this_does_not_exist
128+
outs = node.x
129+
130+
assert outs.result == 1.0

tests/test_node_node.py

+19
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,15 @@ def run(self):
2323
self.outputs = sum(self.inputs)
2424

2525

26+
@dataclasses.dataclass
27+
class SumNodesFromDict(znflow.Node):
28+
inputs: dict
29+
outputs: float = None
30+
31+
def run(self):
32+
self.outputs = sum(self.inputs.values())
33+
34+
2635
def test_eager():
2736
node = Node(inputs=1)
2837
node.run()
@@ -102,3 +111,13 @@ def test_graph_multi():
102111
graph.run()
103112

104113
assert node7.outputs == 80
114+
115+
116+
def test_SumNodesFromDict():
117+
with znflow.DiGraph() as graph:
118+
node1 = Node(inputs=5)
119+
node2 = Node(inputs=10)
120+
node3 = SumNodesFromDict(inputs={"a": node1.outputs, "b": node2.outputs})
121+
graph.run()
122+
123+
assert node3.outputs == 30

znflow/__init__.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
import logging
44
import sys
55

6-
from znflow.base import Connection, FunctionFuture, get_attribute
6+
from znflow.base import (
7+
Connection,
8+
FunctionFuture,
9+
Property,
10+
disable_graph,
11+
get_attribute,
12+
)
713
from znflow.graph import DiGraph
814
from znflow.node import Node, nodify
915
from znflow.visualize import draw
@@ -18,6 +24,8 @@
1824
"FunctionFuture",
1925
"Connection",
2026
"get_attribute",
27+
"disable_graph",
28+
"Property",
2129
]
2230

2331
logger = logging.getLogger(__name__)

znflow/base.py

+54-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
@contextlib.contextmanager
10-
def disable_graph():
10+
def disable_graph(*args, **kwargs):
1111
"""Temporarily disable set the graph to None.
1212
1313
This can be useful, if you e.g. want to use 'get_attribute'.
@@ -20,6 +20,59 @@ def disable_graph():
2020
set_graph(graph)
2121

2222

23+
class Property:
24+
"""Custom Property with disabled graph.
25+
26+
References
27+
----------
28+
Adapted from https://docs.python.org/3/howto/descriptor.html#properties
29+
"""
30+
31+
def __init__(self, fget=None, fset=None, fdel=None, doc=None):
32+
self.fget = disable_graph()(fget)
33+
self.fset = disable_graph()(fset)
34+
self.fdel = disable_graph()(fdel)
35+
if doc is None and fget is not None:
36+
doc = fget.__doc__
37+
self.__doc__ = doc
38+
self._name = ""
39+
40+
def __set_name__(self, owner, name):
41+
self._name = name
42+
43+
def __get__(self, obj, objtype=None):
44+
if obj is None:
45+
return self
46+
if self.fget is None:
47+
raise AttributeError(f"property '{self._name}' has no getter")
48+
return self.fget(obj)
49+
50+
def __set__(self, obj, value):
51+
if self.fset is None:
52+
raise AttributeError(f"property '{self._name}' has no setter")
53+
self.fset(obj, value)
54+
55+
def __delete__(self, obj):
56+
if self.fdel is None:
57+
raise AttributeError(f"property '{self._name}' has no deleter")
58+
self.fdel(obj)
59+
60+
def getter(self, fget):
61+
prop = type(self)(fget, self.fset, self.fdel, self.__doc__)
62+
prop._name = self._name
63+
return prop
64+
65+
def setter(self, fset):
66+
prop = type(self)(self.fget, fset, self.fdel, self.__doc__)
67+
prop._name = self._name
68+
return prop
69+
70+
def deleter(self, fdel):
71+
prop = type(self)(self.fget, self.fset, fdel, self.__doc__)
72+
prop._name = self._name
73+
return prop
74+
75+
2376
class NodeBaseMixin:
2477
"""A Parent for all Nodes.
2578

znflow/graph.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,13 @@ def _update_node_attributes(self, node_instance: Node, updater) -> None:
9393
if attribute.startswith("_") or attribute in Node._protected_:
9494
# We do not allow connections to private attributes.
9595
continue
96-
value = getattr(node_instance, attribute)
96+
try:
97+
value = getattr(node_instance, attribute)
98+
except Exception:
99+
# It might be, that the value is currently not available.
100+
# For example, it could be a property that is not yet set.
101+
# In this case we skip updating the attribute, no matter the exception.
102+
continue
97103
value = updater(value)
98104
if updater.updated:
99105
try:

znflow/node.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,13 @@
44
import inspect
55
import uuid
66

7-
from znflow.base import Connection, FunctionFuture, NodeBaseMixin, get_graph
7+
from znflow.base import (
8+
Connection,
9+
FunctionFuture,
10+
NodeBaseMixin,
11+
disable_graph,
12+
get_graph,
13+
)
814

915

1016
def _mark_init_in_construction(cls):
@@ -51,14 +57,19 @@ def __new__(cls, *args, **kwargs):
5157
return instance
5258

5359
def __getattribute__(self, item):
54-
value = super().__getattribute__(item)
5560
if get_graph() is not None:
61+
with disable_graph():
62+
if item not in set(dir(self)):
63+
raise AttributeError(
64+
f"'{self.__class__.__name__}' object has no attribute '{item}'"
65+
)
66+
5667
if item not in type(self)._protected_ and not item.startswith("_"):
5768
if self._in_construction:
58-
return value
69+
return super().__getattribute__(item)
5970
connector = Connection(instance=self, attribute=item)
6071
return connector
61-
return value
72+
return super().__getattribute__(item)
6273

6374
def __setattr__(self, item, value) -> None:
6475
super().__setattr__(item, value)

0 commit comments

Comments
 (0)