Skip to content

Commit 61e9af4

Browse files
[WIP] Fix jax backend for autograd (#732)
* merge * fix jax autograd
1 parent 8aed2d7 commit 61e9af4

File tree

2 files changed

+2
-1
lines changed

2 files changed

+2
-1
lines changed

RELEASES.md

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
- `ot.gaussian.bures_wasserstein_distance` can be batched (PR #680)
1717
- Backend implementation of `ot.dist` for (PR #701)
1818
- Updated documentation Quickstart guide and User guide with new API (PR #726)
19+
- Fix jax version for auto-grad (PR #732)
1920

2021
#### Closed issues
2122
- Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668)

ot/backend.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1509,7 +1509,7 @@ def set_gradients(self, val, inputs, grads):
15091509
aux = jnp.sum(ravelled_inputs * ravelled_grads) / 2
15101510
aux = aux - jax.lax.stop_gradient(aux)
15111511

1512-
(val,) = jax.tree_map(lambda z: z + aux, (val,))
1512+
(val,) = jax.tree_util.tree_map(lambda z: z + aux, (val,))
15131513
return val
15141514

15151515
def _detach(self, a):

0 commit comments

Comments
 (0)