Skip to content

Commit dfdd631

Browse files
authored
Migrate treenode module. (#8757)
* Update the formating tests PR (#8702) added nbytes representation in DataArrays and Dataset repr, this adds it to the datatree tests. * Migrate treenode module Moves treenode.py and test_treenode.py. Updates some typing. Updates imports from treenode. * Update NotFoundInTreeError description. * Reformat some comments Add test tree structure for easier understanding. * Updates whats-new.rst * mypy typing. (terrible?) There must be a better way, but I don't know it. particularly the list comprehension casts. * Adds __repr__ to NamedNode and updates test This test was broken becuase only the root node was being tested and none of the previous nodes were represented in the __str__. * Adds quotes to NamedNode __str__ representation. * swaps " for ' in NamedNode __str__ representation. * Adding Tom in so he gets blamed properly. * resolve conflict whats-new.rst Question is I did update below the released line to give Tom some credit. I hope that's is allowable. * Moves test_treenode.py to xarray/tests. Integrated tests. * refactors backend tests for datatree IO * Add explicit engine back in test_to_zarr * Removes OrderedDict from treenode * Renames tests/test_io.py -> tests/test_backends_datatree.py * typo * Add types * Pass mypy for 3.9
1 parent e47eb92 commit dfdd631

File tree

13 files changed

+262
-240
lines changed

13 files changed

+262
-240
lines changed

doc/whats-new.rst

+7-3
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ Documentation
4949

5050
Internal Changes
5151
~~~~~~~~~~~~~~~~
52-
52+
- Migrates ``treenode`` functionality into ``xarray/core`` (:pull:`8757`)
53+
By `Matt Savoie <https://github.com/flamingbear>`_ and `Tom Nicholas
54+
<https://github.com/TomNicholas>`_.
5355

5456

5557
.. _whats-new.2024.02.0:
@@ -145,9 +147,11 @@ Internal Changes
145147
``xarray/namedarray``. (:pull:`8319`)
146148
By `Tom Nicholas <https://github.com/TomNicholas>`_ and `Anderson Banihirwe <https://github.com/andersy005>`_.
147149
- Imports ``datatree`` repository and history into internal location. (:pull:`8688`)
148-
By `Matt Savoie <https://github.com/flamingbear>`_ and `Justus Magin <https://github.com/keewis>`_.
150+
By `Matt Savoie <https://github.com/flamingbear>`_, `Justus Magin <https://github.com/keewis>`_
151+
and `Tom Nicholas <https://github.com/TomNicholas>`_.
149152
- Adds :py:func:`open_datatree` into ``xarray/backends`` (:pull:`8697`)
150-
By `Matt Savoie <https://github.com/flamingbear>`_.
153+
By `Matt Savoie <https://github.com/flamingbear>`_ and `Tom Nicholas
154+
<https://github.com/TomNicholas>`_.
151155
- Refactor :py:meth:`xarray.core.indexing.DaskIndexingAdapter.__getitem__` to remove an unnecessary
152156
rewrite of the indexer key (:issue: `8377`, :pull:`8758`)
153157
By `Anderson Banihirwe <https://github.com/andersy005>`_.

xarray/backends/common.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,8 @@ def _open_datatree_netcdf(
137137
**kwargs,
138138
) -> DataTree:
139139
from xarray.backends.api import open_dataset
140+
from xarray.core.treenode import NodePath
140141
from xarray.datatree_.datatree import DataTree
141-
from xarray.datatree_.datatree.treenode import NodePath
142142

143143
ds = open_dataset(filename_or_obj, **kwargs)
144144
tree_root = DataTree.from_dict({"/": ds})
@@ -159,7 +159,7 @@ def _open_datatree_netcdf(
159159

160160

161161
def _iter_nc_groups(root, parent="/"):
162-
from xarray.datatree_.datatree.treenode import NodePath
162+
from xarray.core.treenode import NodePath
163163

164164
parent = NodePath(parent)
165165
for path, group in root.groups.items():

xarray/backends/zarr.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1048,8 +1048,8 @@ def open_datatree(
10481048
import zarr
10491049

10501050
from xarray.backends.api import open_dataset
1051+
from xarray.core.treenode import NodePath
10511052
from xarray.datatree_.datatree import DataTree
1052-
from xarray.datatree_.datatree.treenode import NodePath
10531053

10541054
zds = zarr.open_group(filename_or_obj, mode="r")
10551055
ds = open_dataset(filename_or_obj, engine="zarr", **kwargs)
@@ -1075,7 +1075,7 @@ def open_datatree(
10751075

10761076

10771077
def _iter_zarr_groups(root, parent="/"):
1078-
from xarray.datatree_.datatree.treenode import NodePath
1078+
from xarray.core.treenode import NodePath
10791079

10801080
parent = NodePath(parent)
10811081
for path, group in root.groups():

xarray/datatree_/datatree/treenode.py xarray/core/treenode.py

+50-50
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,12 @@
11
from __future__ import annotations
22

33
import sys
4-
from collections import OrderedDict
4+
from collections.abc import Iterator, Mapping
55
from pathlib import PurePosixPath
66
from typing import (
77
TYPE_CHECKING,
88
Generic,
9-
Iterator,
10-
Mapping,
11-
Optional,
12-
Tuple,
139
TypeVar,
14-
Union,
1510
)
1611

1712
from xarray.core.utils import Frozen, is_dict_like
@@ -25,7 +20,7 @@ class InvalidTreeError(Exception):
2520

2621

2722
class NotFoundInTreeError(ValueError):
28-
"""Raised when operation can't be completed because one node is part of the expected tree."""
23+
"""Raised when operation can't be completed because one node is not part of the expected tree."""
2924

3025

3126
class NodePath(PurePosixPath):
@@ -55,8 +50,8 @@ class TreeNode(Generic[Tree]):
5550
5651
This class stores no data, it has only parents and children attributes, and various methods.
5752
58-
Stores child nodes in an Ordered Dictionary, which is necessary to ensure that equality checks between two trees
59-
also check that the order of child nodes is the same.
53+
Stores child nodes in an dict, ensuring that equality checks between trees
54+
and order of child nodes is preserved (since python 3.7).
6055
6156
Nodes themselves are intrinsically unnamed (do not possess a ._name attribute), but if the node has a parent you can
6257
find the key it is stored under via the .name property.
@@ -73,15 +68,16 @@ class TreeNode(Generic[Tree]):
7368
Also allows access to any other node in the tree via unix-like paths, including upwards referencing via '../'.
7469
7570
(This class is heavily inspired by the anytree library's NodeMixin class.)
71+
7672
"""
7773

78-
_parent: Optional[Tree]
79-
_children: OrderedDict[str, Tree]
74+
_parent: Tree | None
75+
_children: dict[str, Tree]
8076

81-
def __init__(self, children: Optional[Mapping[str, Tree]] = None):
77+
def __init__(self, children: Mapping[str, Tree] | None = None):
8278
"""Create a parentless node."""
8379
self._parent = None
84-
self._children = OrderedDict()
80+
self._children = {}
8581
if children is not None:
8682
self.children = children
8783

@@ -91,7 +87,7 @@ def parent(self) -> Tree | None:
9187
return self._parent
9288

9389
def _set_parent(
94-
self, new_parent: Tree | None, child_name: Optional[str] = None
90+
self, new_parent: Tree | None, child_name: str | None = None
9591
) -> None:
9692
# TODO is it possible to refactor in a way that removes this private method?
9793

@@ -127,17 +123,15 @@ def _detach(self, parent: Tree | None) -> None:
127123
if parent is not None:
128124
self._pre_detach(parent)
129125
parents_children = parent.children
130-
parent._children = OrderedDict(
131-
{
132-
name: child
133-
for name, child in parents_children.items()
134-
if child is not self
135-
}
136-
)
126+
parent._children = {
127+
name: child
128+
for name, child in parents_children.items()
129+
if child is not self
130+
}
137131
self._parent = None
138132
self._post_detach(parent)
139133

140-
def _attach(self, parent: Tree | None, child_name: Optional[str] = None) -> None:
134+
def _attach(self, parent: Tree | None, child_name: str | None = None) -> None:
141135
if parent is not None:
142136
if child_name is None:
143137
raise ValueError(
@@ -167,7 +161,7 @@ def children(self: Tree) -> Mapping[str, Tree]:
167161
@children.setter
168162
def children(self: Tree, children: Mapping[str, Tree]) -> None:
169163
self._check_children(children)
170-
children = OrderedDict(children)
164+
children = {**children}
171165

172166
old_children = self.children
173167
del self.children
@@ -242,7 +236,7 @@ def _iter_parents(self: Tree) -> Iterator[Tree]:
242236
yield node
243237
node = node.parent
244238

245-
def iter_lineage(self: Tree) -> Tuple[Tree, ...]:
239+
def iter_lineage(self: Tree) -> tuple[Tree, ...]:
246240
"""Iterate up the tree, starting from the current node."""
247241
from warnings import warn
248242

@@ -254,7 +248,7 @@ def iter_lineage(self: Tree) -> Tuple[Tree, ...]:
254248
return tuple((self, *self.parents))
255249

256250
@property
257-
def lineage(self: Tree) -> Tuple[Tree, ...]:
251+
def lineage(self: Tree) -> tuple[Tree, ...]:
258252
"""All parent nodes and their parent nodes, starting with the closest."""
259253
from warnings import warn
260254

@@ -266,12 +260,12 @@ def lineage(self: Tree) -> Tuple[Tree, ...]:
266260
return self.iter_lineage()
267261

268262
@property
269-
def parents(self: Tree) -> Tuple[Tree, ...]:
263+
def parents(self: Tree) -> tuple[Tree, ...]:
270264
"""All parent nodes and their parent nodes, starting with the closest."""
271265
return tuple(self._iter_parents())
272266

273267
@property
274-
def ancestors(self: Tree) -> Tuple[Tree, ...]:
268+
def ancestors(self: Tree) -> tuple[Tree, ...]:
275269
"""All parent nodes and their parent nodes, starting with the most distant."""
276270

277271
from warnings import warn
@@ -306,7 +300,7 @@ def is_leaf(self) -> bool:
306300
return self.children == {}
307301

308302
@property
309-
def leaves(self: Tree) -> Tuple[Tree, ...]:
303+
def leaves(self: Tree) -> tuple[Tree, ...]:
310304
"""
311305
All leaf nodes.
312306
@@ -315,20 +309,18 @@ def leaves(self: Tree) -> Tuple[Tree, ...]:
315309
return tuple([node for node in self.subtree if node.is_leaf])
316310

317311
@property
318-
def siblings(self: Tree) -> OrderedDict[str, Tree]:
312+
def siblings(self: Tree) -> dict[str, Tree]:
319313
"""
320314
Nodes with the same parent as this node.
321315
"""
322316
if self.parent:
323-
return OrderedDict(
324-
{
325-
name: child
326-
for name, child in self.parent.children.items()
327-
if child is not self
328-
}
329-
)
317+
return {
318+
name: child
319+
for name, child in self.parent.children.items()
320+
if child is not self
321+
}
330322
else:
331-
return OrderedDict()
323+
return {}
332324

333325
@property
334326
def subtree(self: Tree) -> Iterator[Tree]:
@@ -341,12 +333,12 @@ def subtree(self: Tree) -> Iterator[Tree]:
341333
--------
342334
DataTree.descendants
343335
"""
344-
from . import iterators
336+
from xarray.datatree_.datatree import iterators
345337

346338
return iterators.PreOrderIter(self)
347339

348340
@property
349-
def descendants(self: Tree) -> Tuple[Tree, ...]:
341+
def descendants(self: Tree) -> tuple[Tree, ...]:
350342
"""
351343
Child nodes and all their child nodes.
352344
@@ -431,7 +423,7 @@ def _post_attach(self: Tree, parent: Tree) -> None:
431423
"""Method call after attaching to `parent`."""
432424
pass
433425

434-
def get(self: Tree, key: str, default: Optional[Tree] = None) -> Optional[Tree]:
426+
def get(self: Tree, key: str, default: Tree | None = None) -> Tree | None:
435427
"""
436428
Return the child node with the specified key.
437429
@@ -445,7 +437,7 @@ def get(self: Tree, key: str, default: Optional[Tree] = None) -> Optional[Tree]:
445437

446438
# TODO `._walk` method to be called by both `_get_item` and `_set_item`
447439

448-
def _get_item(self: Tree, path: str | NodePath) -> Union[Tree, T_DataArray]:
440+
def _get_item(self: Tree, path: str | NodePath) -> Tree | T_DataArray:
449441
"""
450442
Returns the object lying at the given path.
451443
@@ -488,24 +480,26 @@ def _set(self: Tree, key: str, val: Tree) -> None:
488480
def _set_item(
489481
self: Tree,
490482
path: str | NodePath,
491-
item: Union[Tree, T_DataArray],
483+
item: Tree | T_DataArray,
492484
new_nodes_along_path: bool = False,
493485
allow_overwrite: bool = True,
494486
) -> None:
495487
"""
496488
Set a new item in the tree, overwriting anything already present at that path.
497489
498-
The given value either forms a new node of the tree or overwrites an existing item at that location.
490+
The given value either forms a new node of the tree or overwrites an
491+
existing item at that location.
499492
500493
Parameters
501494
----------
502495
path
503496
item
504497
new_nodes_along_path : bool
505-
If true, then if necessary new nodes will be created along the given path, until the tree can reach the
506-
specified location.
498+
If true, then if necessary new nodes will be created along the
499+
given path, until the tree can reach the specified location.
507500
allow_overwrite : bool
508-
Whether or not to overwrite any existing node at the location given by path.
501+
Whether or not to overwrite any existing node at the location given
502+
by path.
509503
510504
Raises
511505
------
@@ -580,9 +574,9 @@ class NamedNode(TreeNode, Generic[Tree]):
580574
Implements path-like relationships to other nodes in its tree.
581575
"""
582576

583-
_name: Optional[str]
584-
_parent: Optional[Tree]
585-
_children: OrderedDict[str, Tree]
577+
_name: str | None
578+
_parent: Tree | None
579+
_children: dict[str, Tree]
586580

587581
def __init__(self, name=None, children=None):
588582
super().__init__(children=children)
@@ -603,8 +597,14 @@ def name(self, name: str | None) -> None:
603597
raise ValueError("node names cannot contain forward slashes")
604598
self._name = name
605599

600+
def __repr__(self, level=0):
601+
repr_value = "\t" * level + self.__str__() + "\n"
602+
for child in self.children:
603+
repr_value += self.get(child).__repr__(level + 1)
604+
return repr_value
605+
606606
def __str__(self) -> str:
607-
return f"NamedNode({self.name})" if self.name else "NamedNode()"
607+
return f"NamedNode('{self.name}')" if self.name else "NamedNode()"
608608

609609
def _post_attach(self: NamedNode, parent: NamedNode) -> None:
610610
"""Ensures child has name attribute corresponding to key under which it has been stored."""

xarray/datatree_/datatree/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from .datatree import DataTree
33
from .extensions import register_datatree_accessor
44
from .mapping import TreeIsomorphismError, map_over_subtree
5-
from .treenode import InvalidTreeError, NotFoundInTreeError
5+
from xarray.core.treenode import InvalidTreeError, NotFoundInTreeError
66

77

88
__all__ = (

xarray/datatree_/datatree/datatree.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
MappedDataWithCoords,
5151
)
5252
from .render import RenderTree
53-
from .treenode import NamedNode, NodePath, Tree
53+
from xarray.core.treenode import NamedNode, NodePath, Tree
5454

5555
try:
5656
from xarray.core.variable import calculate_dimensions

xarray/datatree_/datatree/iterators.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from collections import abc
33
from typing import Callable, Iterator, List, Optional
44

5-
from .treenode import Tree
5+
from xarray.core.treenode import Tree
66

77
"""These iterators are copied from anytree.iterators, with minor modifications."""
88

xarray/datatree_/datatree/mapping.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
from xarray import DataArray, Dataset
1010

1111
from .iterators import LevelOrderIter
12-
from .treenode import NodePath, TreeNode
12+
from xarray.core.treenode import NodePath, TreeNode
1313

1414
if TYPE_CHECKING:
15-
from .datatree import DataTree
15+
from xarray.core.datatree import DataTree
1616

1717

1818
class TreeIsomorphismError(ValueError):

xarray/datatree_/datatree/tests/test_formatting.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -108,13 +108,13 @@ def test_diff_node_data(self):
108108
Data in nodes at position '/a' do not match:
109109
110110
Data variables only on the left object:
111-
v int64 1
111+
v int64 8B 1
112112
113113
Data in nodes at position '/a/b' do not match:
114114
115115
Differing data variables:
116-
L w int64 5
117-
R w int64 6"""
116+
L w int64 8B 5
117+
R w int64 8B 6"""
118118
)
119119
actual = diff_tree_repr(dt_1, dt_2, "equals")
120120
assert actual == expected

0 commit comments

Comments
 (0)