Tracer escaping in linalg.solve
with ensure_compile_time_eval
as of jax 0.4.36
#25847
Labels
bug
Something isn't working
Description
I am seeing unexpected jax tracer escape when using
jax.linalg.solve
in thejax.ensure_compile_time_eval
context manager. This seems to occur for jax >= 0.4.36. Below is a simple reproduction.This gives the following error:
I tried using the
jax.checking_leaks
context manager but it does not yield any additional info.System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: