You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Jax exceptions from scan don't give any indication where the actual error occurred. It would be better if we could catch this somehow and give a better stack trace.
something like:
scan(self.foo, ...)
def foo(self, x: NamedArray, mask: Optional[AttentionMask | NamedArray], layer_idx, *, key):
k1, k2, k3, k4 = haliax.jax_utils.maybe_rng_split(key, 4)
attn_output = self.attn(self.ln_1(x), mask=mask, layer_idx=layer_idx, key=k1)
attn_output = self.resid_dropout(attn_output, key=k2)
x = x + attn_output
ff_output = self.mlp(self.ln_2(x), key=k3)
ff_output = self.resid_dropout(ff_output, key=k4)
x = x + ff_output
#import ipdb; ipdb.set_trace()
if jnp.equal(layer_idx.array, 4):
#x = x + 0.01*jnp.sin(x*1e2)
x = x + 0.01*hax.sin(x*1e2)
return x
produced:
carry, ys = lax.scan(wrapped_fn, init, leaves, reverse=reverse, unroll=unroll)
File "/nlp/scr/ahmedah/miniconda3/envs/locked/lib/python3.10/site-packages/haliax/hof.py", line 83, in wrapped_fn
carry, y = f(carry, *args, **kwargs)
File "/nlp/scr/ahmedah/miniconda3/envs/locked/lib/python3.10/site-packages/haliax/hof.py", line 124, in scan_compatible_fn
return fn(carry, *args, **kwargs), None
File "/nlp/scr/ahmedah/miniconda3/envs/locked/lib/python3.10/site-packages/haliax/jax_utils.py", line 69, in wrapper
dynamic_out, static_out = checkpointed_fun(static, dynamic)
jax.errors.ConcretizationTypeError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function new_fun at /nlp/scr/ahmedah/miniconda3/envs/locked/lib/python3.10/site-packages/jax/_src/ad_checkpoint.py:357 for checkpoint. This concrete value was not available in Python because it depends on the value of the argument dyn_args[0][0][3][<flat index 0>].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
Consider using the `static_argnums` parameter for `jax.remat` or `jax.checkpoint`. See the `jax.checkpoint` docstring and its example involving `static_argnums`:
https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
The text was updated successfully, but these errors were encountered:
Jax exceptions from scan don't give any indication where the actual error occurred. It would be better if we could catch this somehow and give a better stack trace.
something like:
produced:
The text was updated successfully, but these errors were encountered: