Skip to content

Commit c6d8eba

Browse files
author
OptaxDev
committed
Merge pull request #1321 from carlosgmartin:tree_utils
PiperOrigin-RevId: 773082934
2 parents 315521a + cadb2bc commit c6d8eba

File tree

5 files changed

+60
-0
lines changed

5 files changed

+60
-0
lines changed

docs/api/utilities.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,13 +110,15 @@ Tree
110110
tree_norm
111111
tree_map_params
112112
tree_max
113+
tree_min
113114
tree_mul
114115
tree_ones_like
115116
tree_random_like
116117
tree_real
117118
tree_split_key_like
118119
tree_scale
119120
tree_set
121+
tree_size
120122
tree_sub
121123
tree_sum
122124
tree_vdot
@@ -183,6 +185,10 @@ Tree max
183185
~~~~~~~~
184186
.. autofunction:: tree_max
185187

188+
Tree min
189+
~~~~~~~~
190+
.. autofunction:: tree_min
191+
186192
Tree multiply
187193
~~~~~~~~~~~~~
188194
.. autofunction:: tree_mul
@@ -211,6 +217,10 @@ Set values in a tree
211217
~~~~~~~~~~~~~~~~~~~~
212218
.. autofunction:: tree_set
213219

220+
Tree size
221+
~~~~~~~~~
222+
.. autofunction:: tree_size
223+
214224
Tree subtract
215225
~~~~~~~~~~~~~
216226
.. autofunction:: tree_sub

optax/tree/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,13 @@
4040
div = _tree_math.tree_div
4141
full_like = _tree_math.tree_full_like
4242
max = _tree_math.tree_max # pylint: disable=redefined-builtin
43+
min = _tree_math.tree_min # pylint: disable=redefined-builtin
4344
mul = _tree_math.tree_mul
4445
norm = _tree_math.tree_norm
4546
ones_like = _tree_math.tree_ones_like
4647
real = _tree_math.tree_real
4748
scale = _tree_math.tree_scale
49+
size = _tree_math.tree_size
4850
sub = _tree_math.tree_sub
4951
sum = _tree_math.tree_sum # pylint: disable=redefined-builtin
5052
update_infinity_moment = _tree_math.tree_update_infinity_moment

optax/tree_utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,13 @@
3939
from optax.tree_utils._tree_math import tree_div
4040
from optax.tree_utils._tree_math import tree_full_like
4141
from optax.tree_utils._tree_math import tree_max
42+
from optax.tree_utils._tree_math import tree_min
4243
from optax.tree_utils._tree_math import tree_mul
4344
from optax.tree_utils._tree_math import tree_norm
4445
from optax.tree_utils._tree_math import tree_ones_like
4546
from optax.tree_utils._tree_math import tree_real
4647
from optax.tree_utils._tree_math import tree_scale
48+
from optax.tree_utils._tree_math import tree_size
4749
from optax.tree_utils._tree_math import tree_sub
4850
from optax.tree_utils._tree_math import tree_sum
4951
from optax.tree_utils._tree_math import tree_update_infinity_moment

optax/tree_utils/_tree_math.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,32 @@ def tree_max(tree: Any) -> chex.Numeric:
183183
return jax.tree.reduce(jnp.maximum, maxes, initializer=jnp.array(-jnp.inf))
184184

185185

186+
def tree_min(tree: Any) -> chex.Numeric:
187+
"""Compute the min of all the elements in a pytree.
188+
189+
Args:
190+
tree: pytree.
191+
192+
Returns:
193+
a scalar value.
194+
"""
195+
mins = jax.tree.map(jnp.min, tree)
196+
# initializer=jnp.inf should work but pytype wants a jax.Array.
197+
return jax.tree.reduce(jnp.minimum, mins, initializer=jnp.array(jnp.inf))
198+
199+
200+
def tree_size(tree: Any) -> int:
201+
r"""Total size of a pytree.
202+
203+
Args:
204+
tree: pytree
205+
206+
Returns:
207+
the total size of the pytree.
208+
"""
209+
return sum(jnp.size(leaf) for leaf in jax.tree.leaves(tree))
210+
211+
186212
def tree_conj(tree: Any) -> Any:
187213
"""Compute the conjugate of a pytree.
188214

optax/tree_utils/_tree_math_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,26 @@ def test_tree_max(self, key):
152152
got = tu.tree_max(tree)
153153
np.testing.assert_allclose(expected, got)
154154

155+
@parameterized.parameters(
156+
'array_a', 'tree_a', 'tree_a_dict', 'tree_b', 'tree_b_dict'
157+
)
158+
def test_tree_min(self, key):
159+
tree = self.data[key]
160+
values, _ = flatten_util.ravel_pytree(tree)
161+
expected = jnp.min(values)
162+
got = tu.tree_min(tree)
163+
np.testing.assert_allclose(expected, got)
164+
165+
@parameterized.parameters(
166+
'array_a', 'tree_a', 'tree_a_dict', 'tree_b', 'tree_b_dict'
167+
)
168+
def test_tree_size(self, key):
169+
tree = self.data[key]
170+
values, _ = flatten_util.ravel_pytree(tree)
171+
expected = jnp.size(values)
172+
got = tu.tree_size(tree)
173+
np.testing.assert_allclose(expected, got)
174+
155175
def test_tree_conj(self):
156176
expected = jnp.conj(self.array_a)
157177
got = tu.tree_conj(self.array_a)

0 commit comments

Comments
 (0)