Skip to content

Commit da7b694

Browse files
committed
Fix tree_min and tree_max to handle zero-size array leafs.
1 parent a02de84 commit da7b694

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

optax/tree_utils/_tree_math.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,14 @@ def tree_max(tree: Any) -> chex.Numeric:
176176
Returns:
177177
a scalar value.
178178
"""
179-
maxes = jax.tree.map(jnp.max, tree)
180-
# initializer=-jnp.inf should work but pytype wants a jax.Array.
181-
return jax.tree.reduce(jnp.maximum, maxes, initializer=jnp.array(-jnp.inf))
179+
identity = -float("inf")
180+
def f(array):
181+
if jnp.size(array) == 0:
182+
return identity
183+
else:
184+
return jnp.max(array)
185+
maxes = jax.tree.map(f, tree)
186+
return jax.tree.reduce(jnp.maximum, maxes, initializer=identity)
182187

183188

184189
def tree_min(tree: Any) -> chex.Numeric:
@@ -190,9 +195,14 @@ def tree_min(tree: Any) -> chex.Numeric:
190195
Returns:
191196
a scalar value.
192197
"""
193-
mins = jax.tree.map(jnp.min, tree)
194-
# initializer=jnp.inf should work but pytype wants a jax.Array.
195-
return jax.tree.reduce(jnp.minimum, mins, initializer=jnp.array(jnp.inf))
198+
identity = float("inf")
199+
def f(array):
200+
if jnp.size(array) == 0:
201+
return identity
202+
else:
203+
return jnp.min(array)
204+
mins = jax.tree.map(f, tree)
205+
return jax.tree.reduce(jnp.minimum, mins, initializer=identity)
196206

197207

198208
def tree_size(tree: Any) -> int:

0 commit comments

Comments
 (0)