You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
FlexAttention currently requires recompilation whenever the score_mod function changes. This works well for static attention modifications but presents challenges when the modification involves a dynamically computed bias tensor that varies across batches in training.
I am working on a use case where I need to add a bias term of shape [B, H, S, S] to the attention scores before applying the softmax operation. However, since this bias is computed as an intermediate result in the forward pass, embedding it in the score_mod function forces a full kernel recompilation on every forward pass, significantly impacting performance.
Would it be possible to extend FlexAttention to allow dynamically varying bias tensors without requiring recompilation?
Current Issue
• Problem: FlexAttention’s design requires a static function for score_mod, meaning any modifications passed as closures lead to kernel recompilation whenever the function (or its captured tensors) changes.
• Why This is a Problem: In many real-world applications (e.g., Transformer models with learned attention biases), the bias is dynamically computed during the forward pass. Since recompilation is expensive, this makes FlexAttention impractical for training scenarios where the bias differs across batches.
Proposed Solution
A potential way to address this could be:
1. Allowing a bias argument directly in flex_attention(...) that accepts a tensor of shape [B, H, S, S] and integrates it into the fused kernel computation.
2. Distinguishing between static modifications (handled via score_mod) and dynamic biases (passed as a tensor input) so that frequent recompilation is avoided.
3. Enabling optional dynamic bias support using CUDA/Triton shared memory or efficient in-kernel broadcasting to minimize performance overhead.
Would love to hear thoughts from the maintainers on the feasibility of this and whether such an extension aligns with the design goals of FlexAttention!
The text was updated successfully, but these errors were encountered:
Hi,
FlexAttention currently requires recompilation whenever the score_mod function changes. This works well for static attention modifications but presents challenges when the modification involves a dynamically computed bias tensor that varies across batches in training.
I am working on a use case where I need to add a bias term of shape [B, H, S, S] to the attention scores before applying the softmax operation. However, since this bias is computed as an intermediate result in the forward pass, embedding it in the score_mod function forces a full kernel recompilation on every forward pass, significantly impacting performance.
Would it be possible to extend FlexAttention to allow dynamically varying bias tensors without requiring recompilation?
Current Issue
• Problem: FlexAttention’s design requires a static function for score_mod, meaning any modifications passed as closures lead to kernel recompilation whenever the function (or its captured tensors) changes.
• Why This is a Problem: In many real-world applications (e.g., Transformer models with learned attention biases), the bias is dynamically computed during the forward pass. Since recompilation is expensive, this makes FlexAttention impractical for training scenarios where the bias differs across batches.
Proposed Solution
A potential way to address this could be:
1. Allowing a bias argument directly in flex_attention(...) that accepts a tensor of shape [B, H, S, S] and integrates it into the fused kernel computation.
2. Distinguishing between static modifications (handled via score_mod) and dynamic biases (passed as a tensor input) so that frequent recompilation is avoided.
3. Enabling optional dynamic bias support using CUDA/Triton shared memory or efficient in-kernel broadcasting to minimize performance overhead.
Would love to hear thoughts from the maintainers on the feasibility of this and whether such an extension aligns with the design goals of FlexAttention!
The text was updated successfully, but these errors were encountered: