Skip to content

Commit cea2e24

Browse files
committed
Fix tree_min and tree_max to handle zero-size array leafs.
1 parent a02de84 commit cea2e24

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

optax/tree_utils/_tree_math.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,13 @@ def tree_max(tree: Any) -> chex.Numeric:
176176
Returns:
177177
a scalar value.
178178
"""
179-
maxes = jax.tree.map(jnp.max, tree)
180-
# initializer=-jnp.inf should work but pytype wants a jax.Array.
181-
return jax.tree.reduce(jnp.maximum, maxes, initializer=jnp.array(-jnp.inf))
179+
def f(array):
180+
if jnp.size(array) == 0:
181+
return None
182+
else:
183+
return jnp.max(array)
184+
maxes = jax.tree.map(f, tree)
185+
return jax.tree.reduce(jnp.maximum, maxes, initializer=-float("inf"))
182186

183187

184188
def tree_min(tree: Any) -> chex.Numeric:
@@ -190,9 +194,13 @@ def tree_min(tree: Any) -> chex.Numeric:
190194
Returns:
191195
a scalar value.
192196
"""
193-
mins = jax.tree.map(jnp.min, tree)
194-
# initializer=jnp.inf should work but pytype wants a jax.Array.
195-
return jax.tree.reduce(jnp.minimum, mins, initializer=jnp.array(jnp.inf))
197+
def f(array):
198+
if jnp.size(array) == 0:
199+
return None
200+
else:
201+
return jnp.min(array)
202+
mins = jax.tree.map(f, tree)
203+
return jax.tree.reduce(jnp.minimum, mins, initializer=float("inf"))
196204

197205

198206
def tree_size(tree: Any) -> int:

optax/tree_utils/_tree_math_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,12 @@ def test_tree_min(self, key):
162162
got = tu.tree_min(tree)
163163
np.testing.assert_allclose(expected, got)
164164

165+
def test_tree_min_empty(self):
166+
tree = [jnp.ones([2, 3]), jnp.zeros([4, 0, 5])]
167+
got = tu.tree_min(tree)
168+
expected = 1.0
169+
assert expected == got
170+
165171
@parameterized.parameters(
166172
'array_a', 'tree_a', 'tree_a_dict', 'tree_b', 'tree_b_dict'
167173
)

0 commit comments

Comments
 (0)