Skip to content

Commit 45534d0

Browse files
authored
Add plots modify (#581)
* add plots modify * plots modify * update plots mmodify * add plots options to zn.plots * add global option * extend tests * extend tests * test fails * code for change to outs * poetry update * lint * copy template it not in the cwd * use outs instead of plots * fix save
1 parent ff14c26 commit 45534d0

File tree

6 files changed

+583
-279
lines changed

6 files changed

+583
-279
lines changed

poetry.lock

+272-273
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ readme = "README.md"
1010

1111
[tool.poetry.dependencies]
1212
python = ">=3.8,<4.0.0"
13-
dvc = "^2.53.0"
13+
dvc = "^2.52.0, !=2.53.0"
1414
pyyaml = "^6.0"
1515
tqdm = "^4.64.0"
1616
pandas = "^1.4.3"

tests/integration/test_dvc_plots.py

+188
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
import zntrack
2+
import pandas as pd
3+
import pytest
4+
import pathlib
5+
import yaml
6+
from zntrack.utils import run_dvc_cmd
7+
8+
9+
class NodeWithPlotsDVC(zntrack.Node):
10+
data = zntrack.dvc.plots(
11+
zntrack.nwd / "data.csv",
12+
x="x",
13+
y="y",
14+
x_label="x",
15+
y_label="y",
16+
title="title",
17+
template="linear",
18+
use_global_plots=False,
19+
)
20+
21+
def run(self) -> None:
22+
df = pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})
23+
df.to_csv(self.data)
24+
25+
def get_data(self):
26+
return pd.read_csv(self.data)
27+
28+
29+
class NodeWithPlotsZn(zntrack.Node):
30+
data = zntrack.zn.plots(
31+
x="x",
32+
y="y",
33+
x_label="x",
34+
y_label="y",
35+
title="title",
36+
template="linear",
37+
use_global_plots=False,
38+
)
39+
40+
def run(self) -> None:
41+
self.data = pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})
42+
43+
def get_data(self):
44+
return self.data
45+
46+
47+
class NodeWithPlotsZnGlobal(zntrack.Node):
48+
data = zntrack.zn.plots(
49+
x="x",
50+
y=["y", "z"],
51+
x_label="x",
52+
y_label="y",
53+
title="title",
54+
template="linear",
55+
)
56+
metrics = zntrack.zn.metrics()
57+
58+
def run(self) -> None:
59+
self.data = pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6], "z": [7, 8, 9]})
60+
self.metrics = {"a": 1, "b": 2}
61+
62+
def get_data(self):
63+
return self.data
64+
65+
66+
class NodeWithPlotsDVCGlobal(zntrack.Node):
67+
data = zntrack.dvc.plots(
68+
zntrack.nwd / "data.csv",
69+
x="x",
70+
y=["y", "z"],
71+
x_label="x",
72+
y_label="y",
73+
title="title",
74+
template="linear",
75+
)
76+
metrics = zntrack.dvc.metrics(zntrack.nwd / "metrics.json")
77+
78+
def run(self) -> None:
79+
df = pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6], "z": [7, 8, 9]})
80+
df.to_csv(self.data)
81+
self.metrics.write_text("{'a': 1, 'b': 2}")
82+
83+
def get_data(self):
84+
return pd.read_csv(self.data)
85+
86+
87+
class NodeRemoteTemplate(zntrack.Node):
88+
data = zntrack.dvc.plots(
89+
zntrack.nwd / "data.csv",
90+
template=zntrack.__file__,
91+
)
92+
93+
def run(self) -> None:
94+
df = pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6], "z": [7, 8, 9]})
95+
df.to_csv(self.data)
96+
97+
98+
@pytest.mark.parametrize("cls", [NodeWithPlotsDVC, NodeWithPlotsZn])
99+
@pytest.mark.parametrize("eager", [True, False])
100+
def test_NodeWithPlots(proj_path, eager, cls):
101+
with zntrack.Project() as project:
102+
node = cls()
103+
project.run(eager=eager)
104+
105+
if not eager:
106+
node.load()
107+
dvc_dict = yaml.safe_load((proj_path / "dvc.yaml").read_text())
108+
109+
plots = {
110+
"x": "x",
111+
"y": "y",
112+
"x_label": "x",
113+
"y_label": "y",
114+
"title": "title",
115+
"template": "linear",
116+
}
117+
118+
assert (
119+
dvc_dict["stages"][node.name]["plots"][0][f"nodes/{node.name}/data.csv"]
120+
== plots
121+
)
122+
run_dvc_cmd(["plots", "show"])
123+
124+
if isinstance(node, NodeWithPlotsDVC):
125+
assert NodeWithPlotsDVC.from_rev().data == pathlib.Path(
126+
"nodes", node.name, "data.csv"
127+
)
128+
129+
assert node.get_data()["x"].tolist() == [1, 2, 3]
130+
131+
132+
@pytest.mark.parametrize("cls", [NodeWithPlotsZnGlobal, NodeWithPlotsDVCGlobal])
133+
@pytest.mark.parametrize("eager", [True, False])
134+
def test_NodeWithPlotsGlobal(proj_path, eager, cls):
135+
with zntrack.Project() as project:
136+
node = cls()
137+
project.run(eager=eager)
138+
139+
if not eager:
140+
node.load()
141+
dvc_dict = yaml.safe_load((proj_path / "dvc.yaml").read_text())
142+
143+
plots = {
144+
"x": "x",
145+
"y": ["y", "z"],
146+
"x_label": "x",
147+
"y_label": "y",
148+
"title": "title",
149+
"template": "linear",
150+
}
151+
152+
assert dvc_dict["plots"][0][f"nodes/{node.name}/data.csv"] == plots
153+
154+
run_dvc_cmd(["plots", "show"])
155+
156+
assert node.get_data()["x"].tolist() == [1, 2, 3]
157+
158+
159+
def test_multiple_plots_nodes(proj_path):
160+
with zntrack.Project(automatic_node_names=True) as project:
161+
NodeWithPlotsZnGlobal()
162+
NodeWithPlotsZnGlobal()
163+
NodeWithPlotsDVCGlobal()
164+
a = NodeWithPlotsDVCGlobal()
165+
project.run()
166+
run_dvc_cmd(["plots", "show"])
167+
168+
with project:
169+
NodeWithPlotsZnGlobal()
170+
b = NodeWithPlotsZnGlobal()
171+
NodeWithPlotsDVCGlobal()
172+
NodeWithPlotsDVCGlobal()
173+
project.run()
174+
run_dvc_cmd(["plots", "show"])
175+
# run_dvc_cmd(["repro", "-f"])
176+
177+
# check loading
178+
NodeWithPlotsDVCGlobal.from_rev(name=a.name)
179+
NodeWithPlotsZnGlobal.from_rev(name=b.name)
180+
181+
182+
def test_NodeRemoteTemplate(proj_path):
183+
with zntrack.Project(automatic_node_names=True) as project:
184+
NodeRemoteTemplate()
185+
NodeRemoteTemplate()
186+
project.run(repro=False)
187+
188+
assert pathlib.Path("dvc_plots", "templates", "__init__.py").exists()

zntrack/fields/dvc/__init__.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import znjson
77

8-
from zntrack.fields.field import Field, FieldGroup
8+
from zntrack.fields.field import Field, FieldGroup, PlotsMixin
99
from zntrack.utils import node_wd
1010

1111
if typing.TYPE_CHECKING:
@@ -130,6 +130,10 @@ def __get__(self, instance: "Node", owner=None):
130130
return node_wd.ReplaceNWD()(value, nwd=instance.nwd)
131131

132132

133+
class PlotsOption(PlotsMixin, DVCOption):
134+
"""Field with DVC plots kwargs."""
135+
136+
133137
def outs(*args, **kwargs) -> DVCOption:
134138
"""Create a outs field."""
135139
return DVCOption(*args, dvc_option="outs", **kwargs)
@@ -167,9 +171,9 @@ def metrics_no_cache(*args, **kwargs) -> DVCOption:
167171

168172
def plots(*args, **kwargs) -> DVCOption:
169173
"""Create a plots field."""
170-
return DVCOption(*args, dvc_option="plots", **kwargs)
174+
return PlotsOption(*args, dvc_option="plots", **kwargs)
171175

172176

173177
def plots_no_cache(*args, **kwargs) -> DVCOption:
174178
"""Create a plots_no_cache field."""
175-
return DVCOption(*args, dvc_option="plots-no-cache", **kwargs)
179+
return PlotsOption(*args, dvc_option="plots-no-cache", **kwargs)

zntrack/fields/field.py

+112
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
import json
66
import logging
77
import pathlib
8+
import shutil
89
import typing
910

11+
import yaml
1012
import zninit
1113

1214
from zntrack.utils import LazyOption, config
@@ -196,3 +198,113 @@ def load(self, instance: "Node", lazy: bool = None):
196198
instance.__dict__[self.name] = LazyOption
197199
else:
198200
super().load(instance)
201+
202+
203+
class PlotsMixin(Field):
204+
"""DVC Plots Option including 'dvc plots modify' command."""
205+
206+
def __init__(
207+
self,
208+
*args,
209+
template=None,
210+
x=None,
211+
y=None,
212+
x_label=None,
213+
y_label=None,
214+
title=None,
215+
use_global_plots: bool = True,
216+
**kwargs,
217+
):
218+
"""Create a DVCOption field.
219+
220+
Attributes
221+
----------
222+
use_global_plots : bool
223+
Save the plots config not in 'stages' but in 'plots' in the dvc.yaml file.
224+
"""
225+
super().__init__(*args, **kwargs)
226+
self.plots_options = {}
227+
self.use_global_plots = use_global_plots
228+
if self.use_global_plots:
229+
if self.dvc_option == "plots":
230+
self.dvc_option = "outs"
231+
elif self.dvc_option == "plots-no-cache":
232+
self.dvc_option = "outs-no-cache"
233+
if template is not None:
234+
self.plots_options["--template"] = pathlib.Path(template).as_posix()
235+
if x is not None:
236+
self.plots_options["-x"] = x
237+
if y is not None:
238+
self.plots_options["-y"] = y
239+
if x_label is not None:
240+
self.plots_options["--x-label"] = x_label
241+
if y_label is not None:
242+
self.plots_options["--y-label"] = y_label
243+
if title is not None:
244+
self.plots_options["--title"] = title
245+
246+
def save(self, instance: "Node"):
247+
"""Save plots options to dvc.yaml, if use_global_plots is True."""
248+
if self.plots_options.get("--template") is not None:
249+
template = pathlib.Path(self.plots_options["--template"]).resolve()
250+
if pathlib.Path.cwd() not in template.parents:
251+
# copy template to dvc_plots/templates if it is not in the cwd
252+
template_dir = pathlib.Path.cwd() / "dvc_plots" / "templates"
253+
template_dir.mkdir(parents=True, exist_ok=True)
254+
shutil.copy(template, template_dir)
255+
self.plots_options["--template"] = (
256+
(template_dir / template.name)
257+
.relative_to(pathlib.Path.cwd())
258+
.as_posix()
259+
)
260+
261+
with contextlib.suppress(NotImplementedError):
262+
super().save(instance=instance)
263+
if not self.use_global_plots:
264+
return
265+
266+
dvc_file = pathlib.Path("dvc.yaml")
267+
if not dvc_file.exists():
268+
dvc_file.write_text(yaml.safe_dump({}))
269+
dvc_config = yaml.safe_load(dvc_file.read_text())
270+
plots = dvc_config.get("plots", [])
271+
272+
# remove leading "-/--"
273+
for key in list(self.plots_options):
274+
if key.startswith("--"):
275+
self.plots_options[key[2:]] = self.plots_options[key]
276+
del self.plots_options[key]
277+
elif key.startswith("-"):
278+
self.plots_options[key[1:]] = self.plots_options[key]
279+
del self.plots_options[key]
280+
# replace "-" with "_"
281+
for key in list(self.plots_options):
282+
if key.replace("-", "_") != key:
283+
self.plots_options[key.replace("-", "_")] = self.plots_options[key]
284+
del self.plots_options[key]
285+
286+
for file in self.get_files(instance):
287+
replaced = False
288+
for entry in plots: # entry: dict{filename: {x:, y:, ...}}
289+
if pathlib.Path(file) == pathlib.Path(next(iter(entry))):
290+
entry = {pathlib.Path(file).as_posix(): self.plots_options}
291+
replaced = True
292+
if not replaced:
293+
plots.append({pathlib.Path(file).as_posix(): self.plots_options})
294+
295+
dvc_config["plots"] = plots
296+
dvc_file.write_text(yaml.dump(dvc_config))
297+
298+
def get_optional_dvc_cmd(self, instance: "Node") -> typing.List[typing.List[str]]:
299+
"""Add 'dvc plots modify' to this option."""
300+
if not self.use_global_plots:
301+
cmds = []
302+
for file in self.get_files(instance):
303+
cmd = ["plots", "modify", pathlib.Path(file).as_posix()]
304+
for key, value in self.plots_options.items():
305+
cmd.append(f"{key}")
306+
cmd.append(pathlib.Path(value).as_posix())
307+
cmds.append(cmd)
308+
return cmds
309+
else:
310+
return []

zntrack/fields/zn/__init__.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from znflow import handler
1515

1616
from zntrack import exceptions
17-
from zntrack.fields.field import DataIsLazyError, Field, FieldGroup, LazyField
17+
from zntrack.fields.field import DataIsLazyError, Field, FieldGroup, LazyField, PlotsMixin
1818
from zntrack.utils import module_handler, update_key_val
1919

2020
if typing.TYPE_CHECKING:
@@ -234,7 +234,7 @@ def get_stage_add_argument(self, instance) -> typing.List[tuple]:
234234
return [(f"--{self.dvc_option}", file.as_posix())]
235235

236236

237-
class Plots(LazyField):
237+
class Plots(PlotsMixin, LazyField):
238238
"""A field that is saved to disk."""
239239

240240
dvc_option: str = "plots"
@@ -246,6 +246,7 @@ def get_files(self, instance) -> list:
246246

247247
def save(self, instance: "Node"):
248248
"""Save the field to disk."""
249+
super().save(instance)
249250
try:
250251
value = self.get_value_except_lazy(instance)
251252
except DataIsLazyError:

0 commit comments

Comments
 (0)