-
Notifications
You must be signed in to change notification settings - Fork 5.9k
add blockmask to flashmask #76407
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add blockmask to flashmask #76407
Conversation
|
你的PR提交成功,感谢你对开源项目的贡献! |
|
xiehaoyang seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account. You have signed the CLA already but the status is still pending? Let us recheck it. |
d1cda79 to
674eab0
Compare
|
/re-run all-failed |
Codecov Report❌ Patch coverage is
❌ Your patch status has failed because the patch coverage (0.00%) is below the target coverage (90.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## develop #76407 +/- ##
==========================================
Coverage ? 0.00%
==========================================
Files ? 1
Lines ? 11
Branches ? 0
==========================================
Hits ? 0
Misses ? 11
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
| if (softmax_lse_log2) { | ||
| dev_ctx.template Alloc<float>(softmax_lse_log2); | ||
| } | ||
| // std::cout << "dq_accum:" << dq_accum->dims() << std::endl; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些不需要的注释可以去掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,我在去掉后重新 Push 一下
| dq_accum->Resize(common::make_ddim( | ||
| {num_heads, total_q_padded_rounded * head_size_rounded})); | ||
| } | ||
| // std::cout << "enter:" << dq_accum->dims() << std::endl; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些不需要的注释可以去掉
| dynload::flashmaskv2_bwd_params_set_n_block_dim(params_handle, 128); | ||
| dynload::flashmaskv2_bwd_params_set_block_mask_ptr( | ||
| params_handle, | ||
| const_cast<int32_t *>(block_mask_indices.data<int32_t>())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个const_cast是必需的吗?我看flash-attention仓库里,block_mask_ptr实际也不是const ptr?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分是仿照之前的 start_row_indices 写的,此处通过 .data<int32_t>() 获取的指针不是 const int32_t * 类型的指针,应该不太需要 const_cast 来去掉 const 修饰。前面的 start_row_indices 相关的指针是否也需要进行类似处理?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看了下代码,目前好像已经支持返回非const的指针了
https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/core/dense_tensor.h#L215
| const_cast<int32_t *>(block_mask_indices.data<int32_t>())); | ||
| // phi::funcs::SetConstant<Context, T> set_dq_zero; | ||
| // // dev_ctx.template Alloc<T>(dq); | ||
| // set_dq_zero(dev_ctx, dq, T{0}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些不需要的注释可以去掉
| dynload::flashmaskv2_fwd_params_set_n_block_dim(params_handle, 128); | ||
| dynload::flashmaskv2_fwd_params_set_block_mask_ptr( | ||
| params_handle, | ||
| const_cast<int32_t *>(block_mask_indices.data<int32_t>())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个const_cast是不是可以去掉?
|
|
||
| assert key.shape[3] == 128, ( | ||
| "headdim must be 128 when using block_mask_attn" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
block_mask功能当前只支持fa3,且不支持deterministic,是否需要再加个fa version和deterministic的assert?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉确实需要 assert ,我添加一下
674eab0 to
cced066
Compare
|
/re-run all-failed |
5 similar comments
|
/re-run all-failed |
|
/re-run all-failed |
|
/re-run all-failed |
|
/re-run all-failed |
|
/re-run all-failed |
yuanlehome
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for op yaml
| training (bool): Whether the module is in training mode. Default is True. | ||
| name (str, optional): Name of the operation. Default is None. Normally, users do not need to set this property. | ||
| For more information, refer to :ref:`api_guide_Name` . | ||
| block_mask_indices (tensor, optional): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是不是多了一行
34ec177 to
6ab581d
Compare
|
/re-run all-failed |
add test fix description del test code
6ab581d to
4c72c9d
Compare
| dynload::flashmaskv2_bwd_params_set_n_block_dim(params_handle, 128); | ||
| dynload::flashmaskv2_bwd_params_set_block_mask_ptr( | ||
| params_handle, (block_mask.data<int32_t>())); | ||
| auto ptr = block_mask.data<int32_t>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个ptr在哪里被用到的?
| phi::float16, | ||
| phi::bfloat16) {} | ||
| phi::bfloat16) { | ||
| kernel->InputAt(4).SetBackend(phi::Backend::ALL_BACKEND); // block_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个设置的作用是?
yuanlehome
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for op yaml
* add blockmask add test fix description del test code * delete unused ptr
PR Category
Operator Mechanism
PR Types
New features
Description
Add Blockmask Support to Flashmask