Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.nn.dot_product_attention CuDNN implementation raises tensor stride error during jit compile #25986

Open
liamclarkza opened this issue Jan 20, 2025 · 4 comments
Labels
bug Something isn't working NVIDIA GPU Issues specific to NVIDIA GPUs

Comments

@liamclarkza
Copy link

Description

I am currently experiencing an issue where I am getting a CuDNN error relating to the stride of my K matrix when using jax.nn.dot_product_attention within a flax model. This occurs when jitting and the error stems from the CuDNN dimension checks here. I am not sure what exactly is causing the striding issue with the k tensor, and I have checked the shapes and sharding for the inputs; however, I am struggling to find a way to debug this issue further.

When using the implementation argument set to 'xla', the model jits, and I am able to train with it.

The shapes for q, k and v are all (8, 2048, 40, 128) and all are sharded along the first (batch) dimension, having the following sharding:
NamedSharding(mesh=Mesh('dp': 1, 'fsdp': 8), spec=PartitionSpec('fsdp',), memory_kind=device).

The function is called as below:

jax.nn.dot_product_attention(
    q.astype(jnp.bfloat16),
    k.astype(jnp.bfloat16),
    v.astype(jnp.bfloat16),
    mask=None, # I have tested with/without masking but get the same error either way
    scale=float(q.shape[-1] ** -0.5),
    implementation='cudnn',
)

This gives the following error:

