You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When calculating the matrix exponential using jax.scipy.linalg.expm the results mismatch those of scipy.linalg.expm for some (in this case symmetric) sampled matrices. The calculation via jax.scipy.linalg.eigh works fine. The same holds for scaling down the matrix elements and using jax.numpy.linalg.matrix_power.
As mentioned, this happens only for some of the sampled matrices, I cannot identify a clear pattern. This might be related to #17756.
In my case this deviation caused completely false results in a simulation i was running. So it has potential to cause a lot of trouble for those who are not aware of the issue.
Here is a code to reproduce the issue:
import jax.numpy as jnp
from jax.scipy.linalg import expm as jexpm
from jax.scipy.linalg import eigh as jeigh
from jax.numpy.linalg import matrix_power as jmatrix_power
import numpy as np
from scipy.linalg import expm, eigh
for _ in range(50):
M = np.asarray(np.random.random(size=(2,2)), dtype = 'float32')
M_np = (M + M.T)/2*1e2
M_jnp = jnp.asarray(M_np)
U_np_exp = expm(1.0j*M_np)
U_jnp_exp = jexpm(1.0j*M_jnp, max_squarings=32)
w, v = eigh(M_np)
U_np_eig = np.dot(np.dot(v, np.diag(np.exp(1.0j*w))), v.T)
w, v = jeigh(M_jnp)
U_jnp_eig = jnp.dot(jnp.dot(v, jnp.diag(jnp.exp(1.0j*w))), v.T)
N = 100
U_jnp_power = jexpm(1.0j*M_jnp/N, max_squarings=32)
U_jnp_power = jmatrix_power(U_jnp_power, N)
dev_jnp = np.sum(np.abs(U_jnp_exp - U_jnp_eig))
dev_np = np.sum(np.abs(U_np_exp - U_np_eig))
dev_eig_eig = np.sum(np.abs(U_np_eig - U_jnp_eig))
dev_exp_power = np.sum(np.abs(U_np_exp - U_jnp_power))
if np.max([dev_jnp, dev_np, dev_eig_eig, dev_exp_power]) > 1e-2:
print(dev_jnp, dev_np, dev_eig_eig, dev_exp_power)
print(M_np)
print()
Description
When calculating the matrix exponential using jax.scipy.linalg.expm the results mismatch those of scipy.linalg.expm for some (in this case symmetric) sampled matrices. The calculation via jax.scipy.linalg.eigh works fine. The same holds for scaling down the matrix elements and using jax.numpy.linalg.matrix_power.
As mentioned, this happens only for some of the sampled matrices, I cannot identify a clear pattern. This might be related to #17756.
In my case this deviation caused completely false results in a simulation i was running. So it has potential to cause a lot of trouble for those who are not aware of the issue.
Here is a code to reproduce the issue:
Example output:
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: