1
1
from __future__ import annotations
2
2
3
3
import sys
4
- from collections import OrderedDict
4
+ from collections . abc import Iterator , Mapping
5
5
from pathlib import PurePosixPath
6
6
from typing import (
7
7
TYPE_CHECKING ,
8
8
Generic ,
9
- Iterator ,
10
- Mapping ,
11
- Optional ,
12
- Tuple ,
13
9
TypeVar ,
14
- Union ,
15
10
)
16
11
17
12
from xarray .core .utils import Frozen , is_dict_like
@@ -25,7 +20,7 @@ class InvalidTreeError(Exception):
25
20
26
21
27
22
class NotFoundInTreeError (ValueError ):
28
- """Raised when operation can't be completed because one node is part of the expected tree."""
23
+ """Raised when operation can't be completed because one node is not part of the expected tree."""
29
24
30
25
31
26
class NodePath (PurePosixPath ):
@@ -55,8 +50,8 @@ class TreeNode(Generic[Tree]):
55
50
56
51
This class stores no data, it has only parents and children attributes, and various methods.
57
52
58
- Stores child nodes in an Ordered Dictionary, which is necessary to ensure that equality checks between two trees
59
- also check that the order of child nodes is the same .
53
+ Stores child nodes in an dict, ensuring that equality checks between trees
54
+ and order of child nodes is preserved (since python 3.7) .
60
55
61
56
Nodes themselves are intrinsically unnamed (do not possess a ._name attribute), but if the node has a parent you can
62
57
find the key it is stored under via the .name property.
@@ -73,15 +68,16 @@ class TreeNode(Generic[Tree]):
73
68
Also allows access to any other node in the tree via unix-like paths, including upwards referencing via '../'.
74
69
75
70
(This class is heavily inspired by the anytree library's NodeMixin class.)
71
+
76
72
"""
77
73
78
- _parent : Optional [ Tree ]
79
- _children : OrderedDict [str , Tree ]
74
+ _parent : Tree | None
75
+ _children : dict [str , Tree ]
80
76
81
- def __init__ (self , children : Optional [ Mapping [str , Tree ]] = None ):
77
+ def __init__ (self , children : Mapping [str , Tree ] | None = None ):
82
78
"""Create a parentless node."""
83
79
self ._parent = None
84
- self ._children = OrderedDict ()
80
+ self ._children = {}
85
81
if children is not None :
86
82
self .children = children
87
83
@@ -91,7 +87,7 @@ def parent(self) -> Tree | None:
91
87
return self ._parent
92
88
93
89
def _set_parent (
94
- self , new_parent : Tree | None , child_name : Optional [ str ] = None
90
+ self , new_parent : Tree | None , child_name : str | None = None
95
91
) -> None :
96
92
# TODO is it possible to refactor in a way that removes this private method?
97
93
@@ -127,17 +123,15 @@ def _detach(self, parent: Tree | None) -> None:
127
123
if parent is not None :
128
124
self ._pre_detach (parent )
129
125
parents_children = parent .children
130
- parent ._children = OrderedDict (
131
- {
132
- name : child
133
- for name , child in parents_children .items ()
134
- if child is not self
135
- }
136
- )
126
+ parent ._children = {
127
+ name : child
128
+ for name , child in parents_children .items ()
129
+ if child is not self
130
+ }
137
131
self ._parent = None
138
132
self ._post_detach (parent )
139
133
140
- def _attach (self , parent : Tree | None , child_name : Optional [ str ] = None ) -> None :
134
+ def _attach (self , parent : Tree | None , child_name : str | None = None ) -> None :
141
135
if parent is not None :
142
136
if child_name is None :
143
137
raise ValueError (
@@ -167,7 +161,7 @@ def children(self: Tree) -> Mapping[str, Tree]:
167
161
@children .setter
168
162
def children (self : Tree , children : Mapping [str , Tree ]) -> None :
169
163
self ._check_children (children )
170
- children = OrderedDict ( children )
164
+ children = { ** children }
171
165
172
166
old_children = self .children
173
167
del self .children
@@ -242,7 +236,7 @@ def _iter_parents(self: Tree) -> Iterator[Tree]:
242
236
yield node
243
237
node = node .parent
244
238
245
- def iter_lineage (self : Tree ) -> Tuple [Tree , ...]:
239
+ def iter_lineage (self : Tree ) -> tuple [Tree , ...]:
246
240
"""Iterate up the tree, starting from the current node."""
247
241
from warnings import warn
248
242
@@ -254,7 +248,7 @@ def iter_lineage(self: Tree) -> Tuple[Tree, ...]:
254
248
return tuple ((self , * self .parents ))
255
249
256
250
@property
257
- def lineage (self : Tree ) -> Tuple [Tree , ...]:
251
+ def lineage (self : Tree ) -> tuple [Tree , ...]:
258
252
"""All parent nodes and their parent nodes, starting with the closest."""
259
253
from warnings import warn
260
254
@@ -266,12 +260,12 @@ def lineage(self: Tree) -> Tuple[Tree, ...]:
266
260
return self .iter_lineage ()
267
261
268
262
@property
269
- def parents (self : Tree ) -> Tuple [Tree , ...]:
263
+ def parents (self : Tree ) -> tuple [Tree , ...]:
270
264
"""All parent nodes and their parent nodes, starting with the closest."""
271
265
return tuple (self ._iter_parents ())
272
266
273
267
@property
274
- def ancestors (self : Tree ) -> Tuple [Tree , ...]:
268
+ def ancestors (self : Tree ) -> tuple [Tree , ...]:
275
269
"""All parent nodes and their parent nodes, starting with the most distant."""
276
270
277
271
from warnings import warn
@@ -306,7 +300,7 @@ def is_leaf(self) -> bool:
306
300
return self .children == {}
307
301
308
302
@property
309
- def leaves (self : Tree ) -> Tuple [Tree , ...]:
303
+ def leaves (self : Tree ) -> tuple [Tree , ...]:
310
304
"""
311
305
All leaf nodes.
312
306
@@ -315,20 +309,18 @@ def leaves(self: Tree) -> Tuple[Tree, ...]:
315
309
return tuple ([node for node in self .subtree if node .is_leaf ])
316
310
317
311
@property
318
- def siblings (self : Tree ) -> OrderedDict [str , Tree ]:
312
+ def siblings (self : Tree ) -> dict [str , Tree ]:
319
313
"""
320
314
Nodes with the same parent as this node.
321
315
"""
322
316
if self .parent :
323
- return OrderedDict (
324
- {
325
- name : child
326
- for name , child in self .parent .children .items ()
327
- if child is not self
328
- }
329
- )
317
+ return {
318
+ name : child
319
+ for name , child in self .parent .children .items ()
320
+ if child is not self
321
+ }
330
322
else :
331
- return OrderedDict ()
323
+ return {}
332
324
333
325
@property
334
326
def subtree (self : Tree ) -> Iterator [Tree ]:
@@ -341,12 +333,12 @@ def subtree(self: Tree) -> Iterator[Tree]:
341
333
--------
342
334
DataTree.descendants
343
335
"""
344
- from . import iterators
336
+ from xarray . datatree_ . datatree import iterators
345
337
346
338
return iterators .PreOrderIter (self )
347
339
348
340
@property
349
- def descendants (self : Tree ) -> Tuple [Tree , ...]:
341
+ def descendants (self : Tree ) -> tuple [Tree , ...]:
350
342
"""
351
343
Child nodes and all their child nodes.
352
344
@@ -431,7 +423,7 @@ def _post_attach(self: Tree, parent: Tree) -> None:
431
423
"""Method call after attaching to `parent`."""
432
424
pass
433
425
434
- def get (self : Tree , key : str , default : Optional [ Tree ] = None ) -> Optional [ Tree ] :
426
+ def get (self : Tree , key : str , default : Tree | None = None ) -> Tree | None :
435
427
"""
436
428
Return the child node with the specified key.
437
429
@@ -445,7 +437,7 @@ def get(self: Tree, key: str, default: Optional[Tree] = None) -> Optional[Tree]:
445
437
446
438
# TODO `._walk` method to be called by both `_get_item` and `_set_item`
447
439
448
- def _get_item (self : Tree , path : str | NodePath ) -> Union [ Tree , T_DataArray ] :
440
+ def _get_item (self : Tree , path : str | NodePath ) -> Tree | T_DataArray :
449
441
"""
450
442
Returns the object lying at the given path.
451
443
@@ -488,24 +480,26 @@ def _set(self: Tree, key: str, val: Tree) -> None:
488
480
def _set_item (
489
481
self : Tree ,
490
482
path : str | NodePath ,
491
- item : Union [ Tree , T_DataArray ] ,
483
+ item : Tree | T_DataArray ,
492
484
new_nodes_along_path : bool = False ,
493
485
allow_overwrite : bool = True ,
494
486
) -> None :
495
487
"""
496
488
Set a new item in the tree, overwriting anything already present at that path.
497
489
498
- The given value either forms a new node of the tree or overwrites an existing item at that location.
490
+ The given value either forms a new node of the tree or overwrites an
491
+ existing item at that location.
499
492
500
493
Parameters
501
494
----------
502
495
path
503
496
item
504
497
new_nodes_along_path : bool
505
- If true, then if necessary new nodes will be created along the given path, until the tree can reach the
506
- specified location.
498
+ If true, then if necessary new nodes will be created along the
499
+ given path, until the tree can reach the specified location.
507
500
allow_overwrite : bool
508
- Whether or not to overwrite any existing node at the location given by path.
501
+ Whether or not to overwrite any existing node at the location given
502
+ by path.
509
503
510
504
Raises
511
505
------
@@ -580,9 +574,9 @@ class NamedNode(TreeNode, Generic[Tree]):
580
574
Implements path-like relationships to other nodes in its tree.
581
575
"""
582
576
583
- _name : Optional [ str ]
584
- _parent : Optional [ Tree ]
585
- _children : OrderedDict [str , Tree ]
577
+ _name : str | None
578
+ _parent : Tree | None
579
+ _children : dict [str , Tree ]
586
580
587
581
def __init__ (self , name = None , children = None ):
588
582
super ().__init__ (children = children )
@@ -603,8 +597,14 @@ def name(self, name: str | None) -> None:
603
597
raise ValueError ("node names cannot contain forward slashes" )
604
598
self ._name = name
605
599
600
+ def __repr__ (self , level = 0 ):
601
+ repr_value = "\t " * level + self .__str__ () + "\n "
602
+ for child in self .children :
603
+ repr_value += self .get (child ).__repr__ (level + 1 )
604
+ return repr_value
605
+
606
606
def __str__ (self ) -> str :
607
- return f"NamedNode({ self .name } )" if self .name else "NamedNode()"
607
+ return f"NamedNode(' { self .name } ' )" if self .name else "NamedNode()"
608
608
609
609
def _post_attach (self : NamedNode , parent : NamedNode ) -> None :
610
610
"""Ensures child has name attribute corresponding to key under which it has been stored."""
0 commit comments