Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does Flex attention API accepts a customized attention mask? #112

Open
jiagaoxiang opened this issue Feb 6, 2025 · 1 comment
Open

Does Flex attention API accepts a customized attention mask? #112

jiagaoxiang opened this issue Feb 6, 2025 · 1 comment

Comments

@jiagaoxiang
Copy link

Hi, by reading through the documentation, I am quite confused on the score_mod and block_mask argument of the flex attention api. They seem to be callables. I am wondering, is there a way I can provide a customized attention mask to the flex attention API just like attn_mask argument in torch.nn.functional.scaled_dot_product_attention? The reason I am asking is because my attention mask is very irregular (like in the attached image, white squares are masked positions) which is used in the encoder self-attention in Llama3.2 vision models (aka, aspect ratio mask).

Thank you!

Image

@Chillee
Copy link
Contributor

Chillee commented Feb 7, 2025

If you really want to use a pre-existing attention mask you can just write a mask_mod like

def mask_mod(b, h, q_idx, kv_idx):
    return mask[q_idx, kv_idx]
block_mask = create_block_mask(mask_mod, None, None, S_LEN, S_LEN)

But generally speaking, it'll be more efficient (and use less memory) to encode the behavior of your mask within mask_mod

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants