-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JAX getting stuck with asynchronous dispatching (CPU backend) #25861
Comments
This has a similar "flavor" to the jax hanging issues I reported here: #24255. For example, I notice we are both using callbacks. Although I identified a workaround in that case, I still regularly encounter jax hanging. In particular, it occurs for the test suite for the invrs-gym package (https://github.com/invrs-io/gym) for versions of jax that are 0.4.31 and newer. A few runs of @roth-jakob 's script also suggest that the issue does not exist at least for jax 0.4.30 |
Thanks for the input! Yes, this seems related. As for the code snipped above, also for the example in #24255, the JAX haning issues seem to come from JAX asynchronous dispatching. By setting |
I also see regular hangs in our pytest suite running on CPU, which are solved by disabling asynchronous dispatch. |
Well, what's happening here is that your callback is calling back into JAX to dispatch a JAX computation. It's this one: If you change your code to do this:
it will not hang. Now I admit this is not a great behavior. This should either work or give a loud error. |
Many thanks for the explanation! An error message would of course be great. Otherwise, a warning in the docstring of the callback not to call back into JAX could also help. |
Thanks for the ping, @hawkinsp! This is definitely on my radar, and I've been poking at this in the background since #24255, but I haven't come up with any great solutions yet. Adding a comment to the docstrings (and/or the callbacks tutorial) as a temporary solution sounds like a great idea. @roth-jakob, would you be interested in opening a PR to add that? |
Yes, I'm happy to add warnings and open a PR! |
@hawkinsp is there a relatively easy way to identify this kind of error in a large codebase? |
FYI, I am finding that Jax can hang even in cases where callback is entirely avoided. Specifically, it occurs when running the invrs-gym test suite via pytest. So, there may be another issue besides the callback-related one. |
It would probably be helpful to have a reproducer for that one, ideally the smallest you can. The repro in this issue is great because it's very short and easy to run. |
I think the easiest thing would be if we added an error if you called into the jit or pmap dispatch paths from within a callback. It would then be obvious to you! |
Ok, I may have actually spoken too soon. There is a callback in a dependency that I oversaw (in agjax), which at a minimum is making use of |
Callback functions should not call into JAX. This information was missing in the docs of the callbacks. This commit adds this information to the docs. See: jax-ml#25861, jax-ml#24255
Description
In a larger software project, I noticed that JAX occasionally gets stuck in computations when asynchronous dispatching is enabled for the CPU. Finding a short reproducer is not so easy, but I came up with the script below (which doesn't do any useful computation but should run through). On my machine, but also on the Google Colab CPU runtime, the script almost always hangs.
Note: When running in Colab or IPython it appears as if the script would run through, but adding an additional input cell below, such as
print(post_mean)
shows that the execution got stuck.Disabling asynchronous dispatch with
jax.config.update('jax_cpu_enable_async_dispatch', False)
makes the script run to the end.Interestingly, the script will also run if you wrap the
vmap_static_cg
call in ajax.block_until_ready( )
.System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: