Skip to content

Commit

Permalink
move group updates (#106)
Browse files Browse the repository at this point in the history
* move group updates

* add tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
PythonFZ and pre-commit-ci[bot] authored Aug 29, 2024
1 parent 1d431b0 commit 74b1cd3
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 23 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "znflow"
version = "0.2.0"
version = "0.2.1"
description = "A general purpose framework for building and running computational graphs."
authors = ["zincwarecode <[email protected]>"]
license = "Apache-2.0"
Expand Down
40 changes: 40 additions & 0 deletions tests/test_graph_build_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Test for raising an error when building a graph."""

import dataclasses

import znflow


@dataclasses.dataclass
class MyNode(znflow.Node):
value: int

def run(self):
pass


def test_graph_build_exception():
graph = znflow.DiGraph()

try:
with graph:
node = MyNode(value=42)
raise ValueError("This is a test")
except ValueError:
pass

assert node.uuid in graph


def test_group_build_exception():
graph = znflow.DiGraph()

try:
with graph.group("group") as grp:
node = MyNode(value=42)
raise ValueError("This is a test")
except ValueError:
pass

assert node.uuid in graph
assert node.uuid in grp
24 changes: 12 additions & 12 deletions tests/test_graph_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ def test_grp():
assert grp.names == ("my_grp",)
assert node.value == 2
assert node.uuid in graph.nodes
assert grp.names in graph._groups
assert grp.names in graph.groups
assert graph.get_group("my_grp").uuids == [node.uuid]

assert len(graph._groups) == 1
assert len(graph.groups) == 1
assert len(graph) == 1


Expand Down Expand Up @@ -72,13 +72,13 @@ def test_muliple_grps():
assert node.uuid in graph.nodes
assert node2.uuid in graph.nodes

assert grp.names in graph._groups
assert grp2.names in graph._groups
assert grp.names in graph.groups
assert grp2.names in graph.groups

assert graph.get_group(*grp.names).uuids == [node.uuid]
assert graph.get_group(*grp2.names).uuids == [node2.uuid]

assert len(graph._groups) == 2
assert len(graph.groups) == 2
assert len(graph) == 2


Expand Down Expand Up @@ -113,11 +113,11 @@ def test_grp_with_existing_nodes():
assert node.uuid in graph.nodes
assert node2.uuid in graph.nodes

assert grp.names in graph._groups
assert grp.names in graph.groups

assert graph.get_group(*grp.names).uuids == [node2.uuid]

assert len(graph._groups) == 1
assert len(graph.groups) == 1
assert len(graph) == 2


Expand Down Expand Up @@ -148,11 +148,11 @@ def test_grp_with_multiple_nodes():
assert node3.uuid in graph.nodes
assert node4.uuid in graph.nodes

assert grp.names in graph._groups
assert grp.names in graph.groups

assert graph.get_group(*grp.names).uuids == [node3.uuid, node4.uuid]

assert len(graph._groups) == 1
assert len(graph.groups) == 1
assert len(graph) == 4


Expand Down Expand Up @@ -181,11 +181,11 @@ def test_reopen_grps():
assert node.uuid in graph.nodes
assert node2.uuid in graph.nodes

assert grp.names in graph._groups
assert grp.names in graph.groups

assert graph.get_group(*grp.names).uuids == [node.uuid, node2.uuid]

assert len(graph._groups) == 1
assert len(graph.groups) == 1
assert len(graph) == 2


Expand All @@ -204,7 +204,7 @@ def test_tuple_grp_names():
assert grp.names == ("grp", "1")
assert node.value == 2
assert node.uuid in graph.nodes
assert grp.names in graph._groups
assert grp.names in graph.groups
assert graph.get_group(*grp.names).uuids == [node.uuid]


Expand Down
2 changes: 1 addition & 1 deletion tests/test_znflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@

def test_version():
"""Test the version."""
assert znflow.__version__ == "0.2.0"
assert znflow.__version__ == "0.2.1"
26 changes: 17 additions & 9 deletions znflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(
"""
self.disable = disable
self.immutable_nodes = immutable_nodes
self._groups = {}
self.groups = {}
self.active_group: typing.Union[Group, None] = None
self.deployment = deployment or VanillaDeployment()
self.deployment.set_graph(self)
Expand Down Expand Up @@ -244,21 +244,29 @@ def group(self, *names: str) -> typing.Generator[Group, None, None]:

existing_nodes = self.get_sorted_nodes()

group = self._groups.get(names, Group(names=names, uuids=[], graph=self))
group = self.groups.get(names, Group(names=names, uuids=[], graph=self))

def finalize_group():
for node_uuid in self.nodes:
if node_uuid not in existing_nodes:
self.groups[group.names] = group
group.uuids.append(node_uuid)

try:
self.active_group = group
if get_graph() is empty_graph:
with self:
yield group
try:
yield group
finally:
finalize_group()
else:
yield group
try:
yield group
finally:
finalize_group()
finally:
self.active_group = None
for node_uuid in self.nodes:
if node_uuid not in existing_nodes:
self._groups[group.names] = group
group.uuids.append(node_uuid)

def get_group(self, *names: str) -> Group:
return self._groups[names]
return self.groups[names]

0 comments on commit 74b1cd3

Please sign in to comment.