*** truncated ***
File "/app/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 427, in compile_or_get_cached
return _compile_and_write_cache(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/app/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 655, in _compile_and_write_cache
executable = backend_compile(
^^^^^^^^^^^^^^^^
File "/app/.venv/lib/python3.11/site-packages/jax/_src/profiler.py", line 333, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/app/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 273, in backend_compile
raise e
File "/app/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 267, in backend_compile
return backend.compile(built_c, compile_options=options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: The stride for the last dimension corresponding to the embedding size per head should be 1 for input_names::K
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(8221): 'graph_.build_operation_graph(cudnn->handle())'
File "/app/.venv/lib/python3.11/site-packages/jax/_src/profiler.py", line 333, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/app/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 273, in backend_compile
raise e
File "/app/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 267, in backend_compile
return backend.compile(built_c, compile_options=options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: The stride for the last dimension corresponding to the embedding size per head should be 1 for input_names::K
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(8221): 'graph_.build_operation_graph(cudnn->handle())'

If there are any ways to further debug the striding of my underlying tensor, and, if possible, how to force a contiguous layout that matches that of the shape of my tensor, please let me know.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.38
jaxlib: 0.4.38
numpy:  2.0.2
python: 3.11.11 (main, Jan 14 2025, 22:49:08) [Clang 19.1.6 ]
device info: NVIDIA H100 80GB HBM3-8, 8 local devices"
process_count: 1
platform: uname_result(system='Linux', node='experiment-2eb4a7d7-dad7-head', release='6.8.0-49-generic', version='#49~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Wed Nov  6 17:42:15 UTC 2', machine='x86_64')


$ nvidia-smi
Mon Jan 20 11:16:51 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.127.05             Driver Version: 550.127.05     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H100 80GB HBM3          On  |   00000000:19:00.0 Off |                    0 |
| N/A   38C    P0            120W /  700W |     550MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          On  |   00000000:3B:00.0 Off |                    0 |
| N/A   36C    P0            118W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA H100 80GB HBM3          On  |   00000000:4C:00.0 Off |                    0 |
| N/A   33C    P0            115W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA H100 80GB HBM3          On  |   00000000:5D:00.0 Off |                    0 |
| N/A   36C    P0            123W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA H100 80GB HBM3          On  |   00000000:9B:00.0 Off |                    0 |
| N/A   38C    P0            118W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   5  NVIDIA H100 80GB HBM3          On  |   00000000:BB:00.0 Off |                    0 |
| N/A   36C    P0            114W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   6  NVIDIA H100 80GB HBM3          On  |   00000000:CB:00.0 Off |                    0 |
| N/A   37C    P0            123W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   7  NVIDIA H100 80GB HBM3          On  |   00000000:DB:00.0 Off |                    0 |
| N/A   34C    P0            117W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A     23980      C   /app/.venv/bin/python                           0MiB |
|    1   N/A  N/A     23980      C   /app/.venv/bin/python                           0MiB |
|    2   N/A  N/A     23980      C   /app/.venv/bin/python                           0MiB |
|    3   N/A  N/A     23980      C   /app/.venv/bin/python                           0MiB |
|    4   N/A  N/A     23980      C   /app/.venv/bin/python                           0MiB |
|    5   N/A  N/A     23980      C   /app/.venv/bin/python                           0MiB |
|    6   N/A  N/A     23980      C   /app/.venv/bin/python                           0MiB |
|    7   N/A  N/A     23980      C   /app/.venv/bin/python                           0MiB |
+-----------------------------------------------------------------------------------------+
@liamclarkza liamclarkza added the bug Something isn't working label Jan 20, 2025
@jreiffers
Copy link
Contributor

Do you have a complete reproducer? I believe the problem here might be that XLA uses an incompatible layout for some intermediate value. If that's the case, the behavior will probably depend on what's around the cudnn call.

@liamclarkza
Copy link
Author

@jreiffers, unfortunately, I can't seem to reproduce this error with a minimal example of our model, which is one of the reasons it is quite hard to debug. It seems to only occur when used with the rest of our training code (which is part of a fairly large codebase).

That said, regarding the layout, I have an XLA dump containing some of the HLO from the "before-optimisation" pass. I assume there are no other passes because the compilation fails at this point.

I think the layout for the q, k and v tensors seems okay here. The inner two dimensions are swapped in the layout; however, the embedding dimension that the error message complains about seems like it should have a stride of 1 here (unless I am reading something wrong).

# Forward pass

  # q, k, v tensors:
  convert.2768 = bf16[32,2048,40,128]{3,1,2,0} convert(transpose.2767), metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/jvp(ESM2Model)/while/body/_layers/transformer_block/self_attention/attention_fn/rotary_pos_emb/convert_element_type" source_file="/app/waffle/_src/models/esm2.py" source_line=188}
  convert.2803 = bf16[32,2048,40,128]{3,1,2,0} convert(transpose.2802), metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/jvp(ESM2Model)/while/body/_layers/transformer_block/self_attention/attention_fn/rotary_pos_emb/convert_element_type" source_file="/app/waffle/_src/models/esm2.py" source_line=188}
  call.2817 = bf16[1,1,2048,2048]{3,2,1,0} call(slice.2816, Arg_8.2611, Arg_9.2612), to_apply=_where_17.2589
  # mask:
  reshape.2814 = pred[32,1,2048]{2,1,0} reshape(broadcast.2813), metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/jvp(ESM2Model)/while/body/_layers/transformer_block/self_attention/attention_fn/broadcast_in_dim" source_file="/app/waffle/_src/models/esm2.py" source_line=350}
  # not sure what these args are:
  Arg_10.2613 = bf16[0]{0} parameter(10), metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/jvp(ESM2Model)/while/body/closed_call"}
  Arg_11.2614 = bf16[0]{0} parameter(11), metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/jvp(ESM2Model)/while/body/closed_call"}
  # attention call:
  custom-call.2818 = (bf16[32,2048,40,128]{3,2,1,0}, f32[32,40,2048]{2,1,0}) custom-call(convert.2768, convert.2803, reshape.2810, call.2817, Arg_10.2613, /*index=5*/Arg_11.2614), custom_call_target="CustomSPMDPartitioning", api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/jvp(ESM2Model)/while/body/_layers/transformer_block/self_attention/attention_fn/custom_partitioning" source_file="/app/waffle/_src/models/esm2.py" source_line=351}, backend_config="137314937879632"
  # outputs:
  get-tuple-element.2819 = bf16[32,2048,40,128]{3,2,1,0} get-tuple-element(custom-call.2818), index=0, metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/jvp(ESM2Model)/while/body/_layers/transformer_block/self_attention/attention_fn/custom_partitioning" source_file="/app/waffle/_src/models/esm2.py" source_line=351}
  get-tuple-element.2820 = f32[32,40,2048]{2,1,0} get-tuple-element(custom-call.2818), index=1, metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/jvp(ESM2Model)/while/body/_layers/transformer_block/self_attention/attention_fn/custom_partitioning" source_file="/app/waffle/_src/models/esm2.py" source_line=351}


# Rematerialised forward for backward

  # q, k, v tensors:
  convert.4080 = bf16[32,2048,40,128]{3,1,2,0} convert(transpose.4079), metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/transpose(jvp(ESM2Model))/while/body/checkpoint/rematted_computation/_layers/transformer_block/self_attention/attention_fn/rotary_pos_emb/convert_element_type" source_file="/app/waffle/_src/models/esm2.py" source_line=188}
  convert.4130 = bf16[32,2048,40,128]{3,1,2,0} convert(transpose.4129), metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/transpose(jvp(ESM2Model))/while/body/checkpoint/rematted_computation/_layers/transformer_block/self_attention/attention_fn/rotary_pos_emb/convert_element_type" source_file="/app/waffle/_src/models/esm2.py" source_line=188}
  reshape.4030 = bf16[32,2048,40,128]{3,2,1,0} reshape(add.4029), metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/transpose(jvp(ESM2Model))/while/body/checkpoint/rematted_computation/_layers/transformer_block/self_attention/attention_fn/project_value/reshape" source_file="/app/waffle/_src/models/esm2.py" source_line=266}
  # mask:
  call.4137 = bf16[1,1,2048,2048]{3,2,1,0} call(slice.4136, constant.3947, constant.3946), to_apply=_where_27.3755
  # not sure what these args are:
  constant.3867 = bf16[0]{0} constant({})
  # attention call:
  custom-call.4138 = (bf16[32,2048,40,128]{3,2,1,0}, f32[32,40,2048]{2,1,0}) custom-call(convert.4080, convert.4130, reshape.4030, call.4137, constant.3867, /*index=5*/constant.3867), custom_call_target="CustomSPMDPartitioning", api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/transpose(jvp(ESM2Model))/while/body/checkpoint/rematted_computation/_layers/transformer_block/self_attention/attention_fn/custom_partitioning" source_file="/app/waffle/_src/models/esm2.py" source_line=351}, backend_config="137314937867472"
  # outputs:
  get-tuple-element.4139 = bf16[32,2048,40,128]{3,2,1,0} get-tuple-element(custom-call.4138), index=0, metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/transpose(jvp(ESM2Model))/while/body/checkpoint/rematted_computation/_layers/transformer_block/self_attention/attention_fn/custom_partitioning" source_file="/app/waffle/_src/models/esm2.py" source_line=351}
  get-tuple-element.4140 = f32[32,40,2048]{2,1,0} get-tuple-element(custom-call.4138), index=1, metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/transpose(jvp(ESM2Model))/while/body/checkpoint/rematted_computation/_layers/transformer_block/self_attention/attention_fn/custom_partitioning" source_file="/app/waffle/_src/models/esm2.py" source_line=351}


 # Backward

  # q, k, v, mask etc. all same as above
  # attention call:
  custom-call.4327 = (bf16[32,2048,40,128]{3,2,1,0}, bf16[32,2048,40,128]{3,2,1,0}, bf16[32,2048,40,128]{3,2,1,0}) custom-call(convert.4080, convert.4130, reshape.4030, call.4137, constant.3867, /*index=5*/constant.3867, get-tuple-element.4140, get-tuple-element.4139, convert.4326), custom_call_target="CustomSPMDPartitioning", api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/transpose(jvp(ESM2Model))/while/body/checkpoint/_layers/transformer_block/self_attention/attention_fn/custom_partitioning" source_file="/app/waffle/_src/models/esm2.py" source_line=351}, backend_config="137314937656976"
  # outputs:
  get-tuple-element.4328 = bf16[32,2048,40,128]{3,2,1,0} get-tuple-element(custom-call.4327), index=0, metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/transpose(jvp(ESM2Model))/while/body/checkpoint/_layers/transformer_block/self_attention/attention_fn/custom_partitioning" source_file="/app/waffle/_src/models/esm2.py" source_line=351}
  get-tuple-element.4329 = bf16[32,2048,40,128]{3,2,1,0} get-tuple-element(custom-call.4327), index=1, metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/transpose(jvp(ESM2Model))/while/body/checkpoint/_layers/transformer_block/self_attention/attention_fn/custom_partitioning" source_file="/app/waffle/_src/models/esm2.py" source_line=351}
  get-tuple-element.4330 = bf16[32,2048,40,128]{3,2,1,0} get-tuple-element(custom-call.4327), index=2, metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/transpose(jvp(ESM2Model))/while/body/checkpoint/_layers/transformer_block/self_attention/attention_fn/custom_partitioning" source_file="/app/waffle/_src/models/esm2.py" source_line=351}

@mjsML mjsML added the NVIDIA GPU Issues specific to NVIDIA GPUs label Jan 22, 2025
@jreiffers
Copy link
Contributor

Did you run this just with --xla_dump_to, or also with --xla_dump_hlo_pass_re? The latter should be set to .* to get dumps after every pass.

@liamclarkza
Copy link
Author

liamclarkza commented Jan 22, 2025

No, I hadn't initially run with the regex flag - sorry about that! I have done so now, and I am seeing many more passes in the XLA dump. I have attached the final pass below.

module_0027.jit__unnamed_wrapped_function_.0176.fusion-dispatch-pipeline.after_pipeline-start.before_fusion-block-level-rewriter.txt

Does this imply that the error occurs within the fusion-block-level-rewriter?

I have extracted the relevant lines around the CuDNN calls from the attached file here:

Call 1:

input_concatenate_fusion = bf16[8,2048,5,128]{3,2,1,0} fusion(get-tuple-element.579, get-tuple-element.580, get-tuple-element.520.0, get-tuple-element.109.0), kind=kInput, calls=fused_concatenate, metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/transpose(jvp(ESM2Model))/while/body/checkpoint/rematted_computation/_layers/transformer_block/self_attention/attention_fn/rotary_pos_emb/vmap(vmap(rotary_pos_emb._apply_rope))/vmap(rotary_pos_emb._apply_rope_1d)/concatenate" source_file="/app/waffle/_src/models/esm2.py" source_line=153 deduplicated_name="input_concatenate_fusion"}

input_concatenate_fusion.1 = bf16[8,2048,5,128]{3,2,1,0} fusion(get-tuple-element.579, get-tuple-element.580, get-tuple-element.521.0, get-tuple-element.111.0), kind=kInput, calls=fused_concatenate.1, metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/transpose(jvp(ESM2Model))/while/body/checkpoint/rematted_computation/_layers/transformer_block/self_attention/attention_fn/rotary_pos_emb/vmap(vmap(rotary_pos_emb._apply_rope))/vmap(rotary_pos_emb._apply_rope_1d)/concatenate" source_file="/app/waffle/_src/models/esm2.py" source_line=153 deduplicated_name="input_concatenate_fusion"}

bitcast.22772.0 = bf16[8,2048,5,128]{3,2,1,0} bitcast(loop_add_fusion.2)

loop_broadcast_fusion = bf16[1,1,2048,2048]{3,2,1,0} fusion(all-reduce-done), kind=kLoop, calls=fused_broadcast, metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/transpose(jvp(ESM2Model))/while/body/checkpoint/rematted_computation/_layers/transformer_block/self_attention/attention_fn/broadcast_in_dim" source_file="/app/waffle/_src/models/esm2.py" source_line=350}

custom-call.3.0 = (bf16[8,5,2048,128]{3,1,2,0}, f32[8,5,2048]{2,1,0}, u8[0]{0}) custom-call(input_concatenate_fusion, input_concatenate_fusion.1, bitcast.22772.0, loop_broadcast_fusion), custom_call_target="__cudnn$fmhaScaleBiasSoftmax", operand_layout_constraints={bf16[8,2048,5,128]{3,2,1,0}, bf16[8,2048,5,128]{3,2,1,0}, bf16[8,2048,5,128]{3,2,1,0}, bf16[1,1,2048,2048]{3,2,1,0}}, api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/transpose(jvp(ESM2Model))/while/body/checkpoint/rematted_computation/_layers/transformer_block/self_attention/attention_fn/custom_partitioning" source_file="/app/waffle/_src/models/esm2.py" source_line=351}, backend_config={"operation_queue_id": "0", "wait_on_operation_queues": [], "cudnn_fmha_backend_config": {"algorithm": {"algo_id": "0", "math_type": "TENSOR_OP_MATH", "tuning_knobs": {"17": "1", "24": "0"}, "is_cudnn_frontend": true, "workspace_size": "0"}, "fmha_scale": 0.08838834764831845, "intermediate_tensor_shape": {"element_type": "BF16", "dimensions": ["8", "5", "2048", "2048"], "tuple_shapes": [], "layout": {"dim_level_types": [], "dim_unique": [], "dim_ordered": [], "minor_to_major": ["3", "2", "1", "0"], "tiles": [], "element_size_in_bits": "0", "memory_space": "0", "index_primitive_type": "PRIMITIVE_TYPE_INVALID", "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", "dynamic_shape_metadata_prefix_bytes": "0"}, "is_dynamic_dimension": [false, false, false, false]}, "is_flash_attention": true, "mask_type": "NO_MASK", "bmm1_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["3"], "lhs_batch_dimensions": ["0", "2"], "rhs_batch_dimensions": ["0", "2"]}, "bmm2_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}, "dropout_rate": 0.0, "seed": 42, "sliding_window_length": 0}}


Call 2:

input_concatenate_fusion = bf16[8,2048,5,128]{3,2,1,0} fusion(get-tuple-element.579, get-tuple-element.580, get-tuple-element.520.0, get-tuple-element.109.0), kind=kInput, calls=fused_concatenate, metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/transpose(jvp(ESM2Model))/while/body/checkpoint/rematted_computation/_layers/transformer_block/self_attention/attention_fn/rotary_pos_emb/vmap(vmap(rotary_pos_emb._apply_rope))/vmap(rotary_pos_emb._apply_rope_1d)/concatenate" source_file="/app/waffle/_src/models/esm2.py" source_line=153 deduplicated_name="input_concatenate_fusion"}

input_concatenate_fusion.1 = bf16[8,2048,5,128]{3,2,1,0} fusion(get-tuple-element.579, get-tuple-element.580, get-tuple-element.521.0, get-tuple-element.111.0), kind=kInput, calls=fused_concatenate.1, metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/transpose(jvp(ESM2Model))/while/body/checkpoint/rematted_computation/_layers/transformer_block/self_attention/attention_fn/rotary_pos_emb/vmap(vmap(rotary_pos_emb._apply_rope))/vmap(rotary_pos_emb._apply_rope_1d)/concatenate" source_file="/app/waffle/_src/models/esm2.py" source_line=153 deduplicated_name="input_concatenate_fusion"}

bitcast.22772.0 = bf16[8,2048,5,128]{3,2,1,0} bitcast(loop_add_fusion.2)

get-tuple-element.126.0 = f32[8,5,2048]{2,1,0} get-tuple-element(custom-call.3.0), index=1, metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/transpose(jvp(ESM2Model))/while/body/checkpoint/rematted_computation/_layers/transformer_block/self_attention/attention_fn/custom_partitioning" source_file="/app/waffle/_src/models/esm2.py" source_line=351}

bitcast.22788.0 = bf16[8,2048,40,128]{3,2,1,0} bitcast(all-reduce-done.1)

bitcast.22794.0 = bf16[8,2048,5,128]{3,2,1,0} bitcast(get-tuple-element.116.0)

custom-call.10.0 = (bf16[8,5,2048,128]{3,1,2,0}, bf16[8,5,2048,128]{3,1,2,0}, bf16[8,5,2048,128]{3,1,2,0}, u8[0]{0}) custom-call(input_concatenate_fusion, input_concatenate_fusion.1, bitcast.22772.0, get-tuple-element.126.0, bitcast.22788.0, /*index=5*/loop_broadcast_fusion, bitcast.22794.0), custom_call_target="__cudnn$fmhaScaleBiasSoftmaxBackward", operand_layout_constraints={bf16[8,2048,5,128]{3,2,1,0}, bf16[8,2048,5,128]{3,2,1,0}, bf16[8,2048,5,128]{3,2,1,0}, f32[8,5,2048]{2,1,0}, bf16[8,2048,40,128]{3,2,1,0}, bf16[1,1,2048,2048]{3,2,1,0}, bf16[8,2048,5,128]{3,2,1,0}}, api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/transpose(jvp(ESM2Model))/while/body/checkpoint/_layers/transformer_block/self_attention/attention_fn/custom_partitioning" source_file="/app/waffle/_src/models/esm2.py" source_line=351}, backend_config={"operation_queue_id": "0", "wait_on_operation_queues": [], "cudnn_fmha_backend_config": {"algorithm": {"algo_id": "0", "math_type": "TENSOR_OP_MATH", "tuning_knobs": {"17": "1", "24": "0"}, "is_cudnn_frontend": true, "workspace_size": "0"}, "fmha_scale": 0.08838834764831845, "intermediate_tensor_shape": {"element_type": "BF16", "dimensions": ["8", "5", "2048", "2048"], "tuple_shapes": [], "layout": {"dim_level_types": [], "dim_unique": [], "dim_ordered": [], "minor_to_major": ["3", "2", "1", "0"], "tiles": [], "element_size_in_bits": "0", "memory_space": "0", "index_primitive_type": "PRIMITIVE_TYPE_INVALID", "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", "dynamic_shape_metadata_prefix_bytes": "0"}, "is_dynamic_dimension": [false, false, false, false]}, "is_flash_attention": true, "mask_type": "NO_MASK", "bmm1_grad_gemm1_dot_dimension_numbers": {"lhs_contracting_dimensions": ["2"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}, "bmm1_grad_gemm2_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}, "bmm2_grad_gemm1_dot_dimension_numbers": {"lhs_contracting_dimensions": ["2"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}, "bmm2_grad_gemm2_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["3"], "lhs_batch_dimensions": ["0", "2"], "rhs_batch_dimensions": ["0", "2"]}, "dropout_rate": 0.0, "seed": 42, "sliding_window_length": 0}}


Call 3:

input_concatenate_fusion.4 = bf16[8,2048,5,128]{3,2,1,0} fusion(get-tuple-element.967, get-tuple-element.966, get-tuple-element.537.0, get-tuple-element.962, copy.253), kind=kInput, calls=fused_concatenate.4, metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/jvp(ESM2Model)/while/body/_layers/transformer_block/self_attention/attention_fn/rotary_pos_emb/vmap(vmap(rotary_pos_emb._apply_rope))/vmap(rotary_pos_emb._apply_rope_1d)/concatenate" source_file="/app/waffle/_src/models/esm2.py" source_line=153 deduplicated_name="input_concatenate_fusion.4"}

input_concatenate_fusion.5 = bf16[8,2048,5,128]{3,2,1,0} fusion(get-tuple-element.967, get-tuple-element.966, get-tuple-element.538.0, get-tuple-element.958, copy.253), kind=kInput, calls=fused_concatenate.5, metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/jvp(ESM2Model)/while/body/_layers/transformer_block/self_attention/attention_fn/rotary_pos_emb/vmap(vmap(rotary_pos_emb._apply_rope))/vmap(rotary_pos_emb._apply_rope_1d)/concatenate" source_file="/app/waffle/_src/models/esm2.py" source_line=153 deduplicated_name="input_concatenate_fusion.4"}

bitcast.23551.0 = bf16[8,2048,5,128]{3,2,1,0} bitcast(loop_add_fusion.5)

loop_select_fusion.1 = bf16[1,1,2048,2048]{3,2,1,0} fusion(get-tuple-element.969, get-tuple-element.968, all-reduce-done.6), kind=kLoop, calls=fused_select.1, metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/jvp(ESM2Model)/while/body/_layers/transformer_block/self_attention/attention_fn/jit(_where)/select_n" source_file="/app/waffle/_src/models/esm2.py" source_line=351}

custom-call.5.0 = (bf16[8,5,2048,128]{3,1,2,0}, f32[8,5,2048]{2,1,0}, u8[0]{0}) custom-call(input_concatenate_fusion.4, input_concatenate_fusion.5, bitcast.23551.0, loop_select_fusion.1), custom_call_target="__cudnn$fmhaScaleBiasSoftmax", operand_layout_constraints={bf16[8,2048,5,128]{3,2,1,0}, bf16[8,2048,5,128]{3,2,1,0}, bf16[8,2048,5,128]{3,2,1,0}, bf16[1,1,2048,2048]{3,2,1,0}}, api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/jvp(ESM2Model)/while/body/_layers/transformer_block/self_attention/attention_fn/custom_partitioning" source_file="/app/waffle/_src/models/esm2.py" source_line=351}, backend_config={"operation_queue_id": "0", "wait_on_operation_queues": [], "cudnn_fmha_backend_config": {"algorithm": {"algo_id": "0", "math_type": "TENSOR_OP_MATH", "tuning_knobs": {"17": "1", "24": "0"}, "is_cudnn_frontend": true, "workspace_size": "0"}, "fmha_scale": 0.08838834764831845, "intermediate_tensor_shape": {"element_type": "BF16", "dimensions": ["8", "5", "2048", "2048"], "tuple_shapes": [], "layout": {"dim_level_types": [], "dim_unique": [], "dim_ordered": [], "minor_to_major": ["3", "2", "1", "0"], "tiles": [], "element_size_in_bits": "0", "memory_space": "0", "index_primitive_type": "PRIMITIVE_TYPE_INVALID", "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", "dynamic_shape_metadata_prefix_bytes": "0"}, "is_dynamic_dimension": [false, false, false, false]}, "is_flash_attention": true, "mask_type": "NO_MASK", "bmm1_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["3"], "lhs_batch_dimensions": ["0", "2"], "rhs_batch_dimensions": ["0", "2"]}, "bmm2_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}, "dropout_rate": 0.0, "seed": 42, "sliding_window_length": 0}}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working NVIDIA GPU Issues specific to NVIDIA GPUs
Projects
None yet
Development

No branches or pull requests

3 participants