Skip to content

Commit

Permalink
improve groups (#96)
Browse files Browse the repository at this point in the history
* improve groups

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* replace nodes with uuids

* len, contains and iter

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* getitem and nodes-property

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* allow group import

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
PythonFZ and pre-commit-ci[bot] authored Feb 26, 2024
1 parent 9724783 commit d1dc83c
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 57 deletions.
141 changes: 99 additions & 42 deletions tests/test_graph_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def run(self):
def test_empty_grp_name():
graph = znflow.DiGraph()

with pytest.raises(TypeError):
with pytest.raises(ValueError):
with graph.group(): # name required
pass

Expand All @@ -24,19 +24,19 @@ def test_grp():

assert graph.active_group is None

with graph.group("my_grp") as grp_name:
assert graph.active_group == grp_name
with graph.group("my_grp") as grp:
assert graph.active_group == grp

node = PlainNode(1)

assert graph.active_group is None
graph.run()

assert grp_name == "my_grp"
assert grp.names == ("my_grp",)
assert node.value == 2
assert node.uuid in graph.nodes
assert grp_name in graph._groups
assert graph.get_group(grp_name) == [node.uuid]
assert grp.names in graph._groups
assert graph.get_group("my_grp").uuids == [node.uuid]

assert len(graph._groups) == 1
assert len(graph) == 1
Expand All @@ -47,36 +47,36 @@ def test_muliple_grps():

assert graph.active_group is None

with graph.group("my_grp") as grp_name:
assert graph.active_group == grp_name
with graph.group("my_grp") as grp:
assert graph.active_group == grp

node = PlainNode(1)

assert graph.active_group is None

with graph.group("my_grp2") as grp_name2:
assert graph.active_group == grp_name2
with graph.group("my_grp2") as grp2:
assert graph.active_group == grp2

node2 = PlainNode(2)

assert graph.active_group is None

graph.run()

assert grp_name == "my_grp"
assert grp_name2 == "my_grp2"
assert grp.names == ("my_grp",)
assert grp2.names == ("my_grp2",)

assert node.value == 2
assert node2.value == 3

assert node.uuid in graph.nodes
assert node2.uuid in graph.nodes

assert grp_name in graph._groups
assert grp_name2 in graph._groups
assert grp.names in graph._groups
assert grp2.names in graph._groups

assert graph.get_group(grp_name) == [node.uuid]
assert graph.get_group(grp_name2) == [node2.uuid]
assert graph.get_group(*grp.names).uuids == [node.uuid]
assert graph.get_group(*grp2.names).uuids == [node2.uuid]

assert len(graph._groups) == 2
assert len(graph) == 2
Expand All @@ -85,8 +85,8 @@ def test_muliple_grps():
def test_nested_grps():
graph = znflow.DiGraph()

with graph.group("my_grp") as grp_name:
assert graph.active_group == grp_name
with graph.group("my_grp") as grp:
assert graph.active_group == grp
with pytest.raises(TypeError):
with graph.group("my_grp2"):
pass
Expand All @@ -96,26 +96,26 @@ def test_grp_with_existing_nodes():
with znflow.DiGraph() as graph:
node = PlainNode(1)

with graph.group("my_grp") as grp_name:
assert graph.active_group == grp_name
with graph.group("my_grp") as grp:
assert graph.active_group == grp

node2 = PlainNode(2)

assert graph.active_group is None

graph.run()

assert grp_name == "my_grp"
assert grp.names == ("my_grp",)

assert node.value == 2
assert node2.value == 3

assert node.uuid in graph.nodes
assert node2.uuid in graph.nodes

assert grp_name in graph._groups
assert grp.names in graph._groups

assert graph.get_group(grp_name) == [node2.uuid]
assert graph.get_group(*grp.names).uuids == [node2.uuid]

assert len(graph._groups) == 1
assert len(graph) == 2
Expand All @@ -126,8 +126,8 @@ def test_grp_with_multiple_nodes():
node = PlainNode(1)
node2 = PlainNode(2)

with graph.group("my_grp") as grp_name:
assert graph.active_group == grp_name
with graph.group("my_grp") as grp:
assert graph.active_group == grp

node3 = PlainNode(3)
node4 = PlainNode(4)
Expand All @@ -136,7 +136,7 @@ def test_grp_with_multiple_nodes():

graph.run()

assert grp_name == "my_grp"
assert grp.names == ("my_grp",)

assert node.value == 2
assert node2.value == 3
Expand All @@ -148,42 +148,42 @@ def test_grp_with_multiple_nodes():
assert node3.uuid in graph.nodes
assert node4.uuid in graph.nodes

assert grp_name in graph._groups
assert grp.names in graph._groups

assert graph.get_group(grp_name) == [node3.uuid, node4.uuid]
assert graph.get_group(*grp.names).uuids == [node3.uuid, node4.uuid]

assert len(graph._groups) == 1
assert len(graph) == 4


def test_reopen_grps():
with znflow.DiGraph() as graph:
with graph.group("my_grp") as grp_name:
assert graph.active_group == grp_name
with graph.group("my_grp") as grp:
assert graph.active_group == grp

node = PlainNode(1)

with graph.group("my_grp") as grp_name2:
assert graph.active_group == grp_name2
with graph.group("my_grp") as grp2:
assert graph.active_group == grp2

node2 = PlainNode(2)

assert graph.active_group is None

graph.run()

assert grp_name == "my_grp"
assert grp_name2 == grp_name
assert grp.names == ("my_grp",)
assert grp.names == grp2.names

assert node.value == 2
assert node2.value == 3

assert node.uuid in graph.nodes
assert node2.uuid in graph.nodes

assert grp_name in graph._groups
assert grp.names in graph._groups

assert graph.get_group(grp_name) == [node.uuid, node2.uuid]
assert graph.get_group(*grp.names).uuids == [node.uuid, node2.uuid]

assert len(graph._groups) == 1
assert len(graph) == 2
Expand All @@ -193,19 +193,19 @@ def test_tuple_grp_names():
graph = znflow.DiGraph()

assert graph.active_group is None
with graph.group(("grp", "1")) as grp_name:
assert graph.active_group == grp_name
with graph.group("grp", "1") as grp:
assert graph.active_group == grp

node = PlainNode(1)

assert graph.active_group is None
graph.run()

assert grp_name == ("grp", "1")
assert grp.names == ("grp", "1")
assert node.value == 2
assert node.uuid in graph.nodes
assert grp_name in graph._groups
assert graph.get_group(grp_name) == [node.uuid]
assert grp.names in graph._groups
assert graph.get_group(*grp.names).uuids == [node.uuid]


def test_grp_nodify():
Expand All @@ -218,4 +218,61 @@ def compute_mean(x, y):
with graph.group("grp1"):
n1 = compute_mean(2, 4)

assert n1.uuid in graph.get_group("grp1")
assert n1.uuid in graph.get_group("grp1").uuids


def test_grp_iter():
graph = znflow.DiGraph()

with graph.group("grp1") as grp:
n1 = PlainNode(1)
n2 = PlainNode(2)

assert list(grp) == [n1.uuid, n2.uuid]


def test_grp_contains():
graph = znflow.DiGraph()

with graph.group("grp1") as grp:
n1 = PlainNode(1)
n2 = PlainNode(2)

assert n1.uuid in grp
assert n2.uuid in grp
assert "foo" not in grp


def test_grp_len():
graph = znflow.DiGraph()

with graph.group("grp1") as grp:
PlainNode(1)
PlainNode(2)

assert len(grp) == 2


def test_grp_getitem():
graph = znflow.DiGraph()

with graph.group("grp1") as grp:
n1 = PlainNode(1)
n2 = PlainNode(2)

assert grp[n1.uuid] == n1
assert grp[n2.uuid] == n2
with pytest.raises(KeyError):
grp["foo"]


def test_grp_nodes():
graph = znflow.DiGraph()

with graph.group("grp1") as grp:
n1 = PlainNode(1)
n2 = PlainNode(2)

assert grp.nodes == [n1, n2]
assert grp.uuids == [n1.uuid, n2.uuid]
assert grp.names == ("grp1",)
3 changes: 2 additions & 1 deletion znflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)
from znflow.combine import combine
from znflow.dynamic import resolve
from znflow.graph import DiGraph
from znflow.graph import DiGraph, Group
from znflow.node import Node, nodify
from znflow.visualize import draw

Expand All @@ -39,6 +39,7 @@
"get_graph",
"empty_graph",
"resolve",
"Group",
]

with contextlib.suppress(ImportError):
Expand Down
Loading

0 comments on commit d1dc83c

Please sign in to comment.