Skip to content

Commit 14bbebf

Browse files
[WIP] Upgrade JAX support from 0.4.24 to 0.4.30 (#643)
* fix jax device * fix jax device * update requirements * up * fix? * fix? * fix * complete fix * complete fix
1 parent 3ea24a3 commit 14bbebf

File tree

3 files changed

+17
-4
lines changed

3 files changed

+17
-4
lines changed

RELEASES.md

+1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
- Fix line-search in partial GW and change default init to the interior of partial transport plans (PR #602)
2424
- Fix `ot.da.sinkhorn_lpl1_mm` compatibility with JAX (PR #592)
2525
- Fiw linesearch import error on Scipy 1.14 (PR #642, Issue #641)
26+
- Upgrade supported JAX versions from jax<=0.4.24 to jax<=0.4.30 (PR #643)
2627

2728
## 0.9.3
2829
*January 2024*

ot/backend.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@
120120
import jax.scipy.special as jspecial
121121
from jax.lib import xla_bridge
122122
jax_type = jax.numpy.ndarray
123+
jax_new_version = float('.'.join(jax.__version__.split('.')[1:])) > 4.24
123124
except ImportError:
124125
jax = False
125126
jax_type = float
@@ -1439,11 +1440,19 @@ def __init__(self):
14391440
jax.device_put(jnp.array(1, dtype=jnp.float64), d)
14401441
]
14411442

1443+
self.jax_new_version = jax_new_version
1444+
14421445
def _to_numpy(self, a):
14431446
return np.array(a)
14441447

1448+
def _get_device(self, a):
1449+
if self.jax_new_version:
1450+
return list(a.devices())[0]
1451+
else:
1452+
return a.device_buffer.device()
1453+
14451454
def _change_device(self, a, type_as):
1446-
return jax.device_put(a, type_as.device_buffer.device())
1455+
return jax.device_put(a, self._get_device(type_as))
14471456

14481457
def _from_numpy(self, a, type_as=None):
14491458
if isinstance(a, float):
@@ -1688,7 +1697,10 @@ def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
16881697
return jnp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
16891698

16901699
def dtype_device(self, a):
1691-
return a.dtype, a.device_buffer.device()
1700+
if self.jax_new_version:
1701+
return a.dtype, list(a.devices())[0]
1702+
else:
1703+
return a.dtype, a.device_buffer.device()
16921704

16931705
def assert_same_dtype_device(self, a, b):
16941706
a_dtype, a_device = self.dtype_device(a)

requirements_all.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ pymanopt @ git+https://github.com/pymanopt/pymanopt.git@master
66
cvxopt
77
scikit-learn
88
torch
9-
jax<=0.4.24
10-
jaxlib<=0.4.24
9+
jax
10+
jaxlib
1111
tensorflow
1212
pytest
1313
torch_geometric

0 commit comments

Comments
 (0)