Skip to content

Commit 6c23fd0

Browse files
virginiafdezVirginia FernandezKumoLiuyiheng-wang-nv
authored
Flash attention (#7977)
Fixes #7944. ### Description In response to Issue #7944, I added the new functionality scaled_dot_product_attention from PyTorch to re-enable flash attention, present in the original MONAI Generative Models repository. This is allowed for torch >= 2.0 and when argument save_attn = False. Errors are raised otherwise. I ran quick tests and added some checks on test_selfattention and test_crossattention scripts to make sure the outputs are the same as not using flash attention. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Virginia Fernandez <[email protected]> Co-authored-by: Virginia Fernandez <[email protected]> Co-authored-by: YunLiu <[email protected]> Co-authored-by: Yiheng Wang <[email protected]>
1 parent 56ee32e commit 6c23fd0

7 files changed

+223
-61
lines changed

monai/networks/blocks/crossattention.py

+53-20
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch.nn as nn
1818

1919
from monai.networks.layers.utils import get_rel_pos_embedding_layer
20-
from monai.utils import optional_import
20+
from monai.utils import optional_import, pytorch_after
2121

2222
Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
2323

@@ -44,6 +44,7 @@ def __init__(
4444
rel_pos_embedding: Optional[str] = None,
4545
input_size: Optional[Tuple] = None,
4646
attention_dtype: Optional[torch.dtype] = None,
47+
use_flash_attention: bool = False,
4748
) -> None:
4849
"""
4950
Args:
@@ -55,13 +56,16 @@ def __init__(
5556
dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads.
5657
qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
5758
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
58-
causal: whether to use causal attention.
59-
sequence_length: if causal is True, it is necessary to specify the sequence length.
60-
rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map.
61-
For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
62-
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
63-
positional parameter size.
59+
causal (bool, optional): whether to use causal attention.
60+
sequence_length (int, optional): if causal is True, it is necessary to specify the sequence length.
61+
rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. For now only
62+
"decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
63+
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional
64+
parameter size.
6465
attention_dtype: cast attention operations to this dtype.
66+
use_flash_attention: if True, use Pytorch's inbuilt
67+
flash attention for a memory efficient attention mechanism (see
68+
https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
6569
"""
6670

6771
super().__init__()
@@ -81,6 +85,20 @@ def __init__(
8185
if causal and sequence_length is None:
8286
raise ValueError("sequence_length is necessary for causal attention.")
8387

88+
if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0):
89+
raise ValueError(
90+
"use_flash_attention is only supported for PyTorch versions >= 2.0."
91+
"Upgrade your PyTorch or set the flag to False."
92+
)
93+
if use_flash_attention and save_attn:
94+
raise ValueError(
95+
"save_attn has been set to True, but use_flash_attention is also set"
96+
"to True. save_attn can only be used if use_flash_attention is False"
97+
)
98+
99+
if use_flash_attention and rel_pos_embedding is not None:
100+
raise ValueError("rel_pos_embedding must be None if you are using flash_attention.")
101+
84102
self.num_heads = num_heads
85103
self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size
86104
self.context_input_size = context_input_size if context_input_size else hidden_size
@@ -94,13 +112,15 @@ def __init__(
94112
self.out_rearrange = Rearrange("b h l d -> b l (h d)")
95113
self.drop_output = nn.Dropout(dropout_rate)
96114
self.drop_weights = nn.Dropout(dropout_rate)
115+
self.dropout_rate = dropout_rate
97116

98117
self.scale = self.head_dim**-0.5
99118
self.save_attn = save_attn
100119
self.attention_dtype = attention_dtype
101120

102121
self.causal = causal
103122
self.sequence_length = sequence_length
123+
self.use_flash_attention = use_flash_attention
104124

105125
if causal and sequence_length is not None:
106126
# causal mask to ensure that attention is only applied to the left in the input sequence
@@ -142,26 +162,39 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None):
142162
q = q.to(self.attention_dtype)
143163
k = k.to(self.attention_dtype)
144164

145-
q = q.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs)
165+
q = q.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs) #
146166
k = k.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs)
147167
v = v.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs)
148-
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
149168

150-
# apply relative positional embedding if defined
151-
att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat
169+
if self.use_flash_attention:
170+
x = torch.nn.functional.scaled_dot_product_attention(
171+
query=q.transpose(1, 2),
172+
key=k.transpose(1, 2),
173+
value=v.transpose(1, 2),
174+
scale=self.scale,
175+
dropout_p=self.dropout_rate,
176+
is_causal=self.causal,
177+
).transpose(
178+
1, 2
179+
) # Back to (b, nh, t, hs)
180+
else:
181+
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
182+
# apply relative positional embedding if defined
183+
if self.rel_positional_embedding is not None:
184+
att_mat = self.rel_positional_embedding(x, att_mat, q)
152185

153-
if self.causal:
154-
att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf"))
186+
if self.causal:
187+
att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf"))
155188

156-
att_mat = att_mat.softmax(dim=-1)
189+
att_mat = att_mat.softmax(dim=-1)
157190

158-
if self.save_attn:
159-
# no gradients and new tensor;
160-
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
161-
self.att_mat = att_mat.detach()
191+
if self.save_attn:
192+
# no gradients and new tensor;
193+
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
194+
self.att_mat = att_mat.detach()
162195

163-
att_mat = self.drop_weights(att_mat)
164-
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
196+
att_mat = self.drop_weights(att_mat)
197+
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
165198
x = self.out_rearrange(x)
166199
x = self.out_proj(x)
167200
x = self.drop_output(x)

monai/networks/blocks/selfattention.py

+45-13
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515

1616
import torch
1717
import torch.nn as nn
18+
import torch.nn.functional as F
1819

1920
from monai.networks.layers.utils import get_rel_pos_embedding_layer
20-
from monai.utils import optional_import
21+
from monai.utils import optional_import, pytorch_after
2122

2223
Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
2324

@@ -42,6 +43,7 @@ def __init__(
4243
rel_pos_embedding: Optional[str] = None,
4344
input_size: Optional[Tuple] = None,
4445
attention_dtype: Optional[torch.dtype] = None,
46+
use_flash_attention: bool = False,
4547
) -> None:
4648
"""
4749
Args:
@@ -59,6 +61,9 @@ def __init__(
5961
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
6062
positional parameter size.
6163
attention_dtype: cast attention operations to this dtype.
64+
use_flash_attention: if True, use Pytorch's inbuilt
65+
flash attention for a memory efficient attention mechanism (see
66+
https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
6267
6368
"""
6469

@@ -82,6 +87,20 @@ def __init__(
8287
if causal and sequence_length is None:
8388
raise ValueError("sequence_length is necessary for causal attention.")
8489

90+
if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0):
91+
raise ValueError(
92+
"use_flash_attention is only supported for PyTorch versions >= 2.0."
93+
"Upgrade your PyTorch or set the flag to False."
94+
)
95+
if use_flash_attention and save_attn:
96+
raise ValueError(
97+
"save_attn has been set to True, but use_flash_attention is also set"
98+
"to True. save_attn can only be used if use_flash_attention is False."
99+
)
100+
101+
if use_flash_attention and rel_pos_embedding is not None:
102+
raise ValueError("rel_pos_embedding must be None if you are using flash_attention.")
103+
85104
self.num_heads = num_heads
86105
self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size
87106
self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size)
@@ -91,12 +110,14 @@ def __init__(
91110
self.out_rearrange = Rearrange("b h l d -> b l (h d)")
92111
self.drop_output = nn.Dropout(dropout_rate)
93112
self.drop_weights = nn.Dropout(dropout_rate)
113+
self.dropout_rate = dropout_rate
94114
self.scale = self.dim_head**-0.5
95115
self.save_attn = save_attn
96116
self.att_mat = torch.Tensor()
97117
self.attention_dtype = attention_dtype
98118
self.causal = causal
99119
self.sequence_length = sequence_length
120+
self.use_flash_attention = use_flash_attention
100121

101122
if causal and sequence_length is not None:
102123
# causal mask to ensure that attention is only applied to the left in the input sequence
@@ -130,23 +151,34 @@ def forward(self, x):
130151
q = q.to(self.attention_dtype)
131152
k = k.to(self.attention_dtype)
132153

133-
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
154+
if self.use_flash_attention:
155+
x = F.scaled_dot_product_attention(
156+
query=q.transpose(1, 2),
157+
key=k.transpose(1, 2),
158+
value=v.transpose(1, 2),
159+
scale=self.scale,
160+
dropout_p=self.dropout_rate,
161+
is_causal=self.causal,
162+
).transpose(1, 2)
163+
else:
164+
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
134165

135-
# apply relative positional embedding if defined
136-
att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat
166+
# apply relative positional embedding if defined
167+
if self.rel_positional_embedding is not None:
168+
att_mat = self.rel_positional_embedding(x, att_mat, q)
137169

138-
if self.causal:
139-
att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[1], : x.shape[1]] == 0, float("-inf"))
170+
if self.causal:
171+
att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf"))
140172

141-
att_mat = att_mat.softmax(dim=-1)
173+
att_mat = att_mat.softmax(dim=-1)
142174

