-
-
Notifications
You must be signed in to change notification settings - Fork 11.3k
[FEATURE] Upstream VIT FA RDNA3 ROCM #27776
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: JartX <[email protected]>
Signed-off-by: JartX <[email protected]>
Signed-off-by: JartX <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for Vision Transformer Flash Attention on RDNA3 with ROCM. The changes include updating the Dockerfile to conditionally build Flash Attention, and modifying the attention layers to correctly select the ViT attention backend for ROCm platforms. The refactoring to centralize backend selection logic is a good improvement. However, I've found a critical issue where a function is used as a boolean value instead of being called, which will lead to incorrect behavior at runtime. Please see the specific comment for details.
Signed-off-by: JartX <[email protected]>
|
/gemini summary |
Summary of ChangesThis pull request integrates comprehensive support for upstream Flash Attention on AMD RDNA3 (gfx1100) GPUs within the ROCm platform. It refactors the attention backend selection mechanism to be more robust and dynamic, moving away from a passed boolean parameter to explicit checks of the current platform, environment variables, and GPU architecture. This change centralizes the logic for choosing between different Flash Attention implementations, including AITemplate's version for gfx9 and the newly integrated upstream Flash Attention for RDNA3, ensuring optimal performance based on the hardware and configuration. Highlights
Changelog
Activity
|
…d-args Signed-off-by: JartX <[email protected]>
|
@zhewenl can inference with your last changes on flash_attn: Can you help me in that points of the wrapper and the flash_attn?: |
|
Hi @yewentao256, would you be so kind as to take a look at this as well? Perhaps instead of using environment variables you know of another way to add it implicitly. |
1b770fd to
6244384
Compare
Signed-off-by: JartX <[email protected]>
Signed-off-by: JartX <[email protected]>
|
@JartX can you evaluate the benchmark which one is faster, triton flash attention API or the Torch.SDPA? Let's try to avoid introducing more and more code path by offering the best one as default. And I saw that you introduced a new environment variables. We would like to cut down on that. |
|
Moreover, since there are efforts in fixing AMD CI, the GPU used on AMD CI is able to test all the code path, torch.sdpa, CK flash attention varlen, AITER flash attention varlen, and even this new triton flash attention varlen (if it is worth introducing). I will be fixing all of them together. |
Lucaskabela
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the change to custom op signature is fine, but there is some changes in the models file we shouldn't need
| is_rocm_aiter: bool, | ||
| use_upstream_fa: bool, | ||
| ) -> torch.Tensor: | ||
| if is_rocm_aiter: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I am not sure what the ask is on this PR, but from my end we should feel free to change these shims however we need.
These custom ops are purely to preserve traceability of the ViT component, and the signatures are designed like this because we went from selecting attention on attrs of the model (self) to needing an independent function without self parameter.
I do want to voice a design consideration on this backend selection logic as a whole though - to me, it would seem better if we could use just pass attn_fnlambda's directly as opposed to some backend enum then doing the function selection later. I wonder what is preventing us from doing this in the code today? (traceability, etc)
cc @ywang96 who may have more context on this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would seem better if we could use just pass attn_fnlambda's directly as opposed to some backend enum then doing the function selection later. I wonder what is preventing us from doing this in the code today?
From the user perspective it's cleaner to just pass in an enum (e.g, --mm-encoder-attn-backend TORCH_SDPA) and it's better for us to control this over passing an entire free-form attn implementation, but I agree that enum -> attn_fn can be done at the level of init time of XXXVisionTransformer and we pass the resolved attn_fn as an input downstream to XXXVisionAttention. Does that align with what's on your mind?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have an RFC #27821 that is proposing the same idea of enum -> attn_fn is happening in the XXXVisionTransformer .
However, in this RFC taking into account that many of the VL models share the same logic as qwen2_5_vl.py, it can be abstract out further that the Overriding logic should be handled by platform as only platform knows that backend it can support.
So the maybe_get_vit_flash_attn_backend will solely responsible for enum -> attn_fn mapping rather than including overriding logic. (maybe_get_vit_flash_attn_backend will be renamed to a new name matching its role).
| q, k = torch.chunk(qk_rotated, 2, dim=0) | ||
|
|
||
| if self.is_flash_attn_backend: | ||
| from importlib.util import find_spec |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This entire logic seems unnecessary here because:
max_seqlenis already a tensor (type hints FTW)use_upstream_fais set inQwen2_5_VisionTransformer/upstream in init. We should just modify the logic there if needed as opposed to here
|
@tjtanaa FA CACHE 32B TORCH.SDPA CACHE 32B Benchmarking summary: TORCH.SDPA WCACHE 32B FA WCACHE 32B TORCH.SDPA WCACHE MOE 30B Benchmarking summary: TORCH.SDPA CACHE MOE 30B +-----------------------------------+-----------+ TORCH.SDPA WCACHE EP MMDATA MOE 30B TORCH.SDPA CACHE EP MMDATA MOE 30B TORCH.SDPA WCACHE EP MOE 30B TORCH.SDPA CACHE EP MOE 30B Benchmarking summary: FA WCACHE MOE 30B Benchmarking summary: FA CACHE MOE 30B Benchmarking summary: FA WCACHE EP MMDATA MOE 30B FA CACHE EP MMDATA MOE 30B Benchmarking summary: FA WCACHE EP MOE 30B Benchmarking summary: FA CACHE EP MOE 30B |
|
|
@JartX From your data, I think it shows that in generally the TORCH_SDPA is still the better option. |
CACHE MOE 30B
TORCH.SDPA (better total throughput and request throughput) CACHE EP MMDATA MOE 30B
TORCH.SDPA (better across all metrics) CACHE EP MOE 30B
TORCH.SDPA (better across all metrics) WCACHE EP MMDATA MOE 30B
TORCH.SDPA (significantly better across all metrics) WCACHE EP MOE 30B
TORCH.SDPA (significantly better across all metrics) WCACHE MOE 30B
Mixed - FA has better output throughput, TORCH.SDPA has better latency and TTFT CACHE 32B
FA (slightly better output throughput and TTFT) WCACHE 32B
TORCH.SDPA (better across all metrics) |
|
@tjtanaa It surprised me too, but I don't understand why xD. Could you try to explain it to me? Please? Even if it's just because of the time I've spent with the benchmarks hahaha |
|
Accuracy Test: FA
TORCH.SDPA
|
|
@JartX I think the accuracy changes is not sufficient to say there is a huge degradation, maybe we need to evaluate on more datasets. Accuracy Changes (FA → TORCH.SDPA)
Regarding to the speed comparison between In my opinion,
|
|
@JartX I would like to suggest you an alternative. Since on Radeon, AITER is also installed, can you explore the use of AITER's triton See if it is faster? |
|
@tjtanaa thanks for the idea — I tried it back in the day, but inference isn’t possible due to lack of hardware support. For example: I’ve also seen your PR: #27919 this one could even be removed PR #27776, or just kept in the Dockerfile along with a wiki update referencing ROCm and RDNA3. And sorry for my ignorance, but with your PR — is it possible to force Flash Attention upstream? If possible, it would be really easy to offer both types of care in case atrial fibrillation progresses. Thank you so much for your time. |
|
@JartX I am referring to the triton implementation from Aiter repo. Is invoking the triton implementation triggering asm error? |
|
@tjtanaa Hi, it also fails to start up; it seems to only be supported for X Arch: |
|
Go to add al spoof the gpu: |
|
@tjtanaa |
I'm trying to respect the latest logic implemented with @Lucaskabela's wrapper, also trying to correct the execution. Before the wrapper, I already had it advanced with minimalist logic implemented in rocm.py and layer.py. Right now I can get it to work with FLASH_ATTN without torch.SDPA, I ask for your help:
@DarkLight1337 @tjtanaa @lgeiger @Lucaskabela
I am worried about the following part in qwen2_5_vl.py
I would say that there is a lack of coherence in the parameters:
is_rocm_aiter: bool == ROCM_AITER_FA
self.attn_backend == _Backend.ROCM_AITER_FA,
I have passed the selection that I am sure is poorly made to flash_attn_maxseqlen_wrapper, here @tjtanaa will surely hit me xD
Please take a look, and if I have to give access to everyone in my repository I will give it :)
Thank you so much!