Skip to content

Commit 34ec177

Browse files
committed
fix description
1 parent cced066 commit 34ec177

File tree

6 files changed

+61
-66
lines changed

6 files changed

+61
-66
lines changed

paddle/phi/kernels/gpu/flash_attn_v3_grad_kernel.cu

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -850,8 +850,7 @@ void FlashMaskV2GradBaseKernel(
850850
&seqused_k_, // b. If given, only this many elements of each batch
851851
// element's keys are used.
852852
const paddle::optional<DenseTensor> &startend_row_indices_,
853-
const paddle::optional<DenseTensor>
854-
&block_mask_indices_, // ((b,h,s//128,s//128)
853+
const paddle::optional<DenseTensor> &block_mask_, // ((b,h,s//128,s//128)
855854
int max_seqlen_q_,
856855
int max_seqlen_k_,
857856
float const softmax_scale,
@@ -1082,9 +1081,9 @@ void FlashMaskV2GradBaseKernel(
10821081
}
10831082
}
10841083

1085-
bool const is_blockmask = block_mask_indices_.is_initialized();
1086-
DenseTensor block_mask_indices;
1087-
if (is_blockmask) block_mask_indices = block_mask_indices_.get();
1084+
bool const is_blockmask = block_mask_.is_initialized();
1085+
DenseTensor block_mask;
1086+
if (is_blockmask) block_mask = block_mask_.get();
10881087

