Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic mask block sizes during inference #109

Open
windsornguyen opened this issue Jan 30, 2025 · 3 comments
Open

Dynamic mask block sizes during inference #109

windsornguyen opened this issue Jan 30, 2025 · 3 comments

Comments

@windsornguyen
Copy link

Does Flex Attention support changing block masks, for example during inference? Reconstructing the mask each fwd pass is costly, but initializing it once at the start also runs into trouble during inference with respect to the dynamically changing block sizes as the sequence length increases.

Here's what I currently have:

class FlexAttention(nn.Module):
    """
    Generalized Multihead Attention and supports various attention masks.
    Supports Rotary Positional Embeddings.
    """
    def __init__(self, config, mask_mod, score_mod=None):
        """
        Initializes the Attention class.

        Args:
            dim (int): Embedding size.
            num_heads (int): Number of heads.
            mask_mod (Callable): Mask to modify attention scores, e.g. causal.
        """
        super().__init__()
        self.dim, self.num_heads = config.dim, config.num_heads
        assert config.dim % config.num_heads == 0, f"dim ({self.dim}) must be divisible num_heads ({self.num_heads})"
        self.head_dim = config.dim // config.num_heads

        self.wq = nn.Linear(config.dim, config.dim)
        self.wk = nn.Linear(config.dim, config.dim)
        self.wv = nn.Linear(config.dim, config.dim)

        self.mask_mod = mask_mod
        self.score_mod = score_mod
        self.block_mask = create_block_mask(
            mask_mod=self.mask_mod,
            B=None, # Broadcast
            H=None, # Broadcast
            Q_LEN=config.seq_len,
            KV_LEN=config.seq_len,
            device=config.device,
        )

        self.c_proj = nn.Linear(config.dim, config.dim)
        self.c_proj.SCALE_INIT = 1

    def forward(
        self,
        x: torch.Tensor = None,
        q: torch.Tensor = None,
        k: torch.Tensor = None,
        v: torch.Tensor = None,
        freqs_cis: torch.Tensor = None,
    ) -> torch.Tensor:
        if x is not None:
            q = k = v = x
        if any(t is None for t in [q, k, v]):
            raise ValueError("Must provide either x for self-attention or q/k/v for cross-attention.")

        bsz, q_len, _ = q.shape
        _, k_len, _ = k.shape
        _, v_len, _ = v.shape

        Q = self.wq(q).reshape(bsz, self.num_heads, q_len, self.head_dim)
        K = self.wk(k).reshape(bsz, self.num_heads, k_len, self.head_dim)
        V = self.wv(v).reshape(bsz, self.num_heads, v_len, self.head_dim)

        Q, K = apply_rotary_emb(Q, K, freqs_cis=freqs_cis)

        output = flex_attention(Q, K, V, block_mask=self.block_mask, score_mod=self.score_mod)
        output = output.reshape(bsz, q_len, self.dim)
        output = self.c_proj(output)
        return output

Might someone be able to point out any mistakes in my code, or provide an example / some insight for how Flex Attention can still be used during inference?

Thank you!

@drisspg
Copy link
Contributor

drisspg commented Jan 30, 2025

We are planning to release a blog post on some of the best practices for flexattention + inference.
In the interim this code pointer might be helpful:
pytorch-labs/gpt-fast#196

At a high level the BlockMask is really just 2 tensors + your mask mod func. So depending on your BM you might be able to make all the possible variations up to say MAX_SEQ_LEN and use this.

From a quick look at your code I would say - Similar to training, It's quite common that you are running on a transformer with N transformer blocks and for all N blocks they can share the same block mask. It is important to amortize the cost of creation of the block mask and share among all the layers. Also be sure to torch.compile(create_block_mask)

@windsornguyen
Copy link
Author

Thanks! That was very helpful. What is the estimated time before the blog post is released? Really looking forward to reading it.

Also a bit of a novice question: if I torch.compile my entire model, or layer-by-layer, would I still have to torch.compile(create_block_mask), or does the PyTorch compiler pick that up at the top-level?

@Chillee
Copy link
Contributor

Chillee commented Feb 5, 2025

@windsornguyen no, it just needs to be within the compiled region (needs as in "it's more efficient if it's within the compiled region").

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants