-
-
Notifications
You must be signed in to change notification settings - Fork 11.4k
[FEATURE] Upstream VIT FA RDNA3 ROCM #27776
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
+77
−12
Closed
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
22de4f2
fa_upstream_detection for rdna3 rocm
JartX 25028eb
working_on_fa_rdna3
JartX 58f0be7
remove is_rocm_aiter
JartX 90d3b7c
missing () on on_gfx9()
JartX cf36822
default FLASH_ATTENTION_TRITON_AMD_ENABLE GPU_ARCHS if passed on buil…
JartX 6244384
Merge branch 'main' into feature/upstream_fa_rdna3_rocm
JartX 2af8555
readd code and 7900XTX device id
JartX c03438b
readd code and 7900XTX device id
JartX cd442f6
Merge branch 'main' into feature/upstream_fa_rdna3_rocm
JartX 11a441e
Merge branch 'main' into feature/upstream_fa_rdna3_rocm
JartX File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -422,14 +422,21 @@ def forward( | |
| q, k = torch.chunk(qk_rotated, 2, dim=0) | ||
|
|
||
| if self.is_flash_attn_backend: | ||
| from importlib.util import find_spec | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This entire logic seems unnecessary here because:
|
||
|
|
||
| if not isinstance(max_seqlen, torch.Tensor): | ||
| max_seqlen = torch.tensor( | ||
| max_seqlen, device=q.device, dtype=torch.int32 | ||
| ) | ||
| self.use_upstream_fa = find_spec("flash_attn") is not None | ||
|
|
||
| context_layer = vit_flash_attn_wrapper( | ||
| q, | ||
| k, | ||
| v, | ||
| cu_seqlens, | ||
| max_seqlen, | ||
| batch_size, | ||
| self.attn_backend == _Backend.ROCM_AITER_FA, | ||
| self.use_upstream_fa, | ||
| ) | ||
| elif self.attn_backend == _Backend.TORCH_SDPA: | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I am not sure what the ask is on this PR, but from my end we should feel free to change these shims however we need.
These custom ops are purely to preserve traceability of the ViT component, and the signatures are designed like this because we went from selecting attention on attrs of the model (
self) to needing an independent function withoutselfparameter.I do want to voice a design consideration on this backend selection logic as a whole though - to me, it would seem better if we could use just pass
attn_fnlambda's directly as opposed to some backend enum then doing the function selection later. I wonder what is preventing us from doing this in the code today? (traceability, etc)cc @ywang96 who may have more context on this
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the user perspective it's cleaner to just pass in an enum (e.g,
--mm-encoder-attn-backend TORCH_SDPA) and it's better for us to control this over passing an entire free-form attn implementation, but I agree thatenum -> attn_fncan be done at the level of init time ofXXXVisionTransformerand we pass the resolvedattn_fnas an input downstream toXXXVisionAttention. Does that align with what's on your mind?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Lucaskabela @ywang96
I have an RFC #27821 that is proposing the same idea of
enum -> attn_fnis happening in theXXXVisionTransformer.However, in this RFC taking into account that many of the VL models share the same logic as
qwen2_5_vl.py, it can be abstract out further that the Overriding logic should be handled byplatformas onlyplatformknows that backend it can support.So the
maybe_get_vit_flash_attn_backendwill solely responsible forenum -> attn_fnmapping rather than including overriding logic. (maybe_get_vit_flash_attn_backendwill be renamed to a new name matching its role).