Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX getting stuck with asynchronous dispatching (CPU backend) #25861

Open
roth-jakob opened this issue Jan 13, 2025 · 12 comments
Open

JAX getting stuck with asynchronous dispatching (CPU backend) #25861

roth-jakob opened this issue Jan 13, 2025 · 12 comments
Assignees
Labels
bug Something isn't working

Comments

@roth-jakob
Copy link
Contributor

roth-jakob commented Jan 13, 2025

Description

In a larger software project, I noticed that JAX occasionally gets stuck in computations when asynchronous dispatching is enabled for the CPU. Finding a short reproducer is not so easy, but I came up with the script below (which doesn't do any useful computation but should run through). On my machine, but also on the Google Colab CPU runtime, the script almost always hangs.
Note: When running in Colab or IPython it appears as if the script would run through, but adding an additional input cell below, such as print(post_mean) shows that the execution got stuck.

Disabling asynchronous dispatch with jax.config.update('jax_cpu_enable_async_dispatch', False) makes the script run to the end.

import jax
from jax import numpy as jnp
from jax import random

# jax.config.update('jax_cpu_enable_async_dispatch', False)

seed = 42
key = random.PRNGKey(seed)
key, subkey = random.split(key)

def cg_pretty_print_it(
    name,
    i,
    maxiter=None,
):
    if maxiter is not None and i == maxiter:
        i_str = f"maxiter: ({i})"
    else:
        i_str = str(i)
    msg = f"{name}: Iteration {i_str}"
    print(msg)

def static_cg(pos):
    from jax.debug import callback
    from jax.lax import while_loop

    name = 'loop'
    maxiter = 100

    def pp(arg):
        cg_pretty_print_it(name, **arg)

    def continue_condition(v):
        return v["info"] < -1

    def cg_single_step(v):
        info = v["info"]
        pos, i = v["pos"], v["iteration"]
        i += 1

        pos = 1 + pos
        info = jnp.where(pos[0,0] > 500, 0, info)
        if name is not None:
            printable_state = {
                "i": i,
                "maxiter": maxiter,
            }
            callback(pp, printable_state)
        ret = {
            "info": info,
            "pos": pos,
            "iteration": i,
        }
        return ret

    val = {
        "info": jnp.array(-2, dtype=int),
        "pos": pos,
        "iteration": jnp.array(0),
    }
    val = while_loop(continue_condition, cg_single_step, val)

    info = val["info"]
    return val["pos"], info


vmap_static_cg = jax.vmap(static_cg, in_axes = 0)

n_samples = 30
inputs = 10*random.normal(subkey, shape=(n_samples, 128, 128))

smpls, smpls_info = vmap_static_cg(inputs)
# smpls, smpls_info = jax.block_until_ready(vmap_static_cg(inputs)) # works

print('done with vmap function')


def mean(forest):
    from functools import reduce

    norm = 1.0 / len(forest)
    add = lambda a,b: a+b
    m = norm * reduce(add, forest)
    return m

post_mean = mean(tuple(s for s in smpls))

Interestingly, the script will also run if you wrap the vmap_static_cg call in a jax.block_until_ready( ).

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.38
jaxlib: 0.4.38
numpy:  2.2.1
python: 3.12.8 | packaged by conda-forge | (main, Dec  5 2024, 14:24:40) [GCC 13.3.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='tux', release='6.12.9-arch1-1', version='#1 SMP PREEMPT_DYNAMIC Fri, 10 Jan 2025 00:39:41 +0000', machine='x86_64')
@roth-jakob roth-jakob added the bug Something isn't working label Jan 13, 2025
@mfschubert
Copy link

This has a similar "flavor" to the jax hanging issues I reported here: #24255. For example, I notice we are both using callbacks.

Although I identified a workaround in that case, I still regularly encounter jax hanging. In particular, it occurs for the test suite for the invrs-gym package (https://github.com/invrs-io/gym) for versions of jax that are 0.4.31 and newer.

A few runs of @roth-jakob 's script also suggest that the issue does not exist at least for jax 0.4.30

@roth-jakob
Copy link
Contributor Author

Thanks for the input! Yes, this seems related. As for the code snipped above, also for the example in #24255, the JAX haning issues seem to come from JAX asynchronous dispatching. By setting jax.config.update('jax_cpu_enable_async_dispatch', False) also the example of #24255 also runs without hanging.

@PhilipVinc
Copy link
Contributor

I also see regular hangs in our pytest suite running on CPU, which are solved by disabling asynchronous dispatch.
I never managed to identify the root cause, however.

@hawkinsp
Copy link
Collaborator

Well, what's happening here is that your callback is calling back into JAX to dispatch a JAX computation. It's this one: i == maxiter:.

If you change your code to do this:

    if maxiter is not None and np.asarray(i) == np.asarray(maxiter):

it will not hang.

Now I admit this is not a great behavior. This should either work or give a loud error.

@dfm @danielsuo ?

@roth-jakob
Copy link
Contributor Author

Many thanks for the explanation!

An error message would of course be great. Otherwise, a warning in the docstring of the callback not to call back into JAX could also help.

@dfm dfm self-assigned this Jan 17, 2025
@dfm
Copy link
Collaborator

dfm commented Jan 17, 2025

Thanks for the ping, @hawkinsp! This is definitely on my radar, and I've been poking at this in the background since #24255, but I haven't come up with any great solutions yet.

Adding a comment to the docstrings (and/or the callbacks tutorial) as a temporary solution sounds like a great idea. @roth-jakob, would you be interested in opening a PR to add that?

@roth-jakob
Copy link
Contributor Author

Yes, I'm happy to add warnings and open a PR!

@PhilipVinc
Copy link
Contributor

@hawkinsp is there a relatively easy way to identify this kind of error in a large codebase?
Is there an easy way to throw an error if I'm using jnp.asarrayor similar functions inside of a callback?

@mfschubert
Copy link

FYI, I am finding that Jax can hang even in cases where callback is entirely avoided. Specifically, it occurs when running the invrs-gym test suite via pytest. So, there may be another issue besides the callback-related one.

@hawkinsp
Copy link
Collaborator

FYI, I am finding that Jax can hang even in cases where callback is entirely avoided. Specifically, it occurs when running the invrs-gym test suite via pytest. So, there may be another issue besides the callback-related one.

It would probably be helpful to have a reproducer for that one, ideally the smallest you can. The repro in this issue is great because it's very short and easy to run.

@hawkinsp
Copy link
Collaborator

@hawkinsp is there a relatively easy way to identify this kind of error in a large codebase? Is there an easy way to throw an error if I'm using jnp.asarrayor similar functions inside of a callback?

I think the easiest thing would be if we added an error if you called into the jit or pmap dispatch paths from within a callback. It would then be obvious to you!

@mfschubert
Copy link

FYI, I am finding that Jax can hang even in cases where callback is entirely avoided. Specifically, it occurs when running the invrs-gym test suite via pytest. So, there may be another issue besides the callback-related one.

It would probably be helpful to have a reproducer for that one, ideally the smallest you can. The repro in this issue is great because it's very short and easy to run.

Ok, I may have actually spoken too soon. There is a callback in a dependency that I oversaw (in agjax), which at a minimum is making use of jax.tree_util. I suppose this would also cause the issue.

roth-jakob added a commit to roth-jakob/jax that referenced this issue Jan 19, 2025
Callback functions should not call into JAX. This information was
missing in the docs of the callbacks. This commit adds this information
to the docs.

See: jax-ml#25861, jax-ml#24255
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants