Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions paddle/phi/backends/dynload/flashmaskv2.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,10 @@ FLASHMASK_V2_HANDLE_ROUTINE(ut_start_ptr)
FLASHMASK_V2_HANDLE_ROUTINE(ut_end_ptr)
FLASHMASK_V2_HANDLE_ROUTINE(flashmask_maxmin_ptr)

FLASHMASK_V2_HANDLE_ROUTINE(m_block_dim)
FLASHMASK_V2_HANDLE_ROUTINE(n_block_dim)
FLASHMASK_V2_HANDLE_ROUTINE(block_mask_ptr)

#define FLASHMASK_V2_BWD_HANDLE_ROUTINE(type, member) \
DECLARE_DYNAMIC_LOAD_FLASHMASK_V2_WRAP(flashmaskv2_bwd_params_get_##member); \
DECLARE_DYNAMIC_LOAD_FLASHMASK_V2_WRAP(flashmaskv2_bwd_params_set_##member);
Expand Down
71 changes: 60 additions & 11 deletions paddle/phi/kernels/gpu/flash_attn_v3_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,7 @@ void FlashMaskV2GradBaseKernel(
&seqused_k_, // b. If given, only this many elements of each batch
// element's keys are used.
const paddle::optional<DenseTensor> &startend_row_indices_,
const paddle::optional<DenseTensor> &block_mask_, // ((b,h,s//128,s//128)
int max_seqlen_q_,
int max_seqlen_k_,
float const softmax_scale,
Expand Down Expand Up @@ -1080,6 +1081,50 @@ void FlashMaskV2GradBaseKernel(
}
}

bool const is_blockmask = block_mask_.is_initialized();
DenseTensor block_mask;
if (is_blockmask) block_mask = block_mask_.get();

if (is_blockmask) {
PADDLE_ENFORCE_EQ(
is_flashmask,
true,
common::errors::InvalidArgument(
"blockmask should be used with flashmask at the same time "));

PADDLE_ENFORCE_EQ(block_mask.dims().size(),
4,
common::errors::InvalidArgument(
"blockmask receive blockmask_indices with dim "
"[batch_size, num_heads, blocklen_q, blocklen_k]"));

PADDLE_ENFORCE_EQ(block_mask.dims()[2],
(seqlen_q + 127) / 128,
common::errors::InvalidArgument(
"blockmask only supports blockdim_q = 128 now"));

PADDLE_ENFORCE_EQ(block_mask.dims()[3],
(seqlen_k + 127) / 128,
common::errors::InvalidArgument(
"blockmask only supports blockdim_k = 128 now"));

PADDLE_ENFORCE_EQ(
block_mask.dims()[1],
startend_row_indices.dims()[1],
common::errors::InvalidArgument(
"blockmask only supports same dim num_heads with flashmask now"));

PADDLE_ENFORCE_LE(seqlen_k,
1024 * 128,
common::errors::InvalidArgument(
"blockmask only supports seqlen <= 128k in bwd now"));

PADDLE_ENFORCE_LE(seqlen_q,
1024 * 128,
common::errors::InvalidArgument(
"blockmask only supports seqlen <= 128k in bwd now"));
}

const bool has_lt_start = lt_start_row_indices.initialized();
const bool has_lt_end = lt_end_row_indices.initialized();
const bool has_ut_start = ut_start_row_indices.initialized();
Expand Down Expand Up @@ -1361,7 +1406,7 @@ void FlashMaskV2GradBaseKernel(
num_heads_k != num_heads && dk_accum ? dk_accum->data() : nullptr,
num_heads_k != num_heads && dv_accum ? dv_accum->data() : nullptr,
const_cast<void *>(softmax_lse.data()),
softmax_d ? const_cast<void *>(softmax_d->data()) : nullptr,
softmax_d ? (softmax_d->data()) : nullptr,
/*p_dropout=*/0.f,
softmax_scale,
window_size_left,
Expand Down Expand Up @@ -1408,36 +1453,31 @@ void FlashMaskV2GradBaseKernel(
if (is_flashmask) {
if (lt_start_row_indices.initialized())
dynload::flashmaskv2_bwd_params_set_lt_start_ptr(
params_handle,
const_cast<int32_t *>(lt_start_row_indices.data<int32_t>()));
params_handle, (lt_start_row_indices.data<int32_t>()));
else
dynload::flashmaskv2_bwd_params_set_lt_start_ptr(params_handle, nullptr);

if (lt_end_row_indices.initialized())
dynload::flashmaskv2_bwd_params_set_lt_end_ptr(
params_handle,
const_cast<int32_t *>(lt_end_row_indices.data<int32_t>()));
params_handle, (lt_end_row_indices.data<int32_t>()));
else
dynload::flashmaskv2_bwd_params_set_lt_end_ptr(params_handle, nullptr);

if (ut_start_row_indices.initialized())
dynload::flashmaskv2_bwd_params_set_ut_start_ptr(
params_handle,
const_cast<int32_t *>(ut_start_row_indices.data<int32_t>()));
params_handle, (ut_start_row_indices.data<int32_t>()));
else
dynload::flashmaskv2_bwd_params_set_ut_start_ptr(params_handle, nullptr);

if (ut_end_row_indices.initialized())
dynload::flashmaskv2_bwd_params_set_ut_end_ptr(
params_handle,
const_cast<int32_t *>(ut_end_row_indices.data<int32_t>()));
params_handle, (ut_end_row_indices.data<int32_t>()));
else
dynload::flashmaskv2_bwd_params_set_ut_end_ptr(params_handle, nullptr);

if (flashmask_maxmin.initialized())
dynload::flashmaskv2_bwd_params_set_flashmask_maxmin_ptr(
params_handle,
const_cast<int32_t *>(flashmask_maxmin.data<int32_t>()));
params_handle, (flashmask_maxmin.data<int32_t>()));
else
dynload::flashmaskv2_bwd_params_set_flashmask_maxmin_ptr(params_handle,
nullptr);
Expand All @@ -1457,6 +1497,13 @@ void FlashMaskV2GradBaseKernel(
dynload::flashmaskv2_bwd_params_set_h_h_flashmask_ratio(params_handle, 0);
}

if (is_blockmask) {
// xhy: blockmask is now only support blockdim_q k = 128
dynload::flashmaskv2_bwd_params_set_m_block_dim(params_handle, 128);
dynload::flashmaskv2_bwd_params_set_n_block_dim(params_handle, 128);
dynload::flashmaskv2_bwd_params_set_block_mask_ptr(
params_handle, (block_mask.data<int32_t>()));
}
#ifdef FLASHATTENTION_DISABLE_LOCAL
PADDLE_ENABLE_EQ(
!dynload::flashmaskv2_bwd_params_get_is_local(params_handle),
Expand Down Expand Up @@ -1504,6 +1551,7 @@ void FlashMaskV2GradKernel(
const DenseTensor &out,
const DenseTensor &softmax_lse,
const DenseTensor &startend_row_indices, // TODO(xiehaoyang): remove this
const paddle::optional<DenseTensor> &block_mask,
const DenseTensor &out_grad,
float const softmax_scale,
bool is_causal,
Expand Down Expand Up @@ -1540,6 +1588,7 @@ void FlashMaskV2GradKernel(
paddle::none,
paddle::none,
startend_row_indices,
block_mask,
0, // max_seqlen_q,
0, // max_seqlen_k,
softmax_scale,
Expand Down
104 changes: 72 additions & 32 deletions paddle/phi/kernels/gpu/flash_attn_v3_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1238,7 +1238,10 @@ void FlashMaskV2BaseKernel(
const paddle::optional<DenseTensor> &k_descale_, // (b, h_k)
const paddle::optional<DenseTensor> &v_descale_, // (b, h_k)
const paddle::optional<DenseTensor> &scheduler_metadata_, // (b + 1)
const paddle::optional<DenseTensor> &startend_row_indices_,
const paddle::optional<DenseTensor>
&startend_row_indices_, // (b,h,s_1,[1,2,4])
const paddle::optional<DenseTensor>
&block_mask_, // ((b,h,s// 128,s // 128)
const int
max_seqlen_q_, // if max_seqlen_q_ is set to 0, it indicates that it is
// uninitialized and should not be referenced
Expand Down Expand Up @@ -1438,6 +1441,7 @@ void FlashMaskV2BaseKernel(
}

bool const is_flashmask = startend_row_indices_.is_initialized();
bool const is_blockmask = block_mask_.is_initialized();

// This needs to go before kBlockM & kBlockN since we rely on the correct
// window_size and is_causal to set kBlockM
Expand Down Expand Up @@ -1705,10 +1709,10 @@ void FlashMaskV2BaseKernel(
seqlen_k_new);
phi::dynload::flashmaskv2_fwd_params_set_total_knew(params_handle,
total_k_new);
phi::dynload::flashmaskv2_fwd_params_set_knew_ptr(
params_handle, const_cast<void *>(k_new.data()));
phi::dynload::flashmaskv2_fwd_params_set_vnew_ptr(
params_handle, const_cast<void *>(v_new.data()));
phi::dynload::flashmaskv2_fwd_params_set_knew_ptr(params_handle,
(k_new.data()));
phi::dynload::flashmaskv2_fwd_params_set_vnew_ptr(params_handle,
(v_new.data()));
// All stride are in elements, not bytes.
phi::dynload::flashmaskv2_fwd_params_set_knew_row_stride(
params_handle, k_new.strides()[k_new.strides().size() - 3]);
Expand Down Expand Up @@ -1800,14 +1804,11 @@ void FlashMaskV2BaseKernel(
}
phi::dynload::flashmaskv2_fwd_params_set_tile_count_semaphore(
params_handle,
scheduler_needs_semaphore
? const_cast<int *>(tile_count_semaphore.data<int>())
: nullptr);
scheduler_needs_semaphore ? (tile_count_semaphore.data<int>())
: nullptr);
phi::dynload::flashmaskv2_fwd_params_set_num_splits_dynamic_ptr(
params_handle,
use_dynamic_split
? const_cast<int *>(tile_count_semaphore.data<int>()) + 1
: nullptr);
use_dynamic_split ? (tile_count_semaphore.data<int>()) + 1 : nullptr);
}

if (q_v_.is_initialized()) {
Expand Down Expand Up @@ -1839,8 +1840,8 @@ void FlashMaskV2BaseKernel(
} else {
CHECK_SHAPE(q_v, total_q, num_heads, head_size_v);
}
phi::dynload::flashmaskv2_fwd_params_set_qv_ptr(
params_handle, const_cast<void *>(q_v.data()));
phi::dynload::flashmaskv2_fwd_params_set_qv_ptr(params_handle,
(q_v.data()));
// All stride are in elements, not bytes.
phi::dynload::flashmaskv2_fwd_params_set_qv_row_stride(
params_handle, q_v.strides()[q_v.strides().size() - 3]);
Expand Down Expand Up @@ -1903,9 +1904,9 @@ void FlashMaskV2BaseKernel(
"rotary_cos must have the same dtype as query"));

phi::dynload::flashmaskv2_fwd_params_set_rotary_cos_ptr(
params_handle, const_cast<void *>(rotary_cos.data()));
params_handle, (rotary_cos.data()));
phi::dynload::flashmaskv2_fwd_params_set_rotary_sin_ptr(
params_handle, const_cast<void *>(rotary_sin.data()));
params_handle, (rotary_sin.data()));
dynload::flashmaskv2_fwd_params_set_is_rotary_interleaved(
params_handle, is_rotary_interleaved);
} else {
Expand Down Expand Up @@ -1961,10 +1962,10 @@ void FlashMaskV2BaseKernel(
dev_ctx.template Alloc<float>(softmax_lse_accum);
}
phi::dynload::flashmaskv2_fwd_params_set_is_fp32(params_handle, false);
phi::dynload::flashmaskv2_fwd_params_set_oaccum_ptr(
params_handle, const_cast<void *>(out_accum->data()));
phi::dynload::flashmaskv2_fwd_params_set_oaccum_ptr(params_handle,
(out_accum->data()));
phi::dynload::flashmaskv2_fwd_params_set_softmax_lseaccum_ptr(
params_handle, const_cast<void *>(softmax_lse_accum->data()));
params_handle, (softmax_lse_accum->data()));
phi::dynload::flashmaskv2_fwd_params_set_oaccum_split_stride(
params_handle, out_accum->strides()[0]);
phi::dynload::flashmaskv2_fwd_params_set_oaccum_row_stride(
Expand All @@ -1984,7 +1985,7 @@ void FlashMaskV2BaseKernel(
CHECK_DEVICE(q_descale);
CHECK_SHAPE(q_descale, batch_size, num_heads_k);
phi::dynload::flashmaskv2_fwd_params_set_q_descale_ptr(
params_handle, const_cast<float *>(q_descale.data<float>()));
params_handle, (q_descale.data<float>()));
phi::dynload::flashmaskv2_fwd_params_set_q_descale_batch_stride(
params_handle, q_descale.strides()[0]);
phi::dynload::flashmaskv2_fwd_params_set_q_descale_head_stride(
Expand All @@ -1998,7 +1999,7 @@ void FlashMaskV2BaseKernel(
CHECK_DEVICE(k_descale);
CHECK_SHAPE(k_descale, batch_size, num_heads_k);
phi::dynload::flashmaskv2_fwd_params_set_k_descale_ptr(
params_handle, const_cast<float *>(k_descale.data<float>()));
params_handle, (k_descale.data<float>()));
phi::dynload::flashmaskv2_fwd_params_set_k_descale_batch_stride(
params_handle, k_descale.strides()[0]);
phi::dynload::flashmaskv2_fwd_params_set_k_descale_head_stride(
Expand All @@ -2012,7 +2013,7 @@ void FlashMaskV2BaseKernel(
CHECK_DEVICE(v_descale);
CHECK_SHAPE(v_descale, batch_size, num_heads_k);
phi::dynload::flashmaskv2_fwd_params_set_v_descale_ptr(
params_handle, const_cast<float *>(v_descale.data<float>()));
params_handle, (v_descale.data<float>()));
phi::dynload::flashmaskv2_fwd_params_set_v_descale_batch_stride(
params_handle, v_descale.strides()[0]);
phi::dynload::flashmaskv2_fwd_params_set_v_descale_head_stride(
Expand Down Expand Up @@ -2074,6 +2075,8 @@ void FlashMaskV2BaseKernel(
// flashmask
DenseTensor startend_row_indices;
if (is_flashmask) startend_row_indices = startend_row_indices_.get();
DenseTensor block_mask;
if (is_blockmask) block_mask = block_mask_.get();
DenseTensor flashmask_maxmin, lt_start_row_indices, lt_end_row_indices,
ut_start_row_indices, ut_end_row_indices;
if (is_flashmask) {
Expand Down Expand Up @@ -2148,39 +2151,72 @@ void FlashMaskV2BaseKernel(
}
}

if (is_blockmask) {
PADDLE_ENFORCE_EQ(
is_flashmask,
true,
common::errors::InvalidArgument(
"blockmask should be used with flashmask at the same time "));

PADDLE_ENFORCE_EQ(block_mask.dims().size(),
4,
common::errors::InvalidArgument(
"blockmask receive blockmask_indices with dim "
"[batch_size, num_heads, blocklen_q, blocklen_k]"));

PADDLE_ENFORCE_EQ(block_mask.dims()[2],
(seqlen_q + 127) / 128,
common::errors::InvalidArgument(
"blockmask is now only support blockdim_q = 128 "));

PADDLE_ENFORCE_EQ(block_mask.dims()[3],
(seqlen_k + 127) / 128,
common::errors::InvalidArgument(
"blockmask is now only support blockdim_k = 128 "));

PADDLE_ENFORCE_EQ(
block_mask.dims()[1],
startend_row_indices.dims()[1],
common::errors::InvalidArgument("blockmask is now only support same "
"dim num_heads with flashmask "));
}

if (is_blockmask) {
// xhy: blockmask is now only support blockdim_q k = 128
dynload::flashmaskv2_fwd_params_set_m_block_dim(params_handle, 128);
dynload::flashmaskv2_fwd_params_set_n_block_dim(params_handle, 128);
dynload::flashmaskv2_fwd_params_set_block_mask_ptr(
params_handle, (block_mask.data<int32_t>()));
}

if (is_flashmask) {
if (lt_start_row_indices.initialized())
dynload::flashmaskv2_fwd_params_set_lt_start_ptr(
params_handle,
const_cast<int32_t *>(lt_start_row_indices.data<int32_t>()));
params_handle, (lt_start_row_indices.data<int32_t>()));
else
dynload::flashmaskv2_fwd_params_set_lt_start_ptr(params_handle, nullptr);

if (lt_end_row_indices.initialized())
dynload::flashmaskv2_fwd_params_set_lt_end_ptr(
params_handle,
const_cast<int32_t *>(lt_end_row_indices.data<int32_t>()));
params_handle, (lt_end_row_indices.data<int32_t>()));
else
dynload::flashmaskv2_fwd_params_set_lt_end_ptr(params_handle, nullptr);

if (ut_start_row_indices.initialized())
dynload::flashmaskv2_fwd_params_set_ut_start_ptr(
params_handle,
const_cast<int32_t *>(ut_start_row_indices.data<int32_t>()));
params_handle, (ut_start_row_indices.data<int32_t>()));
else
dynload::flashmaskv2_fwd_params_set_ut_start_ptr(params_handle, nullptr);

if (ut_end_row_indices.initialized())
dynload::flashmaskv2_fwd_params_set_ut_end_ptr(
params_handle,
const_cast<int32_t *>(ut_end_row_indices.data<int32_t>()));
params_handle, (ut_end_row_indices.data<int32_t>()));
else
dynload::flashmaskv2_fwd_params_set_ut_end_ptr(params_handle, nullptr);

if (flashmask_maxmin.initialized())
dynload::flashmaskv2_fwd_params_set_flashmask_maxmin_ptr(
params_handle,
const_cast<int32_t *>(flashmask_maxmin.data<int32_t>()));
params_handle, (flashmask_maxmin.data<int32_t>()));
else
dynload::flashmaskv2_fwd_params_set_flashmask_maxmin_ptr(params_handle,
nullptr);
Expand Down Expand Up @@ -2266,6 +2302,7 @@ void FlashMaskV2Kernel(const Context &dev_ctx,
const DenseTensor &k,
const DenseTensor &v,
const DenseTensor &startend_row_indices,
const paddle::optional<DenseTensor> &block_mask,
const float softmax_scale,
bool is_causal,
DenseTensor *out,
Expand Down Expand Up @@ -2296,6 +2333,7 @@ void FlashMaskV2Kernel(const Context &dev_ctx,
paddle::none, // v_descale_
paddle::none, // scheduler_metadata_
startend_row_indices,
block_mask,
0, // max_seqlen_q_
0, // max_seqlen_k_
softmax_scale,
Expand Down Expand Up @@ -2339,4 +2377,6 @@ PD_REGISTER_KERNEL(flashmask_attention_v2,
ALL_LAYOUT,
phi::FlashMaskV2Kernel,
phi::float16,
phi::bfloat16) {}
phi::bfloat16) {
kernel->InputAt(4).SetBackend(phi::Backend::ALL_BACKEND); // block_mask
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个设置的作用是?

}
5 changes: 3 additions & 2 deletions paddle/phi/ops/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1228,8 +1228,9 @@
data_type: q

- backward_op : flashmask_attention_v2_grad
forward : flashmask_attention_v2 (Tensor q, Tensor k, Tensor v, Tensor startend_row_indices, float softmax_scale, bool is_causal) -> Tensor(out), Tensor(softmax_lse)
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor startend_row_indices, Tensor out_grad, float softmax_scale, bool is_causal)
forward : flashmask_attention_v2 (Tensor q, Tensor k, Tensor v, Tensor startend_row_indices,Tensor block_mask, float softmax_scale, bool is_causal) -> Tensor(out), Tensor(softmax_lse)
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor startend_row_indices, Tensor block_mask, Tensor out_grad, float softmax_scale, bool is_causal)
optional : block_mask
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta :
func : FlashAttnGradInferMeta
Expand Down
Loading
Loading