|
120 | 120 | import jax.scipy.special as jspecial
|
121 | 121 | from jax.lib import xla_bridge
|
122 | 122 | jax_type = jax.numpy.ndarray
|
| 123 | + jax_new_version = float('.'.join(jax.__version__.split('.')[1:])) > 4.24 |
123 | 124 | except ImportError:
|
124 | 125 | jax = False
|
125 | 126 | jax_type = float
|
@@ -1439,11 +1440,19 @@ def __init__(self):
|
1439 | 1440 | jax.device_put(jnp.array(1, dtype=jnp.float64), d)
|
1440 | 1441 | ]
|
1441 | 1442 |
|
| 1443 | + self.jax_new_version = jax_new_version |
| 1444 | + |
1442 | 1445 | def _to_numpy(self, a):
|
1443 | 1446 | return np.array(a)
|
1444 | 1447 |
|
| 1448 | + def _get_device(self, a): |
| 1449 | + if self.jax_new_version: |
| 1450 | + return list(a.devices())[0] |
| 1451 | + else: |
| 1452 | + return a.device_buffer.device() |
| 1453 | + |
1445 | 1454 | def _change_device(self, a, type_as):
|
1446 |
| - return jax.device_put(a, type_as.device_buffer.device()) |
| 1455 | + return jax.device_put(a, self._get_device(type_as)) |
1447 | 1456 |
|
1448 | 1457 | def _from_numpy(self, a, type_as=None):
|
1449 | 1458 | if isinstance(a, float):
|
@@ -1688,7 +1697,10 @@ def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
|
1688 | 1697 | return jnp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
|
1689 | 1698 |
|
1690 | 1699 | def dtype_device(self, a):
|
1691 |
| - return a.dtype, a.device_buffer.device() |
| 1700 | + if self.jax_new_version: |
| 1701 | + return a.dtype, list(a.devices())[0] |
| 1702 | + else: |
| 1703 | + return a.dtype, a.device_buffer.device() |
1692 | 1704 |
|
1693 | 1705 | def assert_same_dtype_device(self, a, b):
|
1694 | 1706 | a_dtype, a_device = self.dtype_device(a)
|
|
0 commit comments