Skip to content

Commit 674eab0

Browse files
committed
add blockmask
1 parent 6fd1723 commit 674eab0

File tree

7 files changed

+184
-7
lines changed

7 files changed

+184
-7
lines changed

paddle/phi/backends/dynload/flashmaskv2.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,10 @@ FLASHMASK_V2_HANDLE_ROUTINE(ut_start_ptr)
225225
FLASHMASK_V2_HANDLE_ROUTINE(ut_end_ptr)
226226
FLASHMASK_V2_HANDLE_ROUTINE(flashmask_maxmin_ptr)
227227

228+
FLASHMASK_V2_HANDLE_ROUTINE(m_block_dim)
229+
FLASHMASK_V2_HANDLE_ROUTINE(n_block_dim)
230+
FLASHMASK_V2_HANDLE_ROUTINE(block_mask_ptr)
231+
228232
#define FLASHMASK_V2_BWD_HANDLE_ROUTINE(type, member) \
229233
DECLARE_DYNAMIC_LOAD_FLASHMASK_V2_WRAP(flashmaskv2_bwd_params_get_##member); \
230234
DECLARE_DYNAMIC_LOAD_FLASHMASK_V2_WRAP(flashmaskv2_bwd_params_set_##member);

paddle/phi/kernels/gpu/flash_attn_v3_grad_kernel.cu

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,8 @@ void FlashMaskV2GradBaseKernel(
850850
&seqused_k_, // b. If given, only this many elements of each batch
851851
// element's keys are used.
852852
const paddle::optional<DenseTensor> &startend_row_indices_,
853+
const paddle::optional<DenseTensor>
854+
&block_mask_indices_, // ((b,h,s//128,s//128)
853855
int max_seqlen_q_,
854856
int max_seqlen_k_,
855857
float const softmax_scale,
@@ -1080,6 +1082,50 @@ void FlashMaskV2GradBaseKernel(
10801082
}
10811083
}
10821084

1085+
bool const is_blockmask = block_mask_indices_.is_initialized();
1086+
DenseTensor block_mask_indices;
1087+
if (is_blockmask) block_mask_indices = block_mask_indices_.get();
1088+
1089+
if (is_blockmask) {
1090+
PADDLE_ENFORCE_EQ(
1091+
is_flashmask,
1092+
true,
1093+
common::errors::InvalidArgument(
1094+
"blockmask should be used with flashmask at the same time "));
1095+
1096+
PADDLE_ENFORCE_EQ(block_mask_indices.dims().size(),
1097+
4,
1098+
common::errors::InvalidArgument(
1099+
"blockmask receive blockmask_indices with dim "
1100+
"[batch_size, num_heads, blocklen_q, blocklen_k]"));
1101+
1102+
PADDLE_ENFORCE_EQ(block_mask_indices.dims()[2],
1103+
(seqlen_q + 127) / 128,
1104+
common::errors::InvalidArgument(
1105+
"blockmask only supports blockdim_q = 128 now"));
1106+
1107+
PADDLE_ENFORCE_EQ(block_mask_indices.dims()[3],
1108+
(seqlen_k + 127) / 128,
1109+
common::errors::InvalidArgument(
1110+
"blockmask only supports blockdim_k = 128 now"));
1111+
1112+
PADDLE_ENFORCE_EQ(
1113+
block_mask_indices.dims()[1],
1114+
startend_row_indices.dims()[1],
1115+
common::errors::InvalidArgument(
1116+
"blockmask only supports same dim num_heads with flashmask now"));
1117+
1118+
PADDLE_ENFORCE_LE(seqlen_k,
1119+
1024 * 128,
1120+
common::errors::InvalidArgument(
1121+
"blockmask only supports seqlen <= 128k in bwd now"));
1122+
1123+
PADDLE_ENFORCE_LE(seqlen_q,
1124+
1024 * 128,
1125+
common::errors::InvalidArgument(
1126+
"blockmask only supports seqlen <= 128k in bwd now"));
1127+
}
1128+
10831129
const bool has_lt_start = lt_start_row_indices.initialized();
10841130
const bool has_lt_end = lt_end_row_indices.initialized();
10851131
const bool has_ut_start = ut_start_row_indices.initialized();
@@ -1284,6 +1330,7 @@ void FlashMaskV2GradBaseKernel(
12841330
if (softmax_lse_log2) {
12851331
dev_ctx.template Alloc<float>(softmax_lse_log2);
12861332
}
1333+
// std::cout << "dq_accum:" << dq_accum->dims() << std::endl;
12871334
if (dq_accum) {
12881335
if (!is_varlen) {
12891336
dq_accum->Resize(common::make_ddim(
@@ -1292,6 +1339,7 @@ void FlashMaskV2GradBaseKernel(
12921339
dq_accum->Resize(common::make_ddim(
12931340
{num_heads, total_q_padded_rounded * head_size_rounded}));
12941341
}
1342+
// std::cout << "enter:" << dq_accum->dims() << std::endl;
12951343
dev_ctx.template Alloc<float>(dq_accum);
12961344
}
12971345
if (num_heads_k != num_heads) { // MQA / GQA
@@ -1457,6 +1505,17 @@ void FlashMaskV2GradBaseKernel(
14571505
dynload::flashmaskv2_bwd_params_set_h_h_flashmask_ratio(params_handle, 0);
14581506
}
14591507

1508+
if (is_blockmask) {
1509+
// xhy: blockmask is now only support blockdim_q k = 128
1510+
dynload::flashmaskv2_bwd_params_set_m_block_dim(params_handle, 128);
1511+
dynload::flashmaskv2_bwd_params_set_n_block_dim(params_handle, 128);
1512+
dynload::flashmaskv2_bwd_params_set_block_mask_ptr(
1513+
params_handle,
1514+
const_cast<int32_t *>(block_mask_indices.data<int32_t>()));
1515+
// phi::funcs::SetConstant<Context, T> set_dq_zero;
1516+
// // dev_ctx.template Alloc<T>(dq);
1517+
// set_dq_zero(dev_ctx, dq, T{0});
1518+
}
14601519
#ifdef FLASHATTENTION_DISABLE_LOCAL
14611520
PADDLE_ENABLE_EQ(
14621521
!dynload::flashmaskv2_bwd_params_get_is_local(params_handle),
@@ -1504,6 +1563,7 @@ void FlashMaskV2GradKernel(
15041563
const DenseTensor &out,
15051564
const DenseTensor &softmax_lse,
15061565
const DenseTensor &startend_row_indices, // TODO(xiehaoyang): remove this
1566+
const paddle::optional<DenseTensor> &block_mask_indices,
15071567
const DenseTensor &out_grad,
15081568
float const softmax_scale,
15091569
bool is_causal,
@@ -1540,6 +1600,7 @@ void FlashMaskV2GradKernel(
15401600
paddle::none,
15411601
paddle::none,
15421602
startend_row_indices,
1603+
block_mask_indices,
15431604
0, // max_seqlen_q,
15441605
0, // max_seqlen_k,
15451606
softmax_scale,

paddle/phi/kernels/gpu/flash_attn_v3_kernel.cu

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,7 +1235,10 @@ void FlashMaskV2BaseKernel(
12351235
const paddle::optional<DenseTensor> &k_descale_, // (b, h_k)
12361236
const paddle::optional<DenseTensor> &v_descale_, // (b, h_k)
12371237
const paddle::optional<DenseTensor> &scheduler_metadata_, // (b + 1)
1238-
const paddle::optional<DenseTensor> &startend_row_indices_,
1238+
const paddle::optional<DenseTensor>
1239+
&startend_row_indices_, // (b,h,s_1,[1,2,4])
1240+
const paddle::optional<DenseTensor>
1241+
&block_mask_indices_, // ((b,h,s// 128,s // 128)
12391242
const int
12401243
max_seqlen_q_, // if max_seqlen_q_ is set to 0, it indicates that it is
12411244
// uninitialized and should not be referenced
@@ -1432,6 +1435,7 @@ void FlashMaskV2BaseKernel(
14321435
}
14331436

14341437
bool const is_flashmask = startend_row_indices_.is_initialized();
1438+
bool const is_blockmask = block_mask_indices_.is_initialized();
14351439

14361440
// This needs to go before kBlockM & kBlockN since we rely on the correct
14371441
// window_size and is_causal to set kBlockM
@@ -2068,6 +2072,8 @@ void FlashMaskV2BaseKernel(
20682072
// flashmask
20692073
DenseTensor startend_row_indices;
20702074
if (is_flashmask) startend_row_indices = startend_row_indices_.get();
2075+
DenseTensor block_mask_indices;
2076+
if (is_blockmask) block_mask_indices = block_mask_indices_.get();
20712077
DenseTensor flashmask_maxmin, lt_start_row_indices, lt_end_row_indices,
20722078
ut_start_row_indices, ut_end_row_indices;
20732079
if (is_flashmask) {
@@ -2142,6 +2148,45 @@ void FlashMaskV2BaseKernel(
21422148
}
21432149
}
21442150

2151+
if (is_blockmask) {
2152+
PADDLE_ENFORCE_EQ(
2153+
is_flashmask,
2154+
true,
2155+
common::errors::InvalidArgument(
2156+
"blockmask should be used with flashmask at the same time "));
2157+
2158+
PADDLE_ENFORCE_EQ(block_mask_indices.dims().size(),
2159+
4,
2160+
common::errors::InvalidArgument(
2161+
"blockmask receive blockmask_indices with dim "
2162+
"[batch_size, num_heads, blocklen_q, blocklen_k]"));
2163+
2164+
PADDLE_ENFORCE_EQ(block_mask_indices.dims()[2],
2165+
(seqlen_q + 127) / 128,
2166+
common::errors::InvalidArgument(
2167+
"blockmask is now only support blockdim_q = 128 "));
2168+
2169+
PADDLE_ENFORCE_EQ(block_mask_indices.dims()[3],
2170+
(seqlen_k + 127) / 128,
2171+
common::errors::InvalidArgument(
2172+
"blockmask is now only support blockdim_k = 128 "));
2173+
2174+
PADDLE_ENFORCE_EQ(
2175+
block_mask_indices.dims()[1],
2176+
startend_row_indices.dims()[1],
2177+
common::errors::InvalidArgument("blockmask is now only support same "
2178+
"dim num_heads with flashmask "));
2179+
}
2180+
2181+
if (is_blockmask) {
2182+
// xhy: blockmask is now only support blockdim_q k = 128
2183+
dynload::flashmaskv2_fwd_params_set_m_block_dim(params_handle, 128);
2184+
dynload::flashmaskv2_fwd_params_set_n_block_dim(params_handle, 128);
2185+
dynload::flashmaskv2_fwd_params_set_block_mask_ptr(
2186+
params_handle,
2187+
const_cast<int32_t *>(block_mask_indices.data<int32_t>()));
2188+
}
2189+
21452190
if (is_flashmask) {
21462191
if (lt_start_row_indices.initialized())
21472192
dynload::flashmaskv2_fwd_params_set_lt_start_ptr(
@@ -2260,6 +2305,7 @@ void FlashMaskV2Kernel(const Context &dev_ctx,
22602305
const DenseTensor &k,
22612306
const DenseTensor &v,
22622307
const DenseTensor &startend_row_indices,
2308+
const paddle::optional<DenseTensor> &block_mask_indices,
22632309
const float softmax_scale,
22642310
bool is_causal,
22652311
DenseTensor *out,
@@ -2290,6 +2336,7 @@ void FlashMaskV2Kernel(const Context &dev_ctx,
22902336
paddle::none, // v_descale_
22912337
paddle::none, // scheduler_metadata_
22922338
startend_row_indices,
2339+
block_mask_indices,
22932340
0, // max_seqlen_q_
22942341
0, // max_seqlen_k_
22952342
softmax_scale,
@@ -2333,4 +2380,7 @@ PD_REGISTER_KERNEL(flashmask_attention_v2,
23332380
ALL_LAYOUT,
23342381
phi::FlashMaskV2Kernel,
23352382
phi::float16,
2336-
phi::bfloat16) {}
2383+
phi::bfloat16) {
2384+
kernel->InputAt(4).SetBackend(
2385+
phi::Backend::ALL_BACKEND); // block_mask_indices
2386+
}

paddle/phi/ops/yaml/backward.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,8 +1228,9 @@
12281228
data_type: q
12291229

12301230
- backward_op : flashmask_attention_v2_grad
1231-
forward : flashmask_attention_v2 (Tensor q, Tensor k, Tensor v, Tensor startend_row_indices, float softmax_scale, bool is_causal) -> Tensor(out), Tensor(softmax_lse)
1232-
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor startend_row_indices, Tensor out_grad, float softmax_scale, bool is_causal)
1231+
forward : flashmask_attention_v2 (Tensor q, Tensor k, Tensor v, Tensor startend_row_indices,Tensor block_mask_indices, float softmax_scale, bool is_causal) -> Tensor(out), Tensor(softmax_lse)
1232+
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor startend_row_indices, Tensor block_mask_indices, Tensor out_grad, float softmax_scale, bool is_causal)
1233+
optional : block_mask_indices
12331234
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
12341235
infer_meta :
12351236
func : FlashAttnGradInferMeta

paddle/phi/ops/yaml/ops.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2153,8 +2153,9 @@
21532153
interfaces : paddle::dialect::InferSymbolicShapeInterface
21542154

21552155
- op : flashmask_attention_v2
2156-
args : (Tensor q, Tensor k, Tensor v, Tensor startend_row_indices, float softmax_scale, bool is_causal)
2156+
args : (Tensor q, Tensor k, Tensor v, Tensor startend_row_indices, Tensor block_mask_indices, float softmax_scale, bool is_causal)
21572157
output : Tensor(out), Tensor(softmax_lse)
2158+
optional : block_mask_indices
21582159
infer_meta :
21592160
func : FlashMaskV2InferMeta
21602161
param : [q, k, v]

python/paddle/nn/functional/flash_attention.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1575,6 +1575,7 @@ def flashmask_attention(
15751575
training: bool = True,
15761576
name: str | None = None,
15771577
softmax_scale: float | None = None,
1578+
block_mask_indices: Tensor | None = None,
15781579
):
15791580
r"""
15801581
FlashMask: Official Implementation
@@ -1635,6 +1636,26 @@ def flashmask_attention(
16351636
training (bool): Whether the module is in training mode. Default is True.
16361637
name (str, optional): Name of the operation. Default is None. Normally, users do not need to set this property.
16371638
For more information, refer to :ref:`api_guide_Name` .
1639+
block_mask_indices (tensor, optional):
1640+
block_mask_indices (Tensor, optional):
1641+
A 4-D integer mask tensor indicating whether each block in the attention matrix should be kept or masked. Must be used together with flashmask.
1642+
The shape should be [batch_size, num_heads, blocklen_q, blocklen_k], where:
1643+
1644+
blocklen_q = ceil(seqlen_q / 128), i.e., block_mask_indices.shape[2] must be (seqlen_q + 127) // 128
1645+
blocklen_k = ceil(seqlen_k / 128), i.e., block_mask_indices.shape[3] must be (seqlen_k + 127) // 128
1646+
block_mask_indices.shape[1] (number of heads) must match the num_heads dimension of the flashmask
1647+
Both seqlen_q and seqlen_k must be less than or equal to 128 * 1024
1648+
The dtype should be int32, and each element should be either 0 or 1.
1649+
A value of 1 indicates that the corresponding block is kept (not masked), while 0 means the block is masked.
1650+
1651+
Usage Notes:
1652+
1653+
Only supported when blockdim_q = blockdim_k = 128 now.
1654+
Only supported when headdim = 128 now.
1655+
This argument must be provided together with flashmask.
1656+
The mask will be applied at the block level: each [i, j] position in block_mask_indices controls whether the corresponding [128 x 128] block in the attention matrix is masked.
1657+
Any mismatch in expected shape or head dimension will raise an error.
1658+
16381659
16391660
Returns
16401661
Tensor. The computed attention result with the same shape as the input `query`.
@@ -2207,6 +2228,12 @@ def flashmask_attention(
22072228
startend_row_indices, min=0, max=sq
22082229
).repeat_interleave(bsz, 0)
22092230

2231+
if block_mask_indices is not None:
2232+
# xhy: can set a full startend_row_indices for block_mask_attn when using block_mask_attn?
2233+
assert startend_row_indices is not None, (
2234+
"must provide startend_row_indices when using block_mask_attn"
2235+
)
2236+
22102237
if startend_row_indices is None:
22112238
(
22122239
out,
@@ -2248,6 +2275,33 @@ def flashmask_attention(
22482275
"startend_row_indices head_num must be equal to 1(broadcast) or head_num_k."
22492276
)
22502277

2278+
if block_mask_indices is not None:
2279+
assert block_mask_indices.dtype == paddle.int32, (
2280+
f"block_mask_indices.dtype must be paddle.int32, but got {block_mask_indices.dtype}"
2281+
)
2282+
2283+
assert block_mask_indices.shape[0] == key.shape[0], (
2284+
f"block_mask_indices.shape[0] must be equal to batch_size, but got {block_mask_indices.shape[0]} and {key.shape[0]}"
2285+
)
2286+
2287+
assert (
2288+
block_mask_indices.shape[1] == startend_row_indices.shape[1]
2289+
), (
2290+
f"block_mask_indices.shape[1] must be equal to startend_row_indices.shape[1], but got {block_mask_indices.shape[1]} and {key.shape[2]}"
2291+
)
2292+
2293+
assert (
2294+
block_mask_indices.shape[2] == (query.shape[1] + 127) // 128
2295+
), "block_size must be 128 when using block_mask_attn"
2296+
2297+
assert block_mask_indices.shape[3] == (key.shape[1] + 127) // 128, (
2298+
"block_size must be 128 when using block_mask_attn"
2299+
)
2300+
2301+
assert key.shape[3] == 128, (
2302+
"headdim must be 128 when using block_mask_attn"
2303+
)
2304+
22512305
if causal:
22522306
if startend_row_indices.shape[-1] == 1:
22532307
has_end = False
@@ -2329,7 +2383,13 @@ def flashmask_attention(
23292383
out,
23302384
result_softmax_lse,
23312385
) = _C_ops.flashmask_attention_v2(
2332-
query, key, value, startend_row_indices, softmax_scale, causal
2386+
query,
2387+
key,
2388+
value,
2389+
startend_row_indices,
2390+
block_mask_indices,
2391+
softmax_scale,
2392+
causal,
23332393
)
23342394
else:
23352395
raise ValueError(f"Invalid flash attention version: {fa_version}")

0 commit comments

Comments
 (0)