-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dynamic mask block sizes during inference #109
Comments
We are planning to release a blog post on some of the best practices for flexattention + inference. At a high level the BlockMask is really just 2 tensors + your mask mod func. So depending on your BM you might be able to make all the possible variations up to say MAX_SEQ_LEN and use this. From a quick look at your code I would say - Similar to training, It's quite common that you are running on a transformer with N transformer blocks and for all N blocks they can share the same block mask. It is important to amortize the cost of creation of the block mask and share among all the layers. Also be sure to torch.compile(create_block_mask) |
Thanks! That was very helpful. What is the estimated time before the blog post is released? Really looking forward to reading it. Also a bit of a novice question: if I torch.compile my entire model, or layer-by-layer, would I still have to torch.compile(create_block_mask), or does the PyTorch compiler pick that up at the top-level? |
@windsornguyen no, it just needs to be within the compiled region (needs as in "it's more efficient if it's within the compiled region"). |
Does Flex Attention support changing block masks, for example during inference? Reconstructing the mask each fwd pass is costly, but initializing it once at the start also runs into trouble during inference with respect to the dynamically changing block sizes as the sequence length increases.
Here's what I currently have:
Might someone be able to point out any mistakes in my code, or provide an example / some insight for how Flex Attention can still be used during inference?
Thank you!
The text was updated successfully, but these errors were encountered: