-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jax.nn.dot_product_attention CuDNN implementation raises tensor stride error during jit compile #25986
Comments
Do you have a complete reproducer? I believe the problem here might be that XLA uses an incompatible layout for some intermediate value. If that's the case, the behavior will probably depend on what's around the cudnn call. |
@jreiffers, unfortunately, I can't seem to reproduce this error with a minimal example of our model, which is one of the reasons it is quite hard to debug. It seems to only occur when used with the rest of our training code (which is part of a fairly large codebase). That said, regarding the layout, I have an XLA dump containing some of the HLO from the "before-optimisation" pass. I assume there are no other passes because the compilation fails at this point. I think the layout for the q, k and v tensors seems okay here. The inner two dimensions are swapped in the layout; however, the embedding dimension that the error message complains about seems like it should have a stride of 1 here (unless I am reading something wrong).
|
Did you run this just with |
No, I hadn't initially run with the regex flag - sorry about that! I have done so now, and I am seeing many more passes in the XLA dump. I have attached the final pass below. Does this imply that the error occurs within the fusion-block-level-rewriter? I have extracted the relevant lines around the CuDNN calls from the attached file here:
|
Description
I am currently experiencing an issue where I am getting a CuDNN error relating to the stride of my K matrix when using
jax.nn.dot_product_attention
within a flax model. This occurs when jitting and the error stems from the CuDNN dimension checks here. I am not sure what exactly is causing the striding issue with thek
tensor, and I have checked the shapes and sharding for the inputs; however, I am struggling to find a way to debug this issue further.When using the
implementation
argument set to'xla'
, the model jits, and I am able to train with it.The shapes for
q
,k
andv
are all(8, 2048, 40, 128)
and all are sharded along the first (batch) dimension, having the following sharding:NamedSharding(mesh=Mesh('dp': 1, 'fsdp': 8), spec=PartitionSpec('fsdp',), memory_kind=device)
.The function is called as below:
This gives the following error:
If there are any ways to further debug the striding of my underlying tensor, and, if possible, how to force a contiguous layout that matches that of the shape of my tensor, please let me know.
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: