Description
Description
For all purposes, numpy arrays having dtype float0
(let's call them Z-arrays) are, in fact, Jax arrays. But these arrays are treated as NumPy arrays by array_api_compat.get_namespace
leading to crashes when getting the namespace of multiple arrays. Therefore, something appears to be fundamentally broken.
Background
Z-arrays originate in various ways. E.g.,
- When the gradient is zero,
grad
produces a Z-array for the gradient. - Z-arrays are required to be passed in for some cotangent values.
Does it really make sense to consider an array of dtype float0
to be a numpy array? The Numpy namespace doesn't even know anything about float0. It seems to me that an array with a given library's dtypes should be considered to belong to that library. Thus, I expect:
isinstance(np.zeros(10, np.float32), np.ndarray) # True
isinstance(np.zeros(10, jax.float0), np.ndarray) # False (but it's not!)
isinstance(jnp.zeros(10, jnp.float32), np.ndarray) # False
isinstance(np.zeros(10, np.float32), jax.Array) # False
isinstance(np.zeros(10, jax.float0), jax.Array) # True (but it's not!)
isinstance(jnp.zeros(10, jnp.float32), jax.Array) # True
Also, it is very odd that jnp.zeros_like
applied to a Z-array (purportedly a numpy array according to get-namespace) returns a Z-array--a numpy array! An Array API function belonging to one namespace should probably never return an array belonging to another namespace?
Solutions
Some possible solutions:
- Reconsider jnp.zeros does not support float0 as a dtype #4433 and have a Z-array be a Jax array of some new type that responds true to
isinstance(z, jax.Array)
. This would fix the isinstance-problem above and the crashing. - Fix
array_api_compat.get_namespace
to consider Z-arrays to belong to the Jax namespace? This is the easier fix, which should at least resolve the crashing.
In my code, the easiest solution was to have a modified get_namespace
, but I would prefer one of the above solutions so that I don't run into this problem with other libraries.
def fix_zero(x: Any) -> Any:
if x.dtype == float0:
return jnp.zeros(x.shape)
return x
def get_namespace_fixed(self, *x):
values = [fix_zero(getattr(self, field.name))
for field in fields(self)] + list(x)
return get_namespace(*values)
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.25
jaxlib: 0.4.25
numpy: 1.26.4
python: 3.11.8 (main, Feb 22 2024, 17:25:49) [GCC 11.4.0]