143-
if self.save_attn:
144-
# no gradients and new tensor;
145-
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
146-
self.att_mat = att_mat.detach()
175+
if self.save_attn:
176+
# no gradients and new tensor;
177+
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
178+
self.att_mat = att_mat.detach()
147179

148-
att_mat = self.drop_weights(att_mat)
149-
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
180+
att_mat = self.drop_weights(att_mat)
181+
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
150182
x = self.out_rearrange(x)
151183
x = self.out_proj(x)
152184
x = self.drop_output(x)

monai/networks/blocks/spatialattention.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class SpatialAttentionBlock(nn.Module):
3333
num_channels: number of input channels. Must be divisible by num_head_channels.
3434
num_head_channels: number of channels per head.
3535
attention_dtype: cast attention operations to this dtype.
36+
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
3637
3738
"""
3839

@@ -44,6 +45,7 @@ def __init__(
4445
norm_num_groups: int = 32,
4546
norm_eps: float = 1e-6,
4647
attention_dtype: Optional[torch.dtype] = None,
48+
use_flash_attention: bool = False,
4749
) -> None:
4850
super().__init__()
4951

@@ -54,7 +56,11 @@ def __init__(
5456
raise ValueError("num_channels must be divisible by num_head_channels")
5557
num_heads = num_channels // num_head_channels if num_head_channels is not None else 1
5658
self.attn = SABlock(
57-
hidden_size=num_channels, num_heads=num_heads, qkv_bias=True, attention_dtype=attention_dtype
59+
hidden_size=num_channels,
60+
num_heads=num_heads,
61+
qkv_bias=True,
62+
attention_dtype=attention_dtype,
63+
use_flash_attention=use_flash_attention,
5864
)
5965

6066
def forward(self, x: torch.Tensor):

monai/networks/blocks/transformerblock.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,18 @@ def __init__(
3636
causal: bool = False,
3737
sequence_length: int | None = None,
3838
with_cross_attention: bool = False,
39+
use_flash_attention: bool = False,
3940
) -> None:
4041
"""
4142
Args:
4243
hidden_size (int): dimension of hidden layer.
4344
mlp_dim (int): dimension of feedforward layer.
4445
num_heads (int): number of attention heads.
4546
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
46-
qkv_bias (bool, optional): apply bias term for the qkv linear layer. Defaults to False.
47+
qkv_bias(bool, optional): apply bias term for the qkv linear layer. Defaults to False.
4748
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
49+
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
50+
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
4851
4952
"""
5053

@@ -66,13 +69,19 @@ def __init__(
6669
save_attn=save_attn,
6770
causal=causal,
6871
sequence_length=sequence_length,
72+
use_flash_attention=use_flash_attention,
6973
)
7074
self.norm2 = nn.LayerNorm(hidden_size)
7175
self.with_cross_attention = with_cross_attention
7276

7377
self.norm_cross_attn = nn.LayerNorm(hidden_size)
7478
self.cross_attn = CrossAttentionBlock(
75-
hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, qkv_bias=qkv_bias, causal=False
79+
hidden_size=hidden_size,
80+
num_heads=num_heads,
81+
dropout_rate=dropout_rate,
82+
qkv_bias=qkv_bias,
83+
causal=False,
84+
use_flash_attention=use_flash_attention,
7685
)
7786

7887
def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:

monai/networks/nets/diffusion_model_unet.py

+5
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ class DiffusionUNetTransformerBlock(nn.Module):
6666
dropout: dropout probability to use.
6767
cross_attention_dim: size of the context vector for cross attention.
6868
upcast_attention: if True, upcast attention operations to full precision.
69+
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
70+
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
6971
7072
"""
7173

@@ -77,6 +79,7 @@ def __init__(
7779
dropout: float = 0.0,
7880
cross_attention_dim: int | None = None,
7981
upcast_attention: bool = False,
82+
use_flash_attention: bool = False,
8083
) -> None:
8184
super().__init__()
8285
self.attn1 = SABlock(
@@ -86,6 +89,7 @@ def __init__(
8689
dim_head=num_head_channels,
8790
dropout_rate=dropout,
8891
attention_dtype=torch.float if upcast_attention else None,
92+
use_flash_attention=use_flash_attention,
8993
)
9094
self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout)
9195
self.attn2 = CrossAttentionBlock(
@@ -96,6 +100,7 @@ def __init__(
96100
dim_head=num_head_channels,
97101
dropout_rate=dropout,
98102
attention_dtype=torch.float if upcast_attention else None,
103+
use_flash_attention=use_flash_attention,
99104
)
100105
self.norm1 = nn.LayerNorm(num_channels)
101106
self.norm2 = nn.LayerNorm(num_channels)

0 commit comments

Comments
 (0)