Skip to content

Commit fb922c6

Browse files
authored
add remove existing graph arg to zntrack.Project (#544)
* add exception and remove_existing_graph * change check
1 parent 3456e5f commit fb922c6

File tree

5 files changed

+56
-6
lines changed

5 files changed

+56
-6
lines changed

tests/integration/test_project.py

+16
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,19 @@ def test_WriteIO_no_name(tmp_path_2, assert_before_exp):
7171

7272
assert exp2["WriteIO"].inputs == "Lorem Ipsum"
7373
assert exp2["WriteIO"].outputs == "Lorem Ipsum"
74+
75+
76+
def test_project_remove_graph(proj_path):
77+
with zntrack.Project() as project:
78+
node = WriteIO(inputs="Hello World")
79+
project.run()
80+
node.load()
81+
assert node.outputs == "Hello World"
82+
83+
with zntrack.Project(remove_existing_graph=True) as project:
84+
node2 = WriteIO(inputs="Lorem Ipsum", name="node2")
85+
project.run()
86+
node2.load()
87+
assert node2.outputs == "Lorem Ipsum"
88+
with pytest.raises(zntrack.exceptions.NodeNotAvailableError):
89+
node.load()

zntrack/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55
import importlib.metadata
66

7-
from zntrack import tools
7+
from zntrack import exceptions, tools
88
from zntrack.core.node import Node
99
from zntrack.core.nodify import NodeConfig, nodify
1010
from zntrack.fields import Field, FieldGroup, LazyField, dvc, meta, zn
@@ -28,4 +28,5 @@
2828
"nodify",
2929
"NodeConfig",
3030
"tools",
31+
"exceptions",
3132
]

zntrack/core/node.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import zninit
1515
import znjson
1616

17+
from zntrack import exceptions
1718
from zntrack.notebooks.jupyter import jupyter_class_to_file
1819
from zntrack.utils import NodeStatusResults, deprecated, module_handler, run_dvc_cmd
1920
from zntrack.utils.config import config
@@ -162,10 +163,13 @@ def load(self, lazy: bool = None) -> None:
162163

163164
kwargs = {} if lazy is None else {"lazy": lazy}
164165
self.state.loaded = True # we assume loading will be successful.
165-
with config.updated_config(**kwargs):
166-
# TODO: it would be much nicer not to use a global config object here.
167-
for attr in zninit.get_descriptors(Field, self=self):
168-
attr.load(self)
166+
try:
167+
with config.updated_config(**kwargs):
168+
# TODO: it would be much nicer not to use a global config object here.
169+
for attr in zninit.get_descriptors(Field, self=self):
170+
attr.load(self)
171+
except KeyError as err:
172+
raise exceptions.NodeNotAvailableError(self) from err
169173

170174
# TODO: documentation about _post_init and _post_load_ and when they are called
171175
self._post_load_()

zntrack/exceptions/__init__.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""All ZnTrack exceptions."""
2+
3+
4+
class NodeNotAvailableError(Exception):
5+
"""Raised when a node is not available."""
6+
7+
def __init__(self, arg):
8+
"""Initialize the exception.
9+
10+
Parameters
11+
----------
12+
arg : str|Node
13+
Custom Error message or Node that is not available.
14+
"""
15+
if isinstance(arg, str):
16+
super().__init__(arg)
17+
else:
18+
# assume arg is a Node
19+
super().__init__(f"Node {arg.name} is not available.")

zntrack/project/zntrack_project.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -96,19 +96,29 @@ def _initalize():
9696
class Project(_ProjectBase):
9797
"""The ZnTrack Project class."""
9898

99-
def __init__(self, initialize: bool = True) -> None:
99+
def __init__(
100+
self, initialize: bool = True, remove_existing_graph: bool = False
101+
) -> None:
100102
"""Initialize the Project.
101103
102104
Attributes
103105
----------
104106
initialize : bool, default = True
105107
If True, initialize a git repository and a dvc repository.
108+
remove_existing_graph : bool, default = False
109+
If True, remove 'dvc.yaml', 'zntrack.json' and 'params.yaml'
110+
before writing new nodes.
106111
"""
107112
# TODO maybe it is not a good idea to base everything on the DiGraph class.
108113
# It seems to call some class methods
109114
super().__init__()
110115
if initialize:
111116
_initalize()
117+
if remove_existing_graph:
118+
# we remove the files that typically contain the graph definition
119+
pathlib.Path("zntrack.json").unlink(missing_ok=True)
120+
pathlib.Path("dvc.yaml").unlink(missing_ok=True)
121+
pathlib.Path("params.yaml").unlink(missing_ok=True)
112122

113123
def create_branch(self, name: str) -> "Branch":
114124
"""Create a branch in the project."""

0 commit comments

Comments
 (0)