Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.scipy.linalg.expm yields wrong values (for symmetric matrices) #25987

Open
Ferlemann opened this issue Jan 20, 2025 · 0 comments
Open

jax.scipy.linalg.expm yields wrong values (for symmetric matrices) #25987

Ferlemann opened this issue Jan 20, 2025 · 0 comments
Labels
bug Something isn't working

Comments

@Ferlemann
Copy link

Description

When calculating the matrix exponential using jax.scipy.linalg.expm the results mismatch those of scipy.linalg.expm for some (in this case symmetric) sampled matrices. The calculation via jax.scipy.linalg.eigh works fine. The same holds for scaling down the matrix elements and using jax.numpy.linalg.matrix_power.

As mentioned, this happens only for some of the sampled matrices, I cannot identify a clear pattern. This might be related to #17756.

In my case this deviation caused completely false results in a simulation i was running. So it has potential to cause a lot of trouble for those who are not aware of the issue.

Here is a code to reproduce the issue:

import jax.numpy as jnp
from jax.scipy.linalg import expm as jexpm
from jax.scipy.linalg import eigh as jeigh
from jax.numpy.linalg import matrix_power as jmatrix_power

import numpy as np
from scipy.linalg import expm, eigh

for _ in range(50):
    M = np.asarray(np.random.random(size=(2,2)), dtype = 'float32')
    M_np = (M + M.T)/2*1e2
    M_jnp = jnp.asarray(M_np)

    U_np_exp = expm(1.0j*M_np)
    U_jnp_exp = jexpm(1.0j*M_jnp, max_squarings=32)

    w, v = eigh(M_np)
    U_np_eig = np.dot(np.dot(v, np.diag(np.exp(1.0j*w))), v.T)

    w, v = jeigh(M_jnp)
    U_jnp_eig = jnp.dot(jnp.dot(v, jnp.diag(jnp.exp(1.0j*w))), v.T)

    N = 100

    U_jnp_power = jexpm(1.0j*M_jnp/N, max_squarings=32)
    U_jnp_power = jmatrix_power(U_jnp_power, N)

    dev_jnp = np.sum(np.abs(U_jnp_exp - U_jnp_eig))
    dev_np = np.sum(np.abs(U_np_exp - U_np_eig))
    dev_eig_eig = np.sum(np.abs(U_np_eig - U_jnp_eig))
    dev_exp_power = np.sum(np.abs(U_np_exp - U_jnp_power))

    if np.max([dev_jnp, dev_np, dev_eig_eig, dev_exp_power]) > 1e-2:
        print(dev_jnp, dev_np, dev_eig_eig, dev_exp_power)
        print(M_np)
        print()

Example output:

0.022798402 1.0083479e-05 7.1949174e-08 1.24529515e-05
[[75.79973  49.756954]
 [49.756954 54.81395 ]]

0.020133033 8.905516e-06 2.9802322e-08 1.4769792e-05
[[48.398285 71.38826 ]
 [71.38826  38.422195]]

0.028816395 1.3875822e-05 7.291439e-08 2.0913027e-05
[[56.6334   61.245773]
 [61.245773 57.220978]]

0.02039617 1.4980869e-05 2.4110586e-08 1.9685613e-05
[[36.85694  68.622444]
 [68.622444 54.947674]]

0.0100640785 1.0231617e-05 2.9802322e-08 2.2453856e-05
[[30.75979 71.52322]
 [71.52322 43.80476]]

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.38
jaxlib: 0.4.38
numpy:  1.26.4
python: 3.10.10 (tags/v3.10.10:aad5f6a, Feb  7 2023, 17:20:36) [MSC v.1929 64 bit (AMD64)]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Windows', node='Dell-Latitude-7530', release='10', version='10.0.22631', machine='AMD64')
@Ferlemann Ferlemann added the bug Something isn't working label Jan 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant