Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

better stack trace for concretization inside scan #57

Open
dlwh opened this issue Jan 17, 2024 · 0 comments
Open

better stack trace for concretization inside scan #57

dlwh opened this issue Jan 17, 2024 · 0 comments

Comments

@dlwh
Copy link
Member

dlwh commented Jan 17, 2024

Jax exceptions from scan don't give any indication where the actual error occurred. It would be better if we could catch this somehow and give a better stack trace.

something like:

scan(self.foo, ...)


def foo(self, x: NamedArray, mask: Optional[AttentionMask | NamedArray], layer_idx, *, key):
        k1, k2, k3, k4 = haliax.jax_utils.maybe_rng_split(key, 4)

        attn_output = self.attn(self.ln_1(x), mask=mask, layer_idx=layer_idx, key=k1)
        attn_output = self.resid_dropout(attn_output, key=k2)
        x = x + attn_output 

        ff_output = self.mlp(self.ln_2(x), key=k3)
        ff_output = self.resid_dropout(ff_output, key=k4)
        x = x + ff_output

        #import ipdb; ipdb.set_trace()
        if jnp.equal(layer_idx.array, 4):
            #x = x + 0.01*jnp.sin(x*1e2)
            x = x + 0.01*hax.sin(x*1e2)

        return x

produced:


  carry, ys = lax.scan(wrapped_fn, init, leaves, reverse=reverse, unroll=unroll)
 File "/nlp/scr/ahmedah/miniconda3/envs/locked/lib/python3.10/site-packages/haliax/hof.py", line 83, in wrapped_fn
  carry, y = f(carry, *args, **kwargs)
 File "/nlp/scr/ahmedah/miniconda3/envs/locked/lib/python3.10/site-packages/haliax/hof.py", line 124, in scan_compatible_fn
  return fn(carry, *args, **kwargs), None
 File "/nlp/scr/ahmedah/miniconda3/envs/locked/lib/python3.10/site-packages/haliax/jax_utils.py", line 69, in wrapper
  dynamic_out, static_out = checkpointed_fun(static, dynamic)
jax.errors.ConcretizationTypeError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function new_fun at /nlp/scr/ahmedah/miniconda3/envs/locked/lib/python3.10/site-packages/jax/_src/ad_checkpoint.py:357 for checkpoint. This concrete value was not available in Python because it depends on the value of the argument dyn_args[0][0][3][<flat index 0>].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

Consider using the `static_argnums` parameter for `jax.remat` or `jax.checkpoint`. See the `jax.checkpoint` docstring and its example involving `static_argnums`:
https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant