Skip to content
forked from pydata/xarray

Commit 863184d

Browse files
authored
Add close() method to DataTree and use it to clean-up open files in tests (pydata#9651)
* Add close() method to DataTree and clean-up open files in tests This removes a bunch of warnings that were previously issued in unit-tests. * Unit tests for closing functionality
1 parent ed32ba7 commit 863184d

File tree

7 files changed

+194
-77
lines changed

7 files changed

+194
-77
lines changed

xarray/backends/common.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
import os
55
import time
66
import traceback
7-
from collections.abc import Iterable
7+
from collections.abc import Iterable, Mapping
88
from glob import glob
99
from typing import TYPE_CHECKING, Any, ClassVar
1010

1111
import numpy as np
1212

1313
from xarray.conventions import cf_encoder
1414
from xarray.core import indexing
15+
from xarray.core.datatree import DataTree
1516
from xarray.core.utils import FrozenDict, NdimSizeLenMixin, is_remote_uri
1617
from xarray.namedarray.parallelcompat import get_chunked_array_type
1718
from xarray.namedarray.pycompat import is_chunked_array
@@ -20,7 +21,6 @@
2021
from io import BufferedIOBase
2122

2223
from xarray.core.dataset import Dataset
23-
from xarray.core.datatree import DataTree
2424
from xarray.core.types import NestedSequence
2525

2626
# Create a logger object, but don't add any handlers. Leave that to user code.
@@ -149,6 +149,19 @@ def find_root_and_group(ds):
149149
return ds, group
150150

151151

152+
def datatree_from_dict_with_io_cleanup(groups_dict: Mapping[str, Dataset]) -> DataTree:
153+
"""DataTree.from_dict with file clean-up."""
154+
try:
155+
tree = DataTree.from_dict(groups_dict)
156+
except Exception:
157+
for ds in groups_dict.values():
158+
ds.close()
159+
raise
160+
for path, ds in groups_dict.items():
161+
tree[path].set_close(ds._close)
162+
return tree
163+
164+
152165
def robust_getitem(array, key, catch=Exception, max_retries=6, initial_delay=500):
153166
"""
154167
Robustly index an array, using retry logic with exponential backoff if any

xarray/backends/h5netcdf_.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
BackendEntrypoint,
1414
WritableCFDataStore,
1515
_normalize_path,
16+
datatree_from_dict_with_io_cleanup,
1617
find_root_and_group,
1718
)
1819
from xarray.backends.file_manager import CachingFileManager, DummyFileManager
@@ -474,8 +475,6 @@ def open_datatree(
474475
driver_kwds=None,
475476
**kwargs,
476477
) -> DataTree:
477-
from xarray.core.datatree import DataTree
478-
479478
groups_dict = self.open_groups_as_dict(
480479
filename_or_obj,
481480
mask_and_scale=mask_and_scale,
@@ -495,8 +494,7 @@ def open_datatree(
495494
driver_kwds=driver_kwds,
496495
**kwargs,
497496
)
498-
499-
return DataTree.from_dict(groups_dict)
497+
return datatree_from_dict_with_io_cleanup(groups_dict)
500498

501499
def open_groups_as_dict(
502500
self,

xarray/backends/netCDF4_.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
BackendEntrypoint,
1717
WritableCFDataStore,
1818
_normalize_path,
19+
datatree_from_dict_with_io_cleanup,
1920
find_root_and_group,
2021
robust_getitem,
2122
)
@@ -710,8 +711,6 @@ def open_datatree(
710711
autoclose=False,
711712
**kwargs,
712713
) -> DataTree:
713-
from xarray.core.datatree import DataTree
714-
715714
groups_dict = self.open_groups_as_dict(
716715
filename_or_obj,
717716
mask_and_scale=mask_and_scale,
@@ -730,8 +729,7 @@ def open_datatree(
730729
autoclose=autoclose,
731730
**kwargs,
732731
)
733-
734-
return DataTree.from_dict(groups_dict)
732+
return datatree_from_dict_with_io_cleanup(groups_dict)
735733

736734
def open_groups_as_dict(
737735
self,

xarray/backends/zarr.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
BackendEntrypoint,
1818
_encode_variable_name,
1919
_normalize_path,
20+
datatree_from_dict_with_io_cleanup,
2021
)
2122
from xarray.backends.store import StoreBackendEntrypoint
2223
from xarray.core import indexing
@@ -1290,8 +1291,6 @@ def open_datatree(
12901291
zarr_version=None,
12911292
**kwargs,
12921293
) -> DataTree:
1293-
from xarray.core.datatree import DataTree
1294-
12951294
filename_or_obj = _normalize_path(filename_or_obj)
12961295
groups_dict = self.open_groups_as_dict(
12971296
filename_or_obj=filename_or_obj,
@@ -1312,8 +1311,7 @@ def open_datatree(
13121311
zarr_version=zarr_version,
13131312
**kwargs,
13141313
)
1315-
1316-
return DataTree.from_dict(groups_dict)
1314+
return datatree_from_dict_with_io_cleanup(groups_dict)
13171315

13181316
def open_groups_as_dict(
13191317
self,

xarray/core/datatree.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,15 @@ def update(self, other) -> NoReturn:
266266
"use `.copy()` first to get a mutable version of the input dataset."
267267
)
268268

269+
def set_close(self, close: Callable[[], None] | None) -> None:
270+
raise AttributeError("cannot modify a DatasetView()")
271+
272+
def close(self) -> None:
273+
raise AttributeError(
274+
"cannot close a DatasetView(). Close the associated DataTree node "
275+
"instead"
276+
)
277+
269278
# FIXME https://github.com/python/mypy/issues/7328
270279
@overload # type: ignore[override]
271280
def __getitem__(self, key: Mapping) -> Dataset: # type: ignore[overload-overlap]
@@ -633,7 +642,7 @@ def to_dataset(self, inherit: bool = True) -> Dataset:
633642
None if self._attrs is None else dict(self._attrs),
634643
dict(self._indexes if inherit else self._node_indexes),
635644
None if self._encoding is None else dict(self._encoding),
636-
self._close,
645+
None,
637646
)
638647

639648
@property
@@ -796,6 +805,29 @@ def _repr_html_(self):
796805
return f"<pre>{escape(repr(self))}</pre>"
797806
return datatree_repr_html(self)
798807

808+
def __enter__(self) -> Self:
809+
return self
810+
811+
def __exit__(self, exc_type, exc_value, traceback) -> None:
812+
self.close()
813+
814+
# DatasetView does not support close() or set_close(), so we reimplement
815+
# these methods on DataTree.
816+
817+
def _close_node(self) -> None:
818+
if self._close is not None:
819+
self._close()
820+
self._close = None
821+
822+
def close(self) -> None:
823+
"""Close any files associated with this tree."""
824+
for node in self.subtree:
825+
node._close_node()
826+
827+
def set_close(self, close: Callable[[], None] | None) -> None:
828+
"""Set the closer for this node."""
829+
self._close = close
830+
799831
def _replace_node(
800832
self: DataTree,
801833
data: Dataset | Default = _default,

xarray/tests/test_backends_datatree.py

+68-62
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,9 @@ def test_to_netcdf(self, tmpdir, simple_datatree):
115115
original_dt = simple_datatree
116116
original_dt.to_netcdf(filepath, engine=self.engine)
117117

118-
roundtrip_dt = open_datatree(filepath, engine=self.engine)
119-
assert_equal(original_dt, roundtrip_dt)
118+
with open_datatree(filepath, engine=self.engine) as roundtrip_dt:
119+
assert roundtrip_dt._close is not None
120+
assert_equal(original_dt, roundtrip_dt)
120121

121122
def test_to_netcdf_inherited_coords(self, tmpdir):
122123
filepath = tmpdir / "test.nc"
@@ -128,10 +129,10 @@ def test_to_netcdf_inherited_coords(self, tmpdir):
128129
)
129130
original_dt.to_netcdf(filepath, engine=self.engine)
130131

131-
roundtrip_dt = open_datatree(filepath, engine=self.engine)
132-
assert_equal(original_dt, roundtrip_dt)
133-
subtree = cast(DataTree, roundtrip_dt["/sub"])
134-
assert "x" not in subtree.to_dataset(inherit=False).coords
132+
with open_datatree(filepath, engine=self.engine) as roundtrip_dt:
133+
assert_equal(original_dt, roundtrip_dt)
134+
subtree = cast(DataTree, roundtrip_dt["/sub"])
135+
assert "x" not in subtree.to_dataset(inherit=False).coords
135136

136137
def test_netcdf_encoding(self, tmpdir, simple_datatree):
137138
filepath = tmpdir / "test.nc"
@@ -142,14 +143,13 @@ def test_netcdf_encoding(self, tmpdir, simple_datatree):
142143
enc = {"/set2": {var: comp for var in original_dt["/set2"].dataset.data_vars}}
143144

144145
original_dt.to_netcdf(filepath, encoding=enc, engine=self.engine)
145-
roundtrip_dt = open_datatree(filepath, engine=self.engine)
146+
with open_datatree(filepath, engine=self.engine) as roundtrip_dt:
147+
assert roundtrip_dt["/set2/a"].encoding["zlib"] == comp["zlib"]
148+
assert roundtrip_dt["/set2/a"].encoding["complevel"] == comp["complevel"]
146149

147-
assert roundtrip_dt["/set2/a"].encoding["zlib"] == comp["zlib"]
148-
assert roundtrip_dt["/set2/a"].encoding["complevel"] == comp["complevel"]
149-
150-
enc["/not/a/group"] = {"foo": "bar"} # type: ignore[dict-item]
151-
with pytest.raises(ValueError, match="unexpected encoding group.*"):
152-
original_dt.to_netcdf(filepath, encoding=enc, engine=self.engine)
150+
enc["/not/a/group"] = {"foo": "bar"} # type: ignore[dict-item]
151+
with pytest.raises(ValueError, match="unexpected encoding group.*"):
152+
original_dt.to_netcdf(filepath, encoding=enc, engine=self.engine)
153153

154154

155155
@requires_netCDF4
@@ -179,18 +179,17 @@ def test_open_groups(self, unaligned_datatree_nc) -> None:
179179
assert "/Group1" in unaligned_dict_of_datasets.keys()
180180
assert "/Group1/subgroup1" in unaligned_dict_of_datasets.keys()
181181
# Check that group name returns the correct datasets
182-
assert_identical(
183-
unaligned_dict_of_datasets["/"],
184-
xr.open_dataset(unaligned_datatree_nc, group="/"),
185-
)
186-
assert_identical(
187-
unaligned_dict_of_datasets["/Group1"],
188-
xr.open_dataset(unaligned_datatree_nc, group="Group1"),
189-
)
190-
assert_identical(
191-
unaligned_dict_of_datasets["/Group1/subgroup1"],
192-
xr.open_dataset(unaligned_datatree_nc, group="/Group1/subgroup1"),
193-
)
182+
with xr.open_dataset(unaligned_datatree_nc, group="/") as expected:
183+
assert_identical(unaligned_dict_of_datasets["/"], expected)
184+
with xr.open_dataset(unaligned_datatree_nc, group="Group1") as expected:
185+
assert_identical(unaligned_dict_of_datasets["/Group1"], expected)
186+
with xr.open_dataset(
187+
unaligned_datatree_nc, group="/Group1/subgroup1"
188+
) as expected:
189+
assert_identical(unaligned_dict_of_datasets["/Group1/subgroup1"], expected)
190+
191+
for ds in unaligned_dict_of_datasets.values():
192+
ds.close()
194193

195194
def test_open_groups_to_dict(self, tmpdir) -> None:
196195
"""Create an aligned netCDF4 with the following structure to test `open_groups`
@@ -234,8 +233,10 @@ def test_open_groups_to_dict(self, tmpdir) -> None:
234233

235234
aligned_dict_of_datasets = open_groups(filepath)
236235
aligned_dt = DataTree.from_dict(aligned_dict_of_datasets)
237-
238-
assert open_datatree(filepath).identical(aligned_dt)
236+
with open_datatree(filepath) as opened_tree:
237+
assert opened_tree.identical(aligned_dt)
238+
for ds in aligned_dict_of_datasets.values():
239+
ds.close()
239240

240241

241242
@requires_h5netcdf
@@ -252,8 +253,8 @@ def test_to_zarr(self, tmpdir, simple_datatree):
252253
original_dt = simple_datatree
253254
original_dt.to_zarr(filepath)
254255

255-
roundtrip_dt = open_datatree(filepath, engine="zarr")
256-
assert_equal(original_dt, roundtrip_dt)
256+
with open_datatree(filepath, engine="zarr") as roundtrip_dt:
257+
assert_equal(original_dt, roundtrip_dt)
257258

258259
def test_zarr_encoding(self, tmpdir, simple_datatree):
259260
import zarr
@@ -264,14 +265,14 @@ def test_zarr_encoding(self, tmpdir, simple_datatree):
264265
comp = {"compressor": zarr.Blosc(cname="zstd", clevel=3, shuffle=2)}
265266
enc = {"/set2": {var: comp for var in original_dt["/set2"].dataset.data_vars}}
266267
original_dt.to_zarr(filepath, encoding=enc)
267-
roundtrip_dt = open_datatree(filepath, engine="zarr")
268268

269-
print(roundtrip_dt["/set2/a"].encoding)
270-
assert roundtrip_dt["/set2/a"].encoding["compressor"] == comp["compressor"]
269+
with open_datatree(filepath, engine="zarr") as roundtrip_dt:
270+
print(roundtrip_dt["/set2/a"].encoding)
271+
assert roundtrip_dt["/set2/a"].encoding["compressor"] == comp["compressor"]
271272

272-
enc["/not/a/group"] = {"foo": "bar"} # type: ignore[dict-item]
273-
with pytest.raises(ValueError, match="unexpected encoding group.*"):
274-
original_dt.to_zarr(filepath, encoding=enc, engine="zarr")
273+
enc["/not/a/group"] = {"foo": "bar"} # type: ignore[dict-item]
274+
with pytest.raises(ValueError, match="unexpected encoding group.*"):
275+
original_dt.to_zarr(filepath, encoding=enc, engine="zarr")
275276

276277
def test_to_zarr_zip_store(self, tmpdir, simple_datatree):
277278
from zarr.storage import ZipStore
@@ -281,8 +282,8 @@ def test_to_zarr_zip_store(self, tmpdir, simple_datatree):
281282
store = ZipStore(filepath)
282283
original_dt.to_zarr(store)
283284

284-
roundtrip_dt = open_datatree(store, engine="zarr")
285-
assert_equal(original_dt, roundtrip_dt)
285+
with open_datatree(store, engine="zarr") as roundtrip_dt:
286+
assert_equal(original_dt, roundtrip_dt)
286287

287288
def test_to_zarr_not_consolidated(self, tmpdir, simple_datatree):
288289
filepath = tmpdir / "test.zarr"
@@ -295,8 +296,8 @@ def test_to_zarr_not_consolidated(self, tmpdir, simple_datatree):
295296
assert not s1zmetadata.exists()
296297

297298
with pytest.warns(RuntimeWarning, match="consolidated"):
298-
roundtrip_dt = open_datatree(filepath, engine="zarr")
299-
assert_equal(original_dt, roundtrip_dt)
299+
with open_datatree(filepath, engine="zarr") as roundtrip_dt:
300+
assert_equal(original_dt, roundtrip_dt)
300301

301302
def test_to_zarr_default_write_mode(self, tmpdir, simple_datatree):
302303
import zarr
@@ -317,10 +318,10 @@ def test_to_zarr_inherited_coords(self, tmpdir):
317318
filepath = tmpdir / "test.zarr"
318319
original_dt.to_zarr(filepath)
319320

320-
roundtrip_dt = open_datatree(filepath, engine="zarr")
321-
assert_equal(original_dt, roundtrip_dt)
322-
subtree = cast(DataTree, roundtrip_dt["/sub"])
323-
assert "x" not in subtree.to_dataset(inherit=False).coords
321+
with open_datatree(filepath, engine="zarr") as roundtrip_dt:
322+
assert_equal(original_dt, roundtrip_dt)
323+
subtree = cast(DataTree, roundtrip_dt["/sub"])
324+
assert "x" not in subtree.to_dataset(inherit=False).coords
324325

325326
def test_open_groups_round_trip(self, tmpdir, simple_datatree) -> None:
326327
"""Test `open_groups` opens a zarr store with the `simple_datatree` structure."""
@@ -331,7 +332,11 @@ def test_open_groups_round_trip(self, tmpdir, simple_datatree) -> None:
331332
roundtrip_dict = open_groups(filepath, engine="zarr")
332333
roundtrip_dt = DataTree.from_dict(roundtrip_dict)
333334

334-
assert open_datatree(filepath, engine="zarr").identical(roundtrip_dt)
335+
with open_datatree(filepath, engine="zarr") as opened_tree:
336+
assert opened_tree.identical(roundtrip_dt)
337+
338+
for ds in roundtrip_dict.values():
339+
ds.close()
335340

336341
def test_open_datatree(self, unaligned_datatree_zarr) -> None:
337342
"""Test if `open_datatree` fails to open a zarr store with an unaligned group hierarchy."""
@@ -353,21 +358,22 @@ def test_open_groups(self, unaligned_datatree_zarr) -> None:
353358
assert "/Group1/subgroup1" in unaligned_dict_of_datasets.keys()
354359
assert "/Group2" in unaligned_dict_of_datasets.keys()
355360
# Check that group name returns the correct datasets
356-
assert_identical(
357-
unaligned_dict_of_datasets["/"],
358-
xr.open_dataset(unaligned_datatree_zarr, group="/", engine="zarr"),
359-
)
360-
assert_identical(
361-
unaligned_dict_of_datasets["/Group1"],
362-
xr.open_dataset(unaligned_datatree_zarr, group="Group1", engine="zarr"),
363-
)
364-
assert_identical(
365-
unaligned_dict_of_datasets["/Group1/subgroup1"],
366-
xr.open_dataset(
367-
unaligned_datatree_zarr, group="/Group1/subgroup1", engine="zarr"
368-
),
369-
)
370-
assert_identical(
371-
unaligned_dict_of_datasets["/Group2"],
372-
xr.open_dataset(unaligned_datatree_zarr, group="/Group2", engine="zarr"),
373-
)
361+
with xr.open_dataset(
362+
unaligned_datatree_zarr, group="/", engine="zarr"
363+
) as expected:
364+
assert_identical(unaligned_dict_of_datasets["/"], expected)
365+
with xr.open_dataset(
366+
unaligned_datatree_zarr, group="Group1", engine="zarr"
367+
) as expected:
368+
assert_identical(unaligned_dict_of_datasets["/Group1"], expected)
369+
with xr.open_dataset(
370+
unaligned_datatree_zarr, group="/Group1/subgroup1", engine="zarr"
371+
) as expected:
372+
assert_identical(unaligned_dict_of_datasets["/Group1/subgroup1"], expected)
373+
with xr.open_dataset(
374+
unaligned_datatree_zarr, group="/Group2", engine="zarr"
375+
) as expected:
376+
assert_identical(unaligned_dict_of_datasets["/Group2"], expected)
377+
378+
for ds in unaligned_dict_of_datasets.values():
379+
ds.close()

0 commit comments

Comments
 (0)