Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when using flex attention and F.sdpa together. #115

Open
wuyushuwys opened this issue Feb 8, 2025 · 8 comments
Open

Error when using flex attention and F.sdpa together. #115

wuyushuwys opened this issue Feb 8, 2025 · 8 comments

Comments

@wuyushuwys
Copy link

Hello team!

I encounter an issue when i trying to use flex attention and F.sdpa together even during inference. For example, i am using flex_attention for masked self-attention and after that i applied F.sdpa for a cross-attention. however, it will incur RuntimeError: CUDA error: an illegal memory access in F.sdpa.
btw, I have to compile flex_attention otherwise it leads to OOM error in flex_attention, since i am using long sequence length (about 8k).

Previously, i am using two F.sdpa and they work fine, but i would like to benefit from the sparse computation in flex-attention.

@drisspg
Copy link
Contributor

drisspg commented Feb 8, 2025

Can you provide some sort of repro?

@wuyushuwys
Copy link
Author

emmm, i just fixed the error by reimplement it. looks like related to some typo i didn't aware of.
btw, it looks like training with flex attention does not support dynamic sequence length during training due to torch.compile. will that be possible to support? or it is better to pad to same sequence length?
thanks!

@drisspg
Copy link
Contributor

drisspg commented Feb 8, 2025

Dynamic shapes should be supported, do you have a repro

@wuyushuwys
Copy link
Author

i was getting error like this. Does dynamic shape support batch dimension?
the torch version is 2.5.1+cu121

LoweringException: AssertionError: Batch dimension must match
  target: flex_attention
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='primals_3', layout=FixedLayout('cuda', torch.bfloat16, size=[s2, 32, s3, 64], stride=[2048*s3, 64, 2048, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='primals_6', layout=FixedLayout('cuda', torch.bfloat16, size=[s6, 32, s7, 64], stride=[2048*s7, 64, 2048, 1]))
  ))
  args[2]: TensorBox(StorageBox(
    InputBuffer(name='primals_9', layout=FixedLayout('cuda', torch.bfloat16, size=[s9, 32, s10, 64], stride=[2048*s10, 64, 2048, 1]))
  ))
  args[3]: Subgraph(name='sdpa_score0', graph_module=<lambda>(), graph=None)
  args[4]: (TensorBox(StorageBox(
    InputBuffer(name='primals_12', layout=FixedLayout('cuda', torch.int32, size=[s11, 32, s12], stride=[32*s12, s12, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_16', layout=FixedLayout('cuda', torch.int32, size=[s13, 32, s14, s15], stride=[32*s14*s15, s14*s15, s15, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_19', layout=FixedLayout('cuda', torch.int32, size=[s16, 32, s17], stride=[32*s17, s17, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_23', layout=FixedLayout('cuda', torch.int32, size=[s18, 32, s19, s20], stride=[32*s19*s20, s19*s20, s20, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_26', layout=FixedLayout('cuda', torch.int32, size=[s21, 32, s22], stride=[32*s22, s22, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_30', layout=FixedLayout('cuda', torch.int32, size=[s23, 32, s24, s25], stride=[32*s24*s25, s24*s25, s25, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_33', layout=FixedLayout('cuda', torch.int32, size=[s26, 32, s27], stride=[32*s27, s27, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_37', layout=FixedLayout('cuda', torch.int32, size=[s28, 32, s29, s30], stride=[32*s29*s30, s29*s30, s30, 1]))
  )), 128, 128, Subgraph(name='sdpa_mask0', graph_module=<lambda>(), graph=None))
  args[5]: 0.125
  args[6]: {'ROWS_GUARANTEED_SAFE': False, 'PRESCALE_QK': False, 'OUTPUT_LOGSUMEXP': True}
  args[7]: ()
  args[8]: ()

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

Here is the minimum reproduction for the code.

from functools import lru_cache
from typing import Any, Dict, List, Optional, Literal, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.nn.attention.flex_attention import (
    _DEFAULT_SPARSE_BLOCK_SIZE,
    create_block_mask,
    create_mask,
    flex_attention,
)

from diffusers.models.normalization import RMSNorm
from diffusers.models.attention import FeedForward
from diffusers.models.attention_processor import Attention

flex_attention = torch.compile(flex_attention)


@lru_cache
def create_causal_block_mask_cached(block_size, B, H, Q_LEN, KV_LEN, device="cuda"):
    def causal_block_mask(b, h, q_idx, kv_idx):
        mask = q_idx // block_size >= kv_idx // block_size
        return mask

    return create_block_mask(causal_block_mask, B, H, Q_LEN, KV_LEN, device=device, _compile=True)


class CausalAttentionProcessor2_0:
    r"""
    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
    used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector.
    """

    def __init__(self):
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError(
                "AttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
            )

    def __call__(
            self,
            attn: Attention,
            hidden_states: torch.Tensor,
            encoder_hidden_states: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            image_rotary_emb: Optional[torch.Tensor] = None,
            frame_seq_len: Optional[int] = None,
    ) -> torch.Tensor:
        print("causal attention")
        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states

        query = attn.to_q(hidden_states)
        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.norm_q(query)
        key = attn.norm_k(key)

        if image_rotary_emb is not None:
            query = apply_rotary_emb(query, image_rotary_emb)
            key = apply_rotary_emb(key, image_rotary_emb)

        query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
        key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
        value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)

        assert frame_seq_len is not None
        B, H, q_len, _ = query.shape
        _, _, kv_len, _ = key.shape

        # block_mask = create_block_mask_cached(prefix_block_causal_mask, B, None, q_len, kv_len, device=query.device)
        block_mask = create_causal_block_mask_cached(
            block_size=frame_seq_len,
            B=B,
            H=None,
            Q_LEN=q_len,
            KV_LEN=kv_len,
            device=query.device
        )
        hidden_states = flex_attention(query=query, key=key, value=value, score_mod=None, block_mask=block_mask)

        hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
        hidden_states = hidden_states.to(query.dtype)

        hidden_states = attn.to_out[0](hidden_states)
        hidden_states = attn.to_out[1](hidden_states)
        return hidden_states


class AttentionProcessor2_0:
    r"""
    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
    used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector.
    """

    def __init__(self):
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError(
                "AttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
            )

    def __call__(
            self,
            attn: Attention,
            hidden_states: torch.Tensor,
            encoder_hidden_states: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            image_rotary_emb: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        print("simple attention")
        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states

        query = attn.to_q(hidden_states)
        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.norm_q(query)
        key = attn.norm_k(key)

        if image_rotary_emb is not None:
            query = apply_rotary_emb(query, image_rotary_emb)
            key = apply_rotary_emb(key, image_rotary_emb)

        query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
        key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
        value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)

        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

        hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
        hidden_states = hidden_states.to(query.dtype)

        hidden_states = attn.to_out[0](hidden_states)
        hidden_states = attn.to_out[1](hidden_states)
        return hidden_states


class TransformerBlock(nn.Module):

    def __init__(
            self,
            dim: int,
            num_attention_heads: int,
            attention_head_dim: int,
            cross_attention_dim: int,
            qk_norm: str = "rms_norm_across_heads",
            activation_fn: str = "gelu-approximate",
            attention_bias: bool = True,
            attention_out_bias: bool = True,
            eps: float = 1e-6,
            elementwise_affine: bool = False,
            attn1_processor: CausalAttentionProcessor2_0 = CausalAttentionProcessor2_0(),
            attn2_processor: AttentionProcessor2_0 = AttentionProcessor2_0(),
    ):
        super().__init__()

        self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
        self.attn1 = Attention(
            query_dim=dim,
            heads=num_attention_heads,
            kv_heads=num_attention_heads,
            dim_head=attention_head_dim,
            bias=attention_bias,
            cross_attention_dim=None,
            out_bias=attention_out_bias,
            qk_norm=qk_norm,
            processor=attn1_processor,
        )

        self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
        self.attn2 = Attention(
            query_dim=dim,
            cross_attention_dim=cross_attention_dim,
            heads=num_attention_heads,
            kv_heads=num_attention_heads,
            dim_head=attention_head_dim,
            bias=attention_bias,
            out_bias=attention_out_bias,
            qk_norm=qk_norm,
            processor=attn2_processor,
        )

        self.ff = FeedForward(dim, activation_fn=activation_fn)

        self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim ** 0.5)

    def forward(
            self,
            hidden_states: torch.Tensor,
            encoder_hidden_states: torch.Tensor,
            temb: torch.Tensor,
            image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
            encoder_attention_mask: Optional[torch.Tensor] = None,
            frame_seq_len: Optional[int] = None,
    ) -> torch.Tensor:
        batch_size = hidden_states.size(0)
        norm_hidden_states = self.norm1(hidden_states)

        num_ada_params = self.scale_shift_table.shape[0]
        ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
        norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa

        if isinstance(self.attn1.processor, CausalAttentionProcessor2_0):
            attn1_kwargs = dict(frame_seq_len=frame_seq_len)
        else:
            attn1_kwargs = {}
        attn_hidden_states = self.attn1(
            hidden_states=norm_hidden_states,
            encoder_hidden_states=None,
            image_rotary_emb=image_rotary_emb,
            **attn1_kwargs,
        )
        hidden_states = hidden_states + attn_hidden_states * gate_msa

        attn_hidden_states = self.attn2(
            hidden_states,
            encoder_hidden_states=encoder_hidden_states,
            image_rotary_emb=None,
            attention_mask=encoder_attention_mask,
        )
        hidden_states = hidden_states + attn_hidden_states
        norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp

        ff_output = self.ff(norm_hidden_states)
        hidden_states = hidden_states + ff_output * gate_mlp

        return hidden_states


def apply_rotary_emb(x, freqs):
    cos, sin = freqs
    x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1)  # [B, S, H, D // 2]
    x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)
    out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
    return out


if __name__ == "__main__":
    model = TransformerBlock(
        dim=2048,
        num_attention_heads=32,
        attention_head_dim=64,
        cross_attention_dim=2048,
        qk_norm="rms_norm_across_heads",
        activation_fn="gelu-approximate",
        attention_bias=True,
        attention_out_bias=True,
        attn1_processor=CausalAttentionProcessor2_0(),
        attn2_processor=AttentionProcessor2_0(),
    )

    model.cuda()
    batch_size = 2
    h = 32
    w = 18
    t = 16
    hidden_states = torch.randn(batch_size, h * w * t, 2048).cuda()
    encoder_hidden_states = torch.randn(batch_size, 128, 2048).cuda()
    temb = torch.randn(batch_size, 1, 2048 * 6).cuda()
    encoder_attention_mask = torch.zeros(batch_size, 1, 128).cuda()
    encoder_attention_mask[:, :, :12] = 1
    output = model(
        hidden_states,
        encoder_hidden_states,
        temb=temb,
        encoder_attention_mask=encoder_attention_mask,
        frame_seq_len=h * w,
    )
    print(output.shape)

    batch_size = 4
    h = 16
    w = 18
    t = 8
    hidden_states = torch.randn(batch_size, h * w * t, 2048).cuda()
    encoder_hidden_states = torch.randn(batch_size, 128, 2048).cuda()
    temb = torch.randn(batch_size, 1, 2048 * 6).cuda()
    encoder_attention_mask = torch.zeros(batch_size, 1, 128).cuda()
    encoder_attention_mask[:, :, :12] = 1
    output = model(
        hidden_states,
        encoder_hidden_states,
        temb=temb,
        encoder_attention_mask=encoder_attention_mask,
        frame_seq_len=h * w,
    )
    print(output.shape)

@drisspg
Copy link
Contributor

drisspg commented Feb 12, 2025

Yeah we fixed a number of dyanmic shape issues more recently that were not included in 2.5.1 and need a newer pytorch verison

@wuyushuwys
Copy link
Author

do you have recommendation for the torch version that suitable for the dynamic shape? 2.6.0 or nightly
currently we can use the same batch size but there is limitation on the cache size in 2.5.1.

@drisspg
Copy link
Contributor

drisspg commented Feb 26, 2025

If 2.6 supports what you need then great. We in particular had some more dynamic shape issue for max-autotune that are fixed on nightly but not 2.6

@wuyushuwys
Copy link
Author

i notice that i was changing both batch size and sequence length, that might be the causes. currently i fixed the batch size only using dynamic sequence length and it works fine.
Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants