Commit 9caf80a
committed
Fix argument order in scale_by_distance_over_gradients
The helper function _tx is defined with parameters (g, d, g_sos), but
jax.tree.map was calling it with (max_dist, g_sos, updates). This caused
_tx to interpret the accumulated gradient sum-of-squares as d and the
raw updates (which can be negative) as g_sos. When the raw gradient had
any negative entry, the code would execute jnp.sqrt(g_sos + eps) with a
negative argument and produce NaN.
This fix corrects the argument order to (updates, max_dist, g_sos) to
match the function signature.1 parent dcff838 commit 9caf80a
1 file changed
+1
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1406 | 1406 | | |
1407 | 1407 | | |
1408 | 1408 | | |
1409 | | - | |
| 1409 | + | |
1410 | 1410 | | |
1411 | 1411 | | |
1412 | 1412 | | |
| |||
0 commit comments