Skip to content

Commit a017f4e

Browse files
committed
Optimize tree_sum compile time using tree_reduce_associative
Changed tree_sum implementation to use jax.tree_util.tree_reduce_associative when available (JAX >= 0.6.0). Since addition is an associative operation, tree_reduce_associative can provide better compilation performance. Testing shows runtime is very close but compile time is significantly lower (18s vs 23s in reported cases). For compatibility with older JAX versions (< 0.6.0), the implementation falls back to the original jax.tree.reduce when tree_reduce_associative is not available.
1 parent dcff838 commit a017f4e

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

optax/tree_utils/_tree_math.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import jax
2222
import jax.numpy as jnp
23+
from jax import tree_util as jtu
2324
from optax._src import numerics
2425

2526

@@ -176,7 +177,12 @@ def tree_sum(tree: Any) -> jax.typing.ArrayLike:
176177
a scalar value.
177178
"""
178179
sums = jax.tree.map(jnp.sum, tree)
179-
return jax.tree.reduce(operator.add, sums, initializer=0)
180+
# Use tree_reduce_associative for better compile time performance when
181+
# available (JAX >= 0.6.0), otherwise fall back to tree.reduce.
182+
if hasattr(jtu, 'tree_reduce_associative'):
183+
return jtu.tree_reduce_associative(operator.add, sums)
184+
else:
185+
return jax.tree.reduce(operator.add, sums, initializer=0)
180186

181187

182188
def tree_max(tree: Any) -> jax.typing.ArrayLike:

0 commit comments

Comments
 (0)