Skip to content

Conversation

@starcrown001
Copy link
Contributor

PR Category

Operator Mechanism

PR Types

New features

Description

Add Blockmask Support to Flashmask

  • This update enables simultaneous use of both blockmask and flashmask masking methods, where the two masks are combined via a logical OR operation. Any block masked by blockmask, or any block fully masked by flashmask, will be excluded from computation.
  • Currently, blockmask and flashmask are not fully decoupled. There is no change when using flashmask alone; however, when using blockmask alone, a flashmask tensor (that does not affect the results) still needs to be provided.
  • At present, only cases with headdim=128 and blocksize=128 are supported.
  • Compared to the original blockmask implementation (https://github.com/mit-han-lab/Block-Sparse-Attention), this version achieves a 75% to 150% improvement in forward performance and a 50% to 90% improvement in backward performance on H800.
  • Comprehensive regression testing has been conducted for both accuracy and performance against the original flashmask operator, and there is negligible impact on the accuracy or performance of the original flashmask.

@paddle-bot
Copy link

paddle-bot bot commented Nov 13, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Nov 13, 2025
@CLAassistant
Copy link

CLAassistant commented Nov 13, 2025

CLA assistant check
All committers have signed the CLA.

@CLAassistant
Copy link

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.


xiehaoyang seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account.
You have signed the CLA already but the status is still pending? Let us recheck it.

@starcrown001
Copy link
Contributor Author

/re-run all-failed

@codecov-commenter
Copy link

codecov-commenter commented Nov 13, 2025

Codecov Report

❌ Patch coverage is 0% with 11 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@2d9355a). Learn more about missing BASE report.

Files with missing lines Patch % Lines
python/paddle/nn/functional/flash_attention.py 0.00% 11 Missing ⚠️

❌ Your patch status has failed because the patch coverage (0.00%) is below the target coverage (90.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files
@@            Coverage Diff             @@
##             develop   #76407   +/-   ##
==========================================
  Coverage           ?    0.00%           
==========================================
  Files              ?        1           
  Lines              ?       11           
  Branches           ?        0           
==========================================
  Hits               ?        0           
  Misses             ?       11           
  Partials           ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

if (softmax_lse_log2) {
dev_ctx.template Alloc<float>(softmax_lse_log2);
}
// std::cout << "dq_accum:" << dq_accum->dims() << std::endl;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些不需要的注释可以去掉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,我在去掉后重新 Push 一下

dq_accum->Resize(common::make_ddim(
{num_heads, total_q_padded_rounded * head_size_rounded}));
}
// std::cout << "enter:" << dq_accum->dims() << std::endl;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些不需要的注释可以去掉

dynload::flashmaskv2_bwd_params_set_n_block_dim(params_handle, 128);
dynload::flashmaskv2_bwd_params_set_block_mask_ptr(
params_handle,
const_cast<int32_t *>(block_mask_indices.data<int32_t>()));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个const_cast是必需的吗?我看flash-attention仓库里,block_mask_ptr实际也不是const ptr?

Copy link
Contributor Author

@starcrown001 starcrown001 Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分是仿照之前的 start_row_indices 写的,此处通过 .data<int32_t>() 获取的指针不是 const int32_t * 类型的指针,应该不太需要 const_cast 来去掉 const 修饰。前面的 start_row_indices 相关的指针是否也需要进行类似处理?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看了下代码,目前好像已经支持返回非const的指针了
https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/core/dense_tensor.h#L215

const_cast<int32_t *>(block_mask_indices.data<int32_t>()));
// phi::funcs::SetConstant<Context, T> set_dq_zero;
// // dev_ctx.template Alloc<T>(dq);
// set_dq_zero(dev_ctx, dq, T{0});
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些不需要的注释可以去掉

dynload::flashmaskv2_fwd_params_set_n_block_dim(params_handle, 128);
dynload::flashmaskv2_fwd_params_set_block_mask_ptr(
params_handle,
const_cast<int32_t *>(block_mask_indices.data<int32_t>()));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个const_cast是不是可以去掉?


assert key.shape[3] == 128, (
"headdim must be 128 when using block_mask_attn"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

block_mask功能当前只支持fa3,且不支持deterministic,是否需要再加个fa version和deterministic的assert?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉确实需要 assert ,我添加一下

@starcrown001
Copy link
Contributor Author

/re-run all-failed

5 similar comments
@starcrown001
Copy link
Contributor Author

/re-run all-failed

@umiswing
Copy link
Member

/re-run all-failed

@umiswing
Copy link
Member

/re-run all-failed

@umiswing
Copy link
Member

/re-run all-failed

@starcrown001
Copy link
Contributor Author

/re-run all-failed

yuanlehome
yuanlehome previously approved these changes Nov 18, 2025
Copy link
Contributor

@yuanlehome yuanlehome left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for op yaml

training (bool): Whether the module is in training mode. Default is True.
name (str, optional): Name of the operation. Default is None. Normally, users do not need to set this property.
For more information, refer to :ref:`api_guide_Name` .
block_mask_indices (tensor, optional):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是不是多了一行

@starcrown001
Copy link
Contributor Author

/re-run all-failed

add test

fix description

del test code
dynload::flashmaskv2_bwd_params_set_n_block_dim(params_handle, 128);
dynload::flashmaskv2_bwd_params_set_block_mask_ptr(
params_handle, (block_mask.data<int32_t>()));
auto ptr = block_mask.data<int32_t>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个ptr在哪里被用到的?

phi::float16,
phi::bfloat16) {}
phi::bfloat16) {
kernel->InputAt(4).SetBackend(phi::Backend::ALL_BACKEND); // block_mask
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个设置的作用是?

Copy link
Contributor

@yuanlehome yuanlehome left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for op yaml

@GuoxiaWang GuoxiaWang merged commit 43c370f into PaddlePaddle:develop Nov 20, 2025
116 of 129 checks passed
fsylmxx pushed a commit to fsylmxx/Paddle that referenced this pull request Nov 20, 2025
* add blockmask

add test

fix description

del test code

* delete unused ptr
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.