Skip to content

Commit 9caf80a

Browse files
committed
Fix argument order in scale_by_distance_over_gradients
The helper function _tx is defined with parameters (g, d, g_sos), but jax.tree.map was calling it with (max_dist, g_sos, updates). This caused _tx to interpret the accumulated gradient sum-of-squares as d and the raw updates (which can be negative) as g_sos. When the raw gradient had any negative entry, the code would execute jnp.sqrt(g_sos + eps) with a negative argument and produce NaN. This fix corrects the argument order to (updates, max_dist, g_sos) to match the function signature.
1 parent dcff838 commit 9caf80a

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

optax/_src/transform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1406,7 +1406,7 @@ def _tx(g, d, g_sos):
14061406
eta = global_scale * (d / jnp.sqrt(g_sos + eps))
14071407
return eta * g
14081408

1409-
updates = jax.tree.map(_tx, max_dist, g_sos, updates)
1409+
updates = jax.tree.map(_tx, updates, max_dist, g_sos)
14101410

14111411
# new state
14121412
state = ScaleByDistanceOverGradientsState(

0 commit comments

Comments
 (0)