10891088
if (is_blockmask) {
10901089
PADDLE_ENFORCE_EQ(
@@ -1093,24 +1092,24 @@ void FlashMaskV2GradBaseKernel(
10931092
common::errors::InvalidArgument(
10941093
"blockmask should be used with flashmask at the same time "));
10951094

1096-
PADDLE_ENFORCE_EQ(block_mask_indices.dims().size(),
1095+
PADDLE_ENFORCE_EQ(block_mask.dims().size(),
10971096
4,
10981097
common::errors::InvalidArgument(
10991098
"blockmask receive blockmask_indices with dim "
11001099
"[batch_size, num_heads, blocklen_q, blocklen_k]"));
11011100

1102-
PADDLE_ENFORCE_EQ(block_mask_indices.dims()[2],
1101+
PADDLE_ENFORCE_EQ(block_mask.dims()[2],
11031102
(seqlen_q + 127) / 128,
11041103
common::errors::InvalidArgument(
11051104
"blockmask only supports blockdim_q = 128 now"));
11061105

1107-
PADDLE_ENFORCE_EQ(block_mask_indices.dims()[3],
1106+
PADDLE_ENFORCE_EQ(block_mask.dims()[3],
11081107
(seqlen_k + 127) / 128,
11091108
common::errors::InvalidArgument(
11101109
"blockmask only supports blockdim_k = 128 now"));
11111110

11121111
PADDLE_ENFORCE_EQ(
1113-
block_mask_indices.dims()[1],
1112+
block_mask.dims()[1],
11141113
startend_row_indices.dims()[1],
11151114
common::errors::InvalidArgument(
11161115
"blockmask only supports same dim num_heads with flashmask now"));
@@ -1503,8 +1502,8 @@ void FlashMaskV2GradBaseKernel(
15031502
dynload::flashmaskv2_bwd_params_set_m_block_dim(params_handle, 128);
15041503
dynload::flashmaskv2_bwd_params_set_n_block_dim(params_handle, 128);
15051504
dynload::flashmaskv2_bwd_params_set_block_mask_ptr(
1506-
params_handle, (block_mask_indices.data<int32_t>()));
1507-
auto ptr = block_mask_indices.data<int32_t>();
1505+
params_handle, (block_mask.data<int32_t>()));
1506+
auto ptr = block_mask.data<int32_t>();
15081507
std::cout << typeid(ptr).name() << std::endl;
15091508
}
15101509
#ifdef FLASHATTENTION_DISABLE_LOCAL
@@ -1554,7 +1553,7 @@ void FlashMaskV2GradKernel(
15541553
const DenseTensor &out,
15551554
const DenseTensor &softmax_lse,
15561555
const DenseTensor &startend_row_indices, // TODO(xiehaoyang): remove this
1557-
const paddle::optional<DenseTensor> &block_mask_indices,
1556+
const paddle::optional<DenseTensor> &block_mask,
15581557
const DenseTensor &out_grad,
15591558
float const softmax_scale,
15601559
bool is_causal,
@@ -1591,7 +1590,7 @@ void FlashMaskV2GradKernel(
15911590
paddle::none,
15921591
paddle::none,
15931592
startend_row_indices,
1594-
block_mask_indices,
1593+
block_mask,
15951594
0, // max_seqlen_q,
15961595
0, // max_seqlen_k,
15971596
softmax_scale,

paddle/phi/kernels/gpu/flash_attn_v3_kernel.cu

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,7 +1241,7 @@ void FlashMaskV2BaseKernel(
12411241
const paddle::optional<DenseTensor>
12421242
&startend_row_indices_, // (b,h,s_1,[1,2,4])
12431243
const paddle::optional<DenseTensor>
1244-
&block_mask_indices_, // ((b,h,s// 128,s // 128)
1244+
&block_mask_, // ((b,h,s// 128,s // 128)
12451245
const int
12461246
max_seqlen_q_, // if max_seqlen_q_ is set to 0, it indicates that it is
12471247
// uninitialized and should not be referenced
@@ -1441,7 +1441,7 @@ void FlashMaskV2BaseKernel(
14411441
}
14421442

14431443
bool const is_flashmask = startend_row_indices_.is_initialized();
1444-
bool const is_blockmask = block_mask_indices_.is_initialized();
1444+
bool const is_blockmask = block_mask_.is_initialized();
14451445

14461446
// This needs to go before kBlockM & kBlockN since we rely on the correct
14471447
// window_size and is_causal to set kBlockM
@@ -2075,8 +2075,8 @@ void FlashMaskV2BaseKernel(
20752075
// flashmask
20762076
DenseTensor startend_row_indices;
20772077
if (is_flashmask) startend_row_indices = startend_row_indices_.get();
2078-
DenseTensor block_mask_indices;
2079-
if (is_blockmask) block_mask_indices = block_mask_indices_.get();
2078+
DenseTensor block_mask;
2079+
if (is_blockmask) block_mask = block_mask_.get();
20802080
DenseTensor flashmask_maxmin, lt_start_row_indices, lt_end_row_indices,
20812081
ut_start_row_indices, ut_end_row_indices;
20822082
if (is_flashmask) {
@@ -2158,24 +2158,24 @@ void FlashMaskV2BaseKernel(
21582158
common::errors::InvalidArgument(
21592159
"blockmask should be used with flashmask at the same time "));
21602160

2161-
PADDLE_ENFORCE_EQ(block_mask_indices.dims().size(),
2161+
PADDLE_ENFORCE_EQ(block_mask.dims().size(),
21622162
4,
21632163
common::errors::InvalidArgument(
21642164
"blockmask receive blockmask_indices with dim "
21652165
"[batch_size, num_heads, blocklen_q, blocklen_k]"));
21662166

2167-
PADDLE_ENFORCE_EQ(block_mask_indices.dims()[2],
2167+
PADDLE_ENFORCE_EQ(block_mask.dims()[2],
21682168
(seqlen_q + 127) / 128,
21692169
common::errors::InvalidArgument(
21702170
"blockmask is now only support blockdim_q = 128 "));
21712171

2172-
PADDLE_ENFORCE_EQ(block_mask_indices.dims()[3],
2172+
PADDLE_ENFORCE_EQ(block_mask.dims()[3],
21732173
(seqlen_k + 127) / 128,
21742174
common::errors::InvalidArgument(
21752175
"blockmask is now only support blockdim_k = 128 "));
21762176

21772177
PADDLE_ENFORCE_EQ(
2178-
block_mask_indices.dims()[1],
2178+
block_mask.dims()[1],
21792179
startend_row_indices.dims()[1],
21802180
common::errors::InvalidArgument("blockmask is now only support same "
21812181
"dim num_heads with flashmask "));
@@ -2186,7 +2186,7 @@ void FlashMaskV2BaseKernel(
21862186
dynload::flashmaskv2_fwd_params_set_m_block_dim(params_handle, 128);
21872187
dynload::flashmaskv2_fwd_params_set_n_block_dim(params_handle, 128);
21882188
dynload::flashmaskv2_fwd_params_set_block_mask_ptr(
2189-
params_handle, (block_mask_indices.data<int32_t>()));
2189+
params_handle, (block_mask.data<int32_t>()));
21902190
}
21912191

21922192
if (is_flashmask) {
@@ -2302,7 +2302,7 @@ void FlashMaskV2Kernel(const Context &dev_ctx,
23022302
const DenseTensor &k,
23032303
const DenseTensor &v,
23042304
const DenseTensor &startend_row_indices,
2305-
const paddle::optional<DenseTensor> &block_mask_indices,
2305+
const paddle::optional<DenseTensor> &block_mask,
23062306
const float softmax_scale,
23072307
bool is_causal,
23082308
DenseTensor *out,
@@ -2333,7 +2333,7 @@ void FlashMaskV2Kernel(const Context &dev_ctx,
23332333
paddle::none, // v_descale_
23342334
paddle::none, // scheduler_metadata_
23352335
startend_row_indices,
2336-
block_mask_indices,
2336+
block_mask,
23372337
0, // max_seqlen_q_
23382338
0, // max_seqlen_k_
23392339
softmax_scale,
@@ -2378,6 +2378,5 @@ PD_REGISTER_KERNEL(flashmask_attention_v2,
23782378
phi::FlashMaskV2Kernel,
23792379
phi::float16,
23802380
phi::bfloat16) {
2381-
kernel->InputAt(4).SetBackend(
2382-
phi::Backend::ALL_BACKEND); // block_mask_indices
2381+
kernel->InputAt(4).SetBackend(phi::Backend::ALL_BACKEND); // block_mask
23832382
}

paddle/phi/ops/yaml/backward.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,9 +1228,9 @@
12281228
data_type: q
12291229

12301230
- backward_op : flashmask_attention_v2_grad
1231-
forward : flashmask_attention_v2 (Tensor q, Tensor k, Tensor v, Tensor startend_row_indices,Tensor block_mask_indices, float softmax_scale, bool is_causal) -> Tensor(out), Tensor(softmax_lse)
1232-
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor startend_row_indices, Tensor block_mask_indices, Tensor out_grad, float softmax_scale, bool is_causal)
1233-
optional : block_mask_indices
1231+
forward : flashmask_attention_v2 (Tensor q, Tensor k, Tensor v, Tensor startend_row_indices,Tensor block_mask, float softmax_scale, bool is_causal) -> Tensor(out), Tensor(softmax_lse)
1232+
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor startend_row_indices, Tensor block_mask, Tensor out_grad, float softmax_scale, bool is_causal)
1233+
optional : block_mask
12341234
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
12351235
infer_meta :
12361236
func : FlashAttnGradInferMeta

paddle/phi/ops/yaml/ops.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2153,9 +2153,9 @@
21532153
interfaces : paddle::dialect::InferSymbolicShapeInterface
21542154

21552155
- op : flashmask_attention_v2
2156-
args : (Tensor q, Tensor k, Tensor v, Tensor startend_row_indices, Tensor block_mask_indices, float softmax_scale, bool is_causal)
2156+
args : (Tensor q, Tensor k, Tensor v, Tensor startend_row_indices, Tensor block_mask, float softmax_scale, bool is_causal)
21572157
output : Tensor(out), Tensor(softmax_lse)
2158-
optional : block_mask_indices
2158+
optional : block_mask
21592159
infer_meta :
21602160
func : FlashMaskV2InferMeta
21612161
param : [q, k, v]

python/paddle/nn/functional/flash_attention.py

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1575,7 +1575,7 @@ def flashmask_attention(
15751575
training: bool = True,
15761576
name: str | None = None,
15771577
softmax_scale: float | None = None,
1578-
block_mask_indices: Tensor | None = None,
1578+
block_mask: Tensor | None = None,
15791579
):
15801580
r"""
15811581
FlashMask: Official Implementation
@@ -1636,25 +1636,24 @@ def flashmask_attention(
16361636
training (bool): Whether the module is in training mode. Default is True.
16371637
name (str, optional): Name of the operation. Default is None. Normally, users do not need to set this property.
16381638
For more information, refer to :ref:`api_guide_Name` .
1639-
block_mask_indices (tensor, optional):
1640-
block_mask_indices (Tensor, optional):
1641-
A 4-D integer mask tensor indicating whether each block in the attention matrix should be kept or masked. Must be used together with flashmask.
1642-
The shape should be [batch_size, num_heads, blocklen_q, blocklen_k], where:
1639+
block_mask (tensor, optional):
1640+
A 4-D integer mask tensor indicating whether each block in the attention matrix should be kept or masked. Must be used together with flashmask.
1641+
The shape should be [batch_size, num_heads, blocklen_q, blocklen_k], where:
16431642
1644-
blocklen_q = ceil(seqlen_q / 128), i.e., block_mask_indices.shape[2] must be (seqlen_q + 127) // 128
1645-
blocklen_k = ceil(seqlen_k / 128), i.e., block_mask_indices.shape[3] must be (seqlen_k + 127) // 128
1646-
block_mask_indices.shape[1] (number of heads) must match the num_heads dimension of the flashmask
1647-
Both seqlen_q and seqlen_k must be less than or equal to 128 * 1024
1648-
The dtype should be int32, and each element should be either 0 or 1.
1649-
A value of 1 indicates that the corresponding block is kept (not masked), while 0 means the block is masked.
1643+
blocklen_q = ceil(seqlen_q / 128), i.e., block_mask.shape[2] must be (seqlen_q + 127) // 128
1644+
blocklen_k = ceil(seqlen_k / 128), i.e., block_mask.shape[3] must be (seqlen_k + 127) // 128
1645+
block_mask.shape[1] (number of heads) must match the num_heads dimension of the flashmask
1646+
Both seqlen_q and seqlen_k must be less than or equal to 128 * 1024
1647+
The dtype should be int32, and each element should be either 0 or 1.
1648+
A value of 1 indicates that the corresponding block is kept (not masked), while 0 means the block is masked.
16501649
1651-
Usage Notes:
1650+
Usage Notes:
16521651
1653-
Only supported when blockdim_q = blockdim_k = 128 now.
1654-
Only supported when headdim = 128 now.
1655-
This argument must be provided together with flashmask.
1656-
The mask will be applied at the block level: each [i, j] position in block_mask_indices controls whether the corresponding [128 x 128] block in the attention matrix is masked.
1657-
Any mismatch in expected shape or head dimension will raise an error.
1652+
Only supported when blockdim_q = blockdim_k = 128 now.
1653+
Only supported when headdim = 128 now.
1654+
This argument must be provided together with flashmask.
1655+
The mask will be applied at the block level: each [i, j] position in block_mask controls whether the corresponding [128 x 128] block in the attention matrix is masked.
1656+
Any mismatch in expected shape or head dimension will raise an error.
16581657
16591658
16601659
Returns
@@ -2228,7 +2227,7 @@ def flashmask_attention(
22282227
startend_row_indices, min=0, max=sq
22292228
).repeat_interleave(bsz, 0)
22302229

2231-
if block_mask_indices is not None:
2230+
if block_mask is not None:
22322231
# xhy: can set a full startend_row_indices for block_mask_attn when using block_mask_attn?
22332232
assert startend_row_indices is not None, (
22342233
"must provide startend_row_indices when using block_mask_attn"
@@ -2275,26 +2274,24 @@ def flashmask_attention(
22752274
"startend_row_indices head_num must be equal to 1(broadcast) or head_num_k."
22762275
)
22772276

2278-
if block_mask_indices is not None:
2279-
assert block_mask_indices.dtype == paddle.int32, (
2280-
f"block_mask_indices.dtype must be paddle.int32, but got {block_mask_indices.dtype}"
2277+
if block_mask is not None:
2278+
assert block_mask.dtype == paddle.int32, (
2279+
f"block_mask.dtype must be paddle.int32, but got {block_mask.dtype}"
22812280
)
22822281

2283-
assert block_mask_indices.shape[0] == key.shape[0], (
2284-
f"block_mask_indices.shape[0] must be equal to batch_size, but got {block_mask_indices.shape[0]} and {key.shape[0]}"
2282+
assert block_mask.shape[0] == key.shape[0], (
2283+
f"block_mask.shape[0] must be equal to batch_size, but got {block_mask.shape[0]} and {key.shape[0]}"
22852284
)
22862285

2287-
assert (
2288-
block_mask_indices.shape[1] == startend_row_indices.shape[1]
2289-
), (
2290-
f"block_mask_indices.shape[1] must be equal to startend_row_indices.shape[1], but got {block_mask_indices.shape[1]} and {key.shape[2]}"
2286+
assert block_mask.shape[1] == startend_row_indices.shape[1], (
2287+
f"block_mask.shape[1] must be equal to startend_row_indices.shape[1], but got {block_mask.shape[1]} and {key.shape[2]}"
22912288
)
22922289

2293-
assert (
2294-
block_mask_indices.shape[2] == (query.shape[1] + 127) // 128
2295-
), "block_size must be 128 when using block_mask_attn"
2290+
assert block_mask.shape[2] == (query.shape[1] + 127) // 128, (
2291+
"block_size must be 128 when using block_mask_attn"
2292+
)
22962293

2297-
assert block_mask_indices.shape[3] == (key.shape[1] + 127) // 128, (
2294+
assert block_mask.shape[3] == (key.shape[1] + 127) // 128, (
22982295
"block_size must be 128 when using block_mask_attn"
22992296
)
23002297

@@ -2326,7 +2323,7 @@ def flashmask_attention(
23262323
elif paddle.get_flags(["FLAGS_cudnn_deterministic"])[
23272324
"FLAGS_cudnn_deterministic"
23282325
]:
2329-
assert block_mask_indices is None, (
2326+
assert block_mask is None, (
23302327
" blockmask attention no supports deterministic now ."
23312328
)
23322329
fa_version = 2
@@ -2340,7 +2337,7 @@ def flashmask_attention(
23402337
"flashmask_attention does not support setting softmax_scale, use flashmask_attention_v2 instead"
23412338
)
23422339

2343-
assert block_mask_indices is None, (
2340+
assert block_mask is None, (
23442341
" blockmask attention only supports sm >= 90 now."
23452342
)
23462343

@@ -2394,7 +2391,7 @@ def flashmask_attention(
23942391
key,
23952392
value,
23962393
startend_row_indices,
2397-
block_mask_indices,
2394+
block_mask,
23982395
softmax_scale,
23992396
causal,
24002397
)

test/legacy_test/test_flashmask.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def test_dot_scale_product(self):
260260
startend_row_indices=startend_row_indices,
261261
dropout=self.dropout,
262262
causal=self.causal,
263-
block_mask_indices=blockmask,
263+
block_mask=blockmask,
264264
)
265265
out_ = attention_naive_with_mask(q_, k_, v_, mask)
266266
out.backward(ograd)

0 commit comments

Comments
 (0)