Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature Request: Support for Dynamic Bias Tensor in FlexAttention Without Recompilation #123

Open
pengzhangzhi opened this issue Feb 18, 2025 · 0 comments

Comments

@pengzhangzhi
Copy link

Hi,

FlexAttention currently requires recompilation whenever the score_mod function changes. This works well for static attention modifications but presents challenges when the modification involves a dynamically computed bias tensor that varies across batches in training.

I am working on a use case where I need to add a bias term of shape [B, H, S, S] to the attention scores before applying the softmax operation. However, since this bias is computed as an intermediate result in the forward pass, embedding it in the score_mod function forces a full kernel recompilation on every forward pass, significantly impacting performance.

Would it be possible to extend FlexAttention to allow dynamically varying bias tensors without requiring recompilation?

Current Issue
• Problem: FlexAttention’s design requires a static function for score_mod, meaning any modifications passed as closures lead to kernel recompilation whenever the function (or its captured tensors) changes.
• Why This is a Problem: In many real-world applications (e.g., Transformer models with learned attention biases), the bias is dynamically computed during the forward pass. Since recompilation is expensive, this makes FlexAttention impractical for training scenarios where the bias differs across batches.

Proposed Solution

A potential way to address this could be:
1. Allowing a bias argument directly in flex_attention(...) that accepts a tensor of shape [B, H, S, S] and integrates it into the fused kernel computation.
2. Distinguishing between static modifications (handled via score_mod) and dynamic biases (passed as a tensor input) so that frequent recompilation is avoided.
3. Enabling optional dynamic bias support using CUDA/Triton shared memory or efficient in-kernel broadcasting to minimize performance overhead.

Would love to hear thoughts from the maintainers on the feasibility of this and whether such an extension aligns with the design goals of FlexAttention!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant