Skip to content

Commit d982741

Browse files
yuanlehomeckl117
authored andcommitted
support deepseek-v3
update 0113 support head_dim=192,256 for append_attn c16 attention run refine code add softmax_scale support weight_only_int8 refine code support tp delete test_append_attn add splited fused_moe from ziyuan add deepseek-v3 class fix repe for deepseek-v3 fix wint8 precision and refine code fix wint4, big diff add e_score_correction_bias fix head_dim fix v3 verify [AutoParallel] open tensor_fusion for benchmark (PaddlePaddle#9749) * open tensor_fusion for benchmark fix loraga merge (PaddlePaddle#9765) * fix loraga merge * change sign Fix ernie ci auto trainer error (PaddlePaddle#9758) * [AutoParallel]:fix ernine auto_trainer error * [AutoParallel]:fix ernine auto_trainer error * [AutoParallel]:fix ernine auto_trainer error * [AutoParallel]:fix ernine auto_trainer error * [AutoParallel]:fix ernine auto_trainer error * [AutoParallel]:fix ernine auto_trainer error * [AutoParallel]:fix ernine auto_trainer error * [AutoParallel]:fix ernine auto_trainer error * [AutoParallel]:fix ernine auto_trainer error * [AutoParallel]:fix ernine auto_trainer error * [AutoParallel]:fix ernine auto_trainer error * [AutoParallel]:fix ernine auto_trainer error * [AutoParallel]:fix ernine auto_trainer error * Update run_pretrain_auto.py Update README.md (PaddlePaddle#9766) * Update README.md [BugFix] Fix matryoshka norm loss (PaddlePaddle#9774) * fix matryoshka norm [Distributed] support fuse optimizer (PaddlePaddle#9519) (PaddlePaddle#9777) Update register_sequence_parallel_allreduce_hooks (PaddlePaddle#9782) * fix sequence parallel * update register_sequence_parallel_allreduce_hooks * update fuse_sequence_parallel_allreduce Fix ce error (PaddlePaddle#9783) * [AutoParallel]:fix ci error * [AutoParallel]:fix ci error fix (PaddlePaddle#9779) [MoE] fix expert parallel (PaddlePaddle#9760) * fix moe uc fix dpo pp criterion (PaddlePaddle#9786) [Infer] Add pir_model path for server infer. (PaddlePaddle#9790) fix d2s fix v3 verify support qk_head_dim != v_head_dim support fp8 batch gemm on cutlass3.x upgrade cutlass version for block_wise fp8 gemm change cutlass commit to ckl117 group_wise branch support fp8 block gemm, but private cutlass commit, and TODO: update fp8 dual gemm api on cutlass3.x support auto tune fp8 block gemm code update cutlass to v3.7.0, todo: support block gemm based on v3.7.0 support block gemm on cutlass v3.7.0 commit code check code check check dynamic_quant ad block builder dir rename group_quant fix wint8 v_head_dim fix rope fix qwen2 mla use position_ids only remove control flow remove gpu concat fix norm weight dtype remove all_reduce in fused_moe part support fp8 check group_quant and fake fp8 check support block gemm [LLM] support flash device on static model (PaddlePaddle#9619) (PaddlePaddle#9787) * [LLM] support flash device on static model * [LLM] adapt pdc sdk [LLM Benchmark]update scripts (PaddlePaddle#9722) * add no_proxy & del paddlenlp_ops * update timeout for dpo * fix sequence_parallel * add timeout * add Total_Tokens_per_second_per_gpu * fix Tokens_per_second_per_gpu * update Total_Tokens_per_second_per_gpu mergekit gpu 1226 (PaddlePaddle#9702) * mergekit gpu 1226 * merge model gpu * merge gpu * add lora model * change valueerror * add lora * gpu test [LLM] merge code from fastdeploy (PaddlePaddle#9791) * [LLM] update llm server dockerfiles * merge code from fastdeploy [Inference] Support eagle for llama (PaddlePaddle#9812) [CI] Fix ci of small models (PaddlePaddle#9633) [Trainer] Wrap model when lora is ON and only do evaluation. (PaddlePaddle#9803) [README] Update README.md for documention (PaddlePaddle#9785) * Update README.md * Update README.md * Update README_en.md fix static run wint8 and fake-fp8, todo: support data type does not match support fp8, but ffn1 and moe in wint8 support ffn1 fp8 block gemm done ffn1 fp8 block gemm block gemm done block gemm support batch refine rope code compute position_ids use custom op fix split_param (PaddlePaddle#9817) [LLM] Update model convert and fix TP for deepseekv3 (PaddlePaddle#9797) * fix model convert and tp in MoEMLP * fix tp_action filter * update convert accoding to num_nextn_predict_layers * add deepseek-R1 fuse rope fix macro fix mixtral set_state_dict block_wise weight support fp8 per tensor network, no support scale Tensor for tensor gemm deepseek-v3 fp8 tensor gemm network, but precision fault add triton fp8 fused_moe kernel fix moe triton kernel add moe triton kernel fix fix fp8 block gemm precision moe triton fp8 network support moe triton and precision correct, but shared ffn1 ffn2 incorrect fp8 block network, no check shared ffn1-ffn2 in v2-lite delete wint8 in fake delete some useless code and verify per tensor net with in qkv outlinear ffn1 ffn2, but triton moe don't match api fp8 block quant when load model, and code check fix tokenizer and qwen [AutoParallel] add sharding tensor_fusion save load switch (PaddlePaddle#9810) * support tensor_fusion save load * apply suggestions from code review 修复benchmark多机任务异常退出的处理 (PaddlePaddle#9651) * 修复benchmark多机任务异常退出的处理 * fix bug * update Fix LLAMA arg parsing bug in pp (PaddlePaddle#9806) [Readme] Update mixtral.md (PaddlePaddle#9829) [XPU] Support empty_cache on XPUs (PaddlePaddle#9789) * [XPU] Support empty_cache on XPUs * warn if current device doesn't support [Inference] Fix multibatch inference (PaddlePaddle#9831) * fix batch infra * fix deepseekv2 infra Fix position_ids for infra (PaddlePaddle#9841) fix moe diff due to e_score_correction_bias fix fast tokenizer [LLM] Add pipeline and flashmask for Qwen2Moe and Deepseek (PaddlePaddle#9827) * add modleing_pp * add modleing_pp for qwen2moe * add flashmask and pp for Qwen2MoE and Deepseek * remove * fix fast_tokenizer save * update for topk_weight of noaux_tc * fix for flashmask * add use_expert_parallel for pretrain * fix tokenizer test [Mergekit]update & add LoRA merge (PaddlePaddle#9811) * add * fix bug * fix * add * add lora merge * add * add * add * add * add * add [Unified Checkpoint] Fix expert parallel (PaddlePaddle#9821) * fix expert parallel * fix split_param for expert parallel * add filter_sync_parameters fix import [Inference] Flask server compatible with OpenAI api. (PaddlePaddle#9828) * flask server compatible with OpenAI api. * fix max_length to max_tokens. * fix with think model. [LLM] fix checkpoint save for non flash mode (PaddlePaddle#9830) support mla for speculate [DSK] support deepseek-v3/r1 (mha/fp16/bf16/wint8/wint4) (PaddlePaddle#9769) * support deepseek-v3 * support head_dim=192,256 for append_attn c16 * update 0113 * attention run * refine code * add softmax_scale * support weight_only_int8 * refine code * support tp * delete test_append_attn * add splited fused_moe from ziyuan * fix repe for deepseek-v3 * add deepseek-v3 class * fix wint8 precision and refine code * fix wint4, big diff * add e_score_correction_bias * fix head_dim * fix v3 verify * fix d2s * fix v3 verify * support qk_head_dim != v_head_dim * fix wint8 v_head_dim * fix rope * fix qwen2 * mla use position_ids only * remove control flow * remove gpu concat * fix norm weight dtype * remove all_reduce in fused_moe * fix static run * refine rope code * compute position_ids use custom op * fuse rope * fix macro * fix mixtral * support mla for speculate * fix tokenizer and qwen * fix moe diff due to e_score_correction_bias * fix fast tokenizer * fix import --------- Co-authored-by: lizhenyun01 <[email protected]> Co-authored-by: lizhenyun <[email protected]> Solve the compatibility problem of type annotation Python version (PaddlePaddle#9853) mix fp8 and wint8 save extra special tokens (PaddlePaddle#9837) [Bugfix] Fix dsk rope diff (PaddlePaddle#9859) * fix dsk diff * fix * update merge develop to check fp8 moe-wint8 fix deepseek v3 fp8 precision fix deepseek weight quant [Optimization] Support lower memory cards. (PaddlePaddle#9804) * support lower memory cards. * add doc for v100 16G such devices. * remove debug info. * add pre divided factor to overcome overfit problem for fp16 attention. Support XPU for auto-paralllel LLaMa (PaddlePaddle#9796) * Support XPU for auto-paralllel LLaMa * Update * Update * Update * Update * Fix CI errors * Update [XPU] Add xpu fused op for deepseek (PaddlePaddle#9854) [Inference] Update deepseek (PaddlePaddle#9864) * fix * fix infra [PreTrain] Support deepseek mfu for pretraining and fix tflops for pretrain pipe model (PaddlePaddle#9855) * git flops with pp model. * Support hareware tflops for deepseek. [Inference]Support mtp with deepseek-v3 (PaddlePaddle#9856) * support mtp with deepseek_v3 both in static and dygraph mode * fix speculate tokenizer in unittest * delete useless code check code
1 parent 2c556e7 commit d982741

File tree

218 files changed

+15311
-2556
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

218 files changed

+15311
-2556
lines changed

README.md

+10-5
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
------------------------------------------------------------------------------------------
88

99
<p align="center">
10-
<a href="./LICENSE"><img src="https://img.shields.io/badge/license-Apache%202-dfd.svg"></a>
10+
<a href="https://paddlenlp.readthedocs.io/en/latest/?badge=latest"><img src="https://readthedocs.org/projects/paddlenlp/badge/?version=latest">
1111
<a href="https://github.com/PaddlePaddle/PaddleNLP/releases"><img src="https://img.shields.io/github/v/release/PaddlePaddle/PaddleNLP?color=ffa"></a>
1212
<a href=""><img src="https://img.shields.io/badge/python-3.7+-aff.svg"></a>
1313
<a href=""><img src="https://img.shields.io/badge/os-linux%2C%20win%2C%20mac-pink.svg"></a>
@@ -16,6 +16,7 @@
1616
<a href="https://pypi.org/project/paddlenlp/"><img src="https://img.shields.io/pypi/dm/paddlenlp?color=9cf"></a>
1717
<a href="https://github.com/PaddlePaddle/PaddleNLP/issues"><img src="https://img.shields.io/github/issues/PaddlePaddle/PaddleNLP?color=9cc"></a>
1818
<a href="https://github.com/PaddlePaddle/PaddleNLP/stargazers"><img src="https://img.shields.io/github/stars/PaddlePaddle/PaddleNLP?color=ccf"></a>
19+
<a href="./LICENSE"><img src="https://img.shields.io/badge/license-Apache%202-dfd.svg"></a>
1920
</p>
2021

2122
<h4 align="center">
@@ -69,6 +70,9 @@
6970

7071
大模型套件高性能推理模块内置动态插入和全环节算子融合策略,极大加快并行推理速度。底层实现细节封装化,实现开箱即用的高性能并行推理能力。
7172

73+
## 文档
74+
更多详细文档, 请访问 [PaddleNLP Documentation](https://paddlenlp.readthedocs.io/).
75+
7276
------------------------------------------------------------------------------------------
7377

7478
## 模型支持
@@ -91,6 +95,7 @@
9195
| [ChatGLM3](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/chatglm2) | THUDM/chatglm3-6b |
9296
| [DeepSeekV2](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/config/deepseek-v2) | deepseek-ai/DeepSeek-V2, deepseek-ai/DeepSeek-V2-Chat, deepseek-ai/DeepSeek-V2-Lite, deepseek-ai/DeepSeek-V2-Lite-Chat, deepseek-ai/DeepSeek-Coder-V2-Base, deepseek-ai/DeepSeek-Coder-V2-Instruct, deepseek-ai/DeepSeek-Coder-V2-Lite-Base, deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct |
9397
| [DeepSeekV3](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/config/deepseek-v2) | deepseek-ai/DeepSeek-V3, deepseek-ai/DeepSeek-V3-Base |
98+
| [DeepSeek-R1](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/config/deepseek-v2) | deepseek-ai/DeepSeek-R1, deepseek-ai/DeepSeek-R1-Zero, deepseek-ai/DeepSeek-R1-Distill-Llama-70B, deepseek-ai/DeepSeek-R1-Distill-Llama-8B, deepseek-ai/DeepSeek-R1-Distill-Qwen-14B, deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B, deepseek-ai/DeepSeek-R1-Distill-Qwen-32B, deepseek-ai/DeepSeek-R1-Distill-Qwen-7B |
9499
| [Gemma](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/gemma) | google/gemma-7b, google/gemma-7b-it, google/gemma-2b, google/gemma-2b-it |
95100
| [Mistral](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/mistral) | mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-7B-v0.1 |
96101
| [Mixtral](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/mixtral) | mistralai/Mixtral-8x7B-Instruct-v0.1 |
@@ -161,7 +166,7 @@
161166
### 环境依赖
162167

163168
* python >= 3.8
164-
* paddlepaddle >= 3.0.0b0
169+
* paddlepaddle >= 3.0.0rc0
165170

166171
如果您尚未安装 PaddlePaddle,请参考 [飞桨官网](https://www.paddlepaddle.org.cn/) 进行安装。
167172

@@ -206,7 +211,7 @@ wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwe
206211
wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwebtext_100k.idx
207212
cd .. # change folder to PaddleNLP/llm
208213
# 如需使用use_fused_rms_norm=true,需要前往slm/model_zoo/gpt-3/external_ops安装fused_ln
209-
python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" run_pretrain.py ./config/llama/pretrain_argument.json --use_fused_rms_norm false
214+
python -u run_pretrain.py ./config/qwen/pretrain_argument_0p5b.json
210215
```
211216

212217
### 大模型 SFT 精调
@@ -216,7 +221,7 @@ git clone https://github.com/PaddlePaddle/PaddleNLP.git && cd PaddleNLP # 如已
216221
mkdir -p llm/data && cd llm/data
217222
wget https://bj.bcebos.com/paddlenlp/datasets/examples/AdvertiseGen.tar.gz && tar -zxvf AdvertiseGen.tar.gz
218223
cd .. # change folder to PaddleNLP/llm
219-
python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" run_finetune.py ./config/llama/sft_argument.json
224+
python -u run_finetune.py ./config/qwen/sft_argument_0p5b.json
220225
```
221226

222227
更多大模型全流程步骤,请参考[飞桨大模型套件](./llm)介绍。
@@ -231,7 +236,7 @@ dataset = load_dataset("ZHUI/alpaca_demo", split="train")
231236
training_args = SFTConfig(output_dir="Qwen/Qwen2.5-0.5B-SFT", device="gpu")
232237
trainer = SFTTrainer(
233238
args=training_args,
234-
model="Qwen/Qwen2.5-0.5B",
239+
model="Qwen/Qwen2.5-0.5B-Instruct",
235240
train_dataset=dataset,
236241
)
237242
trainer.train()

README_en.md

+6-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
------------------------------------------------------------------------------------------
88

99
<p align="center">
10-
<a href="./LICENSE"><img src="https://img.shields.io/badge/license-Apache%202-dfd.svg"></a>
10+
<a href="https://paddlenlp.readthedocs.io/en/latest/?badge=latest"><img src="https://readthedocs.org/projects/paddlenlp/badge/?version=latest">
1111
<a href="https://github.com/PaddlePaddle/PaddleNLP/releases"><img src="https://img.shields.io/github/v/release/PaddlePaddle/PaddleNLP?color=ffa"></a>
1212
<a href=""><img src="https://img.shields.io/badge/python-3.7+-aff.svg"></a>
1313
<a href=""><img src="https://img.shields.io/badge/os-linux%2C%20win%2C%20mac-pink.svg"></a>
@@ -16,6 +16,7 @@
1616
<a href="https://pypi.org/project/paddlenlp/"><img src="https://img.shields.io/pypi/dm/paddlenlp?color=9cf"></a>
1717
<a href="https://github.com/PaddlePaddle/PaddleNLP/issues"><img src="https://img.shields.io/github/issues/PaddlePaddle/PaddleNLP?color=9cc"></a>
1818
<a href="https://github.com/PaddlePaddle/PaddleNLP/stargazers"><img src="https://img.shields.io/github/stars/PaddlePaddle/PaddleNLP?color=ccf"></a>
19+
<a href="./LICENSE"><img src="https://img.shields.io/badge/license-Apache%202-dfd.svg"></a>
1920
</p>
2021

2122
<h4 align="center">
@@ -52,6 +53,9 @@ The fine-tuning algorithms are deeply integrated with zero-padding data streams
5253

5354
The high-performance inference module of the large model toolkit incorporates dynamic insertion and operator fusion strategies throughout the entire process, greatly accelerating parallel inference speed. The underlying implementation details are encapsulated, enabling out-of-the-box high-performance parallel inference capabilities.
5455

56+
## Documentation
57+
For detailed documentation, visit the [PaddleNLP Documentation](https://paddlenlp.readthedocs.io/).
58+
5559
------------------------------------------------------------------------------------------
5660

5761
## Support Models
@@ -68,7 +72,7 @@ Detailed list 👉 [Supported Model List](https://github.com/PaddlePaddle/Paddle
6872
### Pip Installation
6973

7074
```shell
71-
pip install --upgrade paddlenlp==3.0.0b2
75+
pip install --upgrade paddlenlp==3.0.0b3
7276
```
7377

7478
or you can install the latest develop branch code with the following command:

csrc/README.md

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
# PaddleNLP 自定义 OP
1+
# PaddleNLP 大模型高性能自定义推理算子
22

3-
此文档介绍如何编译安装 PaddleNLP 自定义 OP。
3+
此文档介绍如何编译安装 PaddleNLP 大模型高性能自定义推理算子的安装教程。
4+
5+
使用这些高性能算子,可以大幅提升大模型推理速度。
6+
大模型推理相关教程详见[此处](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/README.md#6-%E6%8E%A8%E7%90%86)
47

58
## 安装 C++ 依赖
69

csrc/gpu/append_attention.cu

+27-10
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
5656
const std::string& cache_quant_type_str,
5757
const bool use_neox_rotary_style,
5858
const int max_input_length,
59+
const float softmax_scale,
5960
const float quant_max_bound,
6061
const float quant_min_bound,
6162
const float out_linear_in_scale,
@@ -97,21 +98,21 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
9798
if (out_linear_in_scale > 0.0) {
9899
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
99100
fmha_out = GetEmptyTensor(
100-
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
101+
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims_v},
101102
paddle::DataType::INT8,
102103
qkv.place());
103104
}
104105
else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
105106
fmha_out = GetEmptyTensor(
106-
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
107+
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims_v},
107108
paddle::DataType::FLOAT8_E4M3FN,
108109
qkv.place());
109110
}else{
110111
PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0'].");
111112
}
112113
} else {
113114
fmha_out = GetEmptyTensor(
114-
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
115+
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims_v},
115116
D,
116117
qkv.place());
117118
}
@@ -203,6 +204,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
203204
encoder_block_shape_q,
204205
max_input_length,
205206
max_enc_len_this_time_data,
207+
softmax_scale,
206208
quant_max_bound,
207209
quant_min_bound,
208210
out_linear_in_scale,
@@ -240,6 +242,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
240242
encoder_block_shape_q,
241243
max_input_length,
242244
max_enc_len_this_time_data,
245+
softmax_scale,
243246
quant_max_bound,
244247
quant_min_bound,
245248
out_linear_in_scale,
@@ -282,6 +285,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
282285
encoder_block_shape_q,
283286
max_input_length,
284287
max_enc_len_this_time_data,
288+
softmax_scale,
285289
quant_max_bound,
286290
quant_min_bound,
287291
out_linear_in_scale,
@@ -428,6 +432,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
428432
decoder_block_shape_q,
429433
max_input_length,
430434
max_len_kv_data,
435+
softmax_scale,
431436
quant_max_bound,
432437
quant_min_bound,
433438
out_linear_in_scale,
@@ -465,6 +470,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
465470
decoder_block_shape_q,
466471
max_input_length,
467472
max_len_kv_data,
473+
softmax_scale,
468474
quant_max_bound,
469475
quant_min_bound,
470476
out_linear_in_scale,
@@ -508,6 +514,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
508514
decoder_block_shape_q,
509515
max_input_length,
510516
max_len_kv_data,
517+
softmax_scale,
511518
quant_max_bound,
512519
quant_min_bound,
513520
out_linear_in_scale,
@@ -565,6 +572,7 @@ std::vector<paddle::Tensor> AppendAttention(
565572
const std::string& cache_quant_type_str,
566573
const bool use_neox_rotary_style,
567574
const int max_input_length,
575+
const float softmax_scale,
568576
const float quant_max_bound,
569577
const float quant_min_bound,
570578
const float out_linear_in_scale,
@@ -578,9 +586,10 @@ std::vector<paddle::Tensor> AppendAttention(
578586
meta_data.token_nums = qkv_dims[0];
579587
meta_data.kv_num_heads = key_cache_dims[1];
580588
meta_data.head_dims = key_cache_dims[3];
581-
const int total_num_head =
582-
qkv_dims[qkv_dims.size() - 1] / meta_data.head_dims;
583-
meta_data.q_num_heads = total_num_head - 2 * meta_data.kv_num_heads;
589+
meta_data.head_dims_v = value_cache.dims()[3];
590+
const int q_hidden_size =
591+
qkv_dims[qkv_dims.size() - 1] - meta_data.kv_num_heads * (meta_data.head_dims + meta_data.head_dims_v);
592+
meta_data.q_num_heads = q_hidden_size / meta_data.head_dims;
584593

585594
meta_data.max_blocks_per_seq = block_tables.dims()[1];
586595
meta_data.block_size = key_cache.dims()[2];
@@ -626,6 +635,7 @@ std::vector<paddle::Tensor> AppendAttention(
626635
cache_quant_type_str,
627636
use_neox_rotary_style,
628637
max_input_length,
638+
softmax_scale,
629639
quant_max_bound,
630640
quant_min_bound,
631641
out_linear_in_scale,
@@ -672,6 +682,7 @@ std::vector<paddle::Tensor> AppendAttention(
672682
cache_quant_type_str,
673683
use_neox_rotary_style,
674684
max_input_length,
685+
softmax_scale,
675686
quant_max_bound,
676687
quant_min_bound,
677688
out_linear_in_scale,
@@ -719,6 +730,7 @@ std::vector<paddle::Tensor> AppendAttention(
719730
cache_quant_type_str,
720731
use_neox_rotary_style,
721732
max_input_length,
733+
softmax_scale,
722734
quant_max_bound,
723735
quant_min_bound,
724736
out_linear_in_scale,
@@ -764,6 +776,7 @@ std::vector<paddle::Tensor> AppendAttention(
764776
cache_quant_type_str,
765777
use_neox_rotary_style,
766778
max_input_length,
779+
softmax_scale,
767780
quant_max_bound,
768781
quant_min_bound,
769782
out_linear_in_scale,
@@ -821,10 +834,12 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
821834
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape) {
822835
const int token_num = qkv_shape[0];
823836
const int kv_num_heads = key_cache_shape[1];
824-
const int head_dim = key_cache_shape[3];
825-
const int total_num_head = qkv_shape[qkv_shape.size() - 1] / head_dim;
826-
const int num_heads = total_num_head - 2 * kv_num_heads;
827-
return {{token_num, num_heads * head_dim}, qkv_shape};
837+
const int head_dim_qk = key_cache_shape[3];
838+
const int head_dim_v = value_cache_shape[3];
839+
const int q_hidden_size =
840+
qkv_shape[qkv_shape.size() - 1] - kv_num_heads * (head_dim_qk + head_dim_v);
841+
const int num_heads = q_hidden_size / head_dim_qk;
842+
return {{token_num, num_heads * head_dim_v}, qkv_shape};
828843
}
829844

830845
std::vector<paddle::DataType> AppendAttentionInferDtype(
@@ -865,6 +880,7 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
865880
const std::string& cache_quant_type_str,
866881
const bool use_neox_rotary_style,
867882
const int max_input_length,
883+
const float softmax_scale,
868884
const float quant_max_bound,
869885
const float quant_min_bound,
870886
const float out_linear_in_scale,
@@ -941,6 +957,7 @@ PD_BUILD_OP(append_attention)
941957
"cache_quant_type: std::string",
942958
"use_neox_rotary_style: bool",
943959
"max_input_length: int",
960+
"softmax_scale: float",
944961
"quant_max_bound: float",
945962
"quant_min_bound: float",
946963
"out_linear_in_scale: float",

0 commit comments

Comments
 (0)