Skip to content

Arrays having dtype float0 are broken with the Array API #20620

Open
@NeilGirdhar

Description

@NeilGirdhar

Description

For all purposes, numpy arrays having dtype float0 (let's call them Z-arrays) are, in fact, Jax arrays. But these arrays are treated as NumPy arrays by array_api_compat.get_namespace leading to crashes when getting the namespace of multiple arrays. Therefore, something appears to be fundamentally broken.

Background

Z-arrays originate in various ways. E.g.,

  • When the gradient is zero, grad produces a Z-array for the gradient.
  • Z-arrays are required to be passed in for some cotangent values.

Does it really make sense to consider an array of dtype float0 to be a numpy array? The Numpy namespace doesn't even know anything about float0. It seems to me that an array with a given library's dtypes should be considered to belong to that library. Thus, I expect:

isinstance(np.zeros(10, np.float32), np.ndarray)  # True
isinstance(np.zeros(10, jax.float0), np.ndarray)  # False (but it's not!)
isinstance(jnp.zeros(10, jnp.float32), np.ndarray) # False
isinstance(np.zeros(10, np.float32), jax.Array) # False
isinstance(np.zeros(10, jax.float0), jax.Array) # True (but it's not!)
isinstance(jnp.zeros(10, jnp.float32), jax.Array) # True

Also, it is very odd that jnp.zeros_like applied to a Z-array (purportedly a numpy array according to get-namespace) returns a Z-array--a numpy array! An Array API function belonging to one namespace should probably never return an array belonging to another namespace?

Solutions

Some possible solutions:

  • Reconsider jnp.zeros does not support float0 as a dtype #4433 and have a Z-array be a Jax array of some new type that responds true to isinstance(z, jax.Array). This would fix the isinstance-problem above and the crashing.
  • Fix array_api_compat.get_namespace to consider Z-arrays to belong to the Jax namespace? This is the easier fix, which should at least resolve the crashing.

In my code, the easiest solution was to have a modified get_namespace, but I would prefer one of the above solutions so that I don't run into this problem with other libraries.

def fix_zero(x: Any) -> Any:
    if x.dtype == float0:
        return jnp.zeros(x.shape)
    return x

def get_namespace_fixed(self, *x):
        values = [fix_zero(getattr(self, field.name))
                  for field in fields(self)] + list(x)
        return get_namespace(*values)

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.25
jaxlib: 0.4.25
numpy: 1.26.4
python: 3.11.8 (main, Feb 22 2024, 17:25:49) [GCC 11.4.0]

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions