-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Error when using flex attention and F.sdpa together. #115
Comments
Can you provide some sort of repro? |
emmm, i just fixed the error by reimplement it. looks like related to some typo i didn't aware of. |
Dynamic shapes should be supported, do you have a repro |
i was getting error like this. Does dynamic shape support batch dimension? LoweringException: AssertionError: Batch dimension must match
target: flex_attention
args[0]: TensorBox(StorageBox(
InputBuffer(name='primals_3', layout=FixedLayout('cuda', torch.bfloat16, size=[s2, 32, s3, 64], stride=[2048*s3, 64, 2048, 1]))
))
args[1]: TensorBox(StorageBox(
InputBuffer(name='primals_6', layout=FixedLayout('cuda', torch.bfloat16, size=[s6, 32, s7, 64], stride=[2048*s7, 64, 2048, 1]))
))
args[2]: TensorBox(StorageBox(
InputBuffer(name='primals_9', layout=FixedLayout('cuda', torch.bfloat16, size=[s9, 32, s10, 64], stride=[2048*s10, 64, 2048, 1]))
))
args[3]: Subgraph(name='sdpa_score0', graph_module=<lambda>(), graph=None)
args[4]: (TensorBox(StorageBox(
InputBuffer(name='primals_12', layout=FixedLayout('cuda', torch.int32, size=[s11, 32, s12], stride=[32*s12, s12, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='primals_16', layout=FixedLayout('cuda', torch.int32, size=[s13, 32, s14, s15], stride=[32*s14*s15, s14*s15, s15, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='primals_19', layout=FixedLayout('cuda', torch.int32, size=[s16, 32, s17], stride=[32*s17, s17, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='primals_23', layout=FixedLayout('cuda', torch.int32, size=[s18, 32, s19, s20], stride=[32*s19*s20, s19*s20, s20, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='primals_26', layout=FixedLayout('cuda', torch.int32, size=[s21, 32, s22], stride=[32*s22, s22, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='primals_30', layout=FixedLayout('cuda', torch.int32, size=[s23, 32, s24, s25], stride=[32*s24*s25, s24*s25, s25, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='primals_33', layout=FixedLayout('cuda', torch.int32, size=[s26, 32, s27], stride=[32*s27, s27, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='primals_37', layout=FixedLayout('cuda', torch.int32, size=[s28, 32, s29, s30], stride=[32*s29*s30, s29*s30, s30, 1]))
)), 128, 128, Subgraph(name='sdpa_mask0', graph_module=<lambda>(), graph=None))
args[5]: 0.125
args[6]: {'ROWS_GUARANTEED_SAFE': False, 'PRESCALE_QK': False, 'OUTPUT_LOGSUMEXP': True}
args[7]: ()
args[8]: ()
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information Here is the minimum reproduction for the code. from functools import lru_cache
from typing import Any, Dict, List, Optional, Literal, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.attention.flex_attention import (
_DEFAULT_SPARSE_BLOCK_SIZE,
create_block_mask,
create_mask,
flex_attention,
)
from diffusers.models.normalization import RMSNorm
from diffusers.models.attention import FeedForward
from diffusers.models.attention_processor import Attention
flex_attention = torch.compile(flex_attention)
@lru_cache
def create_causal_block_mask_cached(block_size, B, H, Q_LEN, KV_LEN, device="cuda"):
def causal_block_mask(b, h, q_idx, kv_idx):
mask = q_idx // block_size >= kv_idx // block_size
return mask
return create_block_mask(causal_block_mask, B, H, Q_LEN, KV_LEN, device=device, _compile=True)
class CausalAttentionProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"AttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
frame_seq_len: Optional[int] = None,
) -> torch.Tensor:
print("causal attention")
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.norm_q(query)
key = attn.norm_k(key)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
assert frame_seq_len is not None
B, H, q_len, _ = query.shape
_, _, kv_len, _ = key.shape
# block_mask = create_block_mask_cached(prefix_block_causal_mask, B, None, q_len, kv_len, device=query.device)
block_mask = create_causal_block_mask_cached(
block_size=frame_seq_len,
B=B,
H=None,
Q_LEN=q_len,
KV_LEN=kv_len,
device=query.device
)
hidden_states = flex_attention(query=query, key=key, value=value, score_mod=None, block_mask=block_mask)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class AttentionProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"AttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
print("simple attention")
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.norm_q(query)
key = attn.norm_k(key)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class TransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
cross_attention_dim: int,
qk_norm: str = "rms_norm_across_heads",
activation_fn: str = "gelu-approximate",
attention_bias: bool = True,
attention_out_bias: bool = True,
eps: float = 1e-6,
elementwise_affine: bool = False,
attn1_processor: CausalAttentionProcessor2_0 = CausalAttentionProcessor2_0(),
attn2_processor: AttentionProcessor2_0 = AttentionProcessor2_0(),
):
super().__init__()
self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
kv_heads=num_attention_heads,
dim_head=attention_head_dim,
bias=attention_bias,
cross_attention_dim=None,
out_bias=attention_out_bias,
qk_norm=qk_norm,
processor=attn1_processor,
)
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
kv_heads=num_attention_heads,
dim_head=attention_head_dim,
bias=attention_bias,
out_bias=attention_out_bias,
qk_norm=qk_norm,
processor=attn2_processor,
)
self.ff = FeedForward(dim, activation_fn=activation_fn)
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim ** 0.5)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
frame_seq_len: Optional[int] = None,
) -> torch.Tensor:
batch_size = hidden_states.size(0)
norm_hidden_states = self.norm1(hidden_states)
num_ada_params = self.scale_shift_table.shape[0]
ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
if isinstance(self.attn1.processor, CausalAttentionProcessor2_0):
attn1_kwargs = dict(frame_seq_len=frame_seq_len)
else:
attn1_kwargs = {}
attn_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=None,
image_rotary_emb=image_rotary_emb,
**attn1_kwargs,
)
hidden_states = hidden_states + attn_hidden_states * gate_msa
attn_hidden_states = self.attn2(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
image_rotary_emb=None,
attention_mask=encoder_attention_mask,
)
hidden_states = hidden_states + attn_hidden_states
norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + ff_output * gate_mlp
return hidden_states
def apply_rotary_emb(x, freqs):
cos, sin = freqs
x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, H, D // 2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
if __name__ == "__main__":
model = TransformerBlock(
dim=2048,
num_attention_heads=32,
attention_head_dim=64,
cross_attention_dim=2048,
qk_norm="rms_norm_across_heads",
activation_fn="gelu-approximate",
attention_bias=True,
attention_out_bias=True,
attn1_processor=CausalAttentionProcessor2_0(),
attn2_processor=AttentionProcessor2_0(),
)
model.cuda()
batch_size = 2
h = 32
w = 18
t = 16
hidden_states = torch.randn(batch_size, h * w * t, 2048).cuda()
encoder_hidden_states = torch.randn(batch_size, 128, 2048).cuda()
temb = torch.randn(batch_size, 1, 2048 * 6).cuda()
encoder_attention_mask = torch.zeros(batch_size, 1, 128).cuda()
encoder_attention_mask[:, :, :12] = 1
output = model(
hidden_states,
encoder_hidden_states,
temb=temb,
encoder_attention_mask=encoder_attention_mask,
frame_seq_len=h * w,
)
print(output.shape)
batch_size = 4
h = 16
w = 18
t = 8
hidden_states = torch.randn(batch_size, h * w * t, 2048).cuda()
encoder_hidden_states = torch.randn(batch_size, 128, 2048).cuda()
temb = torch.randn(batch_size, 1, 2048 * 6).cuda()
encoder_attention_mask = torch.zeros(batch_size, 1, 128).cuda()
encoder_attention_mask[:, :, :12] = 1
output = model(
hidden_states,
encoder_hidden_states,
temb=temb,
encoder_attention_mask=encoder_attention_mask,
frame_seq_len=h * w,
)
print(output.shape) |
Yeah we fixed a number of dyanmic shape issues more recently that were not included in 2.5.1 and need a newer pytorch verison |
do you have recommendation for the torch version that suitable for the dynamic shape? |
If 2.6 supports what you need then great. We in particular had some more dynamic shape issue for max-autotune that are fixed on nightly but not 2.6 |
i notice that i was changing both batch size and sequence length, that might be the causes. currently i fixed the batch size only using dynamic sequence length and it works fine. |
Hello team!
I encounter an issue when i trying to use flex attention and F.sdpa together even during inference. For example, i am using flex_attention for masked self-attention and after that i applied F.sdpa for a cross-attention. however, it will incur
RuntimeError: CUDA error: an illegal memory access
in F.sdpa.btw, I have to compile flex_attention otherwise it leads to OOM error in flex_attention, since i am using long sequence length (about 8k).
Previously, i am using two F.sdpa and they work fine, but i would like to benefit from the sparse computation in flex-attention.
The text was updated successfully, but these errors were encountered: