Open
Description
The spec seems to imply that xp.finfo(xp.float32).eps
is a python float, but numpy and jax.numpy use numpy scalars instead
In [1]: import torch
In [2]: torch.finfo(torch.float32).eps
Out[2]: 1.1920928955078125e-07
In [3]: type(torch.finfo(torch.float32).eps)
Out[3]: float
In [4]: import numpy as np
In [5]: type(np.finfo(np.float32).eps)
Out[5]: numpy.float32
In [6]: import jax.numpy as jnp
In [7]: type(jnp.finfo(jnp.float32).eps)
Out[7]: numpy.float32
Not hard to work around on the -compat level, even if having a wrapper for finfo feels a bit cheesy.