Skip to content

support deepseek-v3 #9878

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
------------------------------------------------------------------------------------------

<p align="center">
<a href="./LICENSE"><img src="https://img.shields.io/badge/license-Apache%202-dfd.svg"></a>
<a href="https://paddlenlp.readthedocs.io/en/latest/?badge=latest"><img src="https://readthedocs.org/projects/paddlenlp/badge/?version=latest">
<a href="https://github.com/PaddlePaddle/PaddleNLP/releases"><img src="https://img.shields.io/github/v/release/PaddlePaddle/PaddleNLP?color=ffa"></a>
<a href=""><img src="https://img.shields.io/badge/python-3.7+-aff.svg"></a>
<a href=""><img src="https://img.shields.io/badge/os-linux%2C%20win%2C%20mac-pink.svg"></a>
Expand All @@ -16,6 +16,7 @@
<a href="https://pypi.org/project/paddlenlp/"><img src="https://img.shields.io/pypi/dm/paddlenlp?color=9cf"></a>
<a href="https://github.com/PaddlePaddle/PaddleNLP/issues"><img src="https://img.shields.io/github/issues/PaddlePaddle/PaddleNLP?color=9cc"></a>
<a href="https://github.com/PaddlePaddle/PaddleNLP/stargazers"><img src="https://img.shields.io/github/stars/PaddlePaddle/PaddleNLP?color=ccf"></a>
<a href="./LICENSE"><img src="https://img.shields.io/badge/license-Apache%202-dfd.svg"></a>
</p>

<h4 align="center">
Expand Down Expand Up @@ -69,6 +70,9 @@

大模型套件高性能推理模块内置动态插入和全环节算子融合策略,极大加快并行推理速度。底层实现细节封装化,实现开箱即用的高性能并行推理能力。

## 文档
更多详细文档, 请访问 [PaddleNLP Documentation](https://paddlenlp.readthedocs.io/).

------------------------------------------------------------------------------------------

## 模型支持
Expand All @@ -91,6 +95,7 @@
| [ChatGLM3](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/chatglm2) | THUDM/chatglm3-6b |
| [DeepSeekV2](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/config/deepseek-v2) | deepseek-ai/DeepSeek-V2, deepseek-ai/DeepSeek-V2-Chat, deepseek-ai/DeepSeek-V2-Lite, deepseek-ai/DeepSeek-V2-Lite-Chat, deepseek-ai/DeepSeek-Coder-V2-Base, deepseek-ai/DeepSeek-Coder-V2-Instruct, deepseek-ai/DeepSeek-Coder-V2-Lite-Base, deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct |
| [DeepSeekV3](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/config/deepseek-v2) | deepseek-ai/DeepSeek-V3, deepseek-ai/DeepSeek-V3-Base |
| [DeepSeek-R1](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/config/deepseek-v2) | deepseek-ai/DeepSeek-R1, deepseek-ai/DeepSeek-R1-Zero, deepseek-ai/DeepSeek-R1-Distill-Llama-70B, deepseek-ai/DeepSeek-R1-Distill-Llama-8B, deepseek-ai/DeepSeek-R1-Distill-Qwen-14B, deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B, deepseek-ai/DeepSeek-R1-Distill-Qwen-32B, deepseek-ai/DeepSeek-R1-Distill-Qwen-7B |
| [Gemma](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/gemma) | google/gemma-7b, google/gemma-7b-it, google/gemma-2b, google/gemma-2b-it |
| [Mistral](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/mistral) | mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-7B-v0.1 |
| [Mixtral](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/mixtral) | mistralai/Mixtral-8x7B-Instruct-v0.1 |
Expand Down Expand Up @@ -161,7 +166,7 @@
### 环境依赖

* python >= 3.8
* paddlepaddle >= 3.0.0b0
* paddlepaddle >= 3.0.0rc0

如果您尚未安装 PaddlePaddle,请参考 [飞桨官网](https://www.paddlepaddle.org.cn/) 进行安装。

Expand Down Expand Up @@ -206,7 +211,7 @@ wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwe
wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwebtext_100k.idx
cd .. # change folder to PaddleNLP/llm
# 如需使用use_fused_rms_norm=true,需要前往slm/model_zoo/gpt-3/external_ops安装fused_ln
python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" run_pretrain.py ./config/llama/pretrain_argument.json --use_fused_rms_norm false
python -u run_pretrain.py ./config/qwen/pretrain_argument_0p5b.json
```

### 大模型 SFT 精调
Expand All @@ -216,7 +221,7 @@ git clone https://github.com/PaddlePaddle/PaddleNLP.git && cd PaddleNLP # 如已
mkdir -p llm/data && cd llm/data
wget https://bj.bcebos.com/paddlenlp/datasets/examples/AdvertiseGen.tar.gz && tar -zxvf AdvertiseGen.tar.gz
cd .. # change folder to PaddleNLP/llm
python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" run_finetune.py ./config/llama/sft_argument.json
python -u run_finetune.py ./config/qwen/sft_argument_0p5b.json
```

更多大模型全流程步骤,请参考[飞桨大模型套件](./llm)介绍。
Expand All @@ -231,7 +236,7 @@ dataset = load_dataset("ZHUI/alpaca_demo", split="train")
training_args = SFTConfig(output_dir="Qwen/Qwen2.5-0.5B-SFT", device="gpu")
trainer = SFTTrainer(
args=training_args,
model="Qwen/Qwen2.5-0.5B",
model="Qwen/Qwen2.5-0.5B-Instruct",
train_dataset=dataset,
)
trainer.train()
Expand Down
8 changes: 6 additions & 2 deletions README_en.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
------------------------------------------------------------------------------------------

<p align="center">
<a href="./LICENSE"><img src="https://img.shields.io/badge/license-Apache%202-dfd.svg"></a>
<a href="https://paddlenlp.readthedocs.io/en/latest/?badge=latest"><img src="https://readthedocs.org/projects/paddlenlp/badge/?version=latest">
<a href="https://github.com/PaddlePaddle/PaddleNLP/releases"><img src="https://img.shields.io/github/v/release/PaddlePaddle/PaddleNLP?color=ffa"></a>
<a href=""><img src="https://img.shields.io/badge/python-3.7+-aff.svg"></a>
<a href=""><img src="https://img.shields.io/badge/os-linux%2C%20win%2C%20mac-pink.svg"></a>
Expand All @@ -16,6 +16,7 @@
<a href="https://pypi.org/project/paddlenlp/"><img src="https://img.shields.io/pypi/dm/paddlenlp?color=9cf"></a>
<a href="https://github.com/PaddlePaddle/PaddleNLP/issues"><img src="https://img.shields.io/github/issues/PaddlePaddle/PaddleNLP?color=9cc"></a>
<a href="https://github.com/PaddlePaddle/PaddleNLP/stargazers"><img src="https://img.shields.io/github/stars/PaddlePaddle/PaddleNLP?color=ccf"></a>
<a href="./LICENSE"><img src="https://img.shields.io/badge/license-Apache%202-dfd.svg"></a>
</p>

<h4 align="center">
Expand Down Expand Up @@ -52,6 +53,9 @@ The fine-tuning algorithms are deeply integrated with zero-padding data streams

The high-performance inference module of the large model toolkit incorporates dynamic insertion and operator fusion strategies throughout the entire process, greatly accelerating parallel inference speed. The underlying implementation details are encapsulated, enabling out-of-the-box high-performance parallel inference capabilities.

## Documentation
For detailed documentation, visit the [PaddleNLP Documentation](https://paddlenlp.readthedocs.io/).

------------------------------------------------------------------------------------------

## Support Models
Expand All @@ -68,7 +72,7 @@ Detailed list 👉 [Supported Model List](https://github.com/PaddlePaddle/Paddle
### Pip Installation

```shell
pip install --upgrade paddlenlp==3.0.0b2
pip install --upgrade paddlenlp==3.0.0b3
```

or you can install the latest develop branch code with the following command:
Expand Down
7 changes: 5 additions & 2 deletions csrc/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# PaddleNLP 自定义 OP
# PaddleNLP 大模型高性能自定义推理算子

此文档介绍如何编译安装 PaddleNLP 自定义 OP。
此文档介绍如何编译安装 PaddleNLP 大模型高性能自定义推理算子的安装教程。

使用这些高性能算子,可以大幅提升大模型推理速度。
大模型推理相关教程详见[此处](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/README.md#6-%E6%8E%A8%E7%90%86)。

## 安装 C++ 依赖

Expand Down
37 changes: 27 additions & 10 deletions csrc/gpu/append_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const int max_input_length,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
Expand Down Expand Up @@ -97,21 +98,21 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
if (out_linear_in_scale > 0.0) {
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
fmha_out = GetEmptyTensor(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims_v},
paddle::DataType::INT8,
qkv.place());
}
else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
fmha_out = GetEmptyTensor(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims_v},
paddle::DataType::FLOAT8_E4M3FN,
qkv.place());
}else{
PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0'].");
}
} else {
fmha_out = GetEmptyTensor(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims_v},
D,
qkv.place());
}
Expand Down Expand Up @@ -203,6 +204,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
encoder_block_shape_q,
max_input_length,
max_enc_len_this_time_data,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -240,6 +242,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
encoder_block_shape_q,
max_input_length,
max_enc_len_this_time_data,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -282,6 +285,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
encoder_block_shape_q,
max_input_length,
max_enc_len_this_time_data,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -428,6 +432,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
decoder_block_shape_q,
max_input_length,
max_len_kv_data,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -465,6 +470,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
decoder_block_shape_q,
max_input_length,
max_len_kv_data,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -508,6 +514,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
decoder_block_shape_q,
max_input_length,
max_len_kv_data,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -565,6 +572,7 @@ std::vector<paddle::Tensor> AppendAttention(
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const int max_input_length,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
Expand All @@ -578,9 +586,10 @@ std::vector<paddle::Tensor> AppendAttention(
meta_data.token_nums = qkv_dims[0];
meta_data.kv_num_heads = key_cache_dims[1];
meta_data.head_dims = key_cache_dims[3];
const int total_num_head =
qkv_dims[qkv_dims.size() - 1] / meta_data.head_dims;
meta_data.q_num_heads = total_num_head - 2 * meta_data.kv_num_heads;
meta_data.head_dims_v = value_cache.dims()[3];
const int q_hidden_size =
qkv_dims[qkv_dims.size() - 1] - meta_data.kv_num_heads * (meta_data.head_dims + meta_data.head_dims_v);
meta_data.q_num_heads = q_hidden_size / meta_data.head_dims;

meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = key_cache.dims()[2];
Expand Down Expand Up @@ -626,6 +635,7 @@ std::vector<paddle::Tensor> AppendAttention(
cache_quant_type_str,
use_neox_rotary_style,
max_input_length,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -672,6 +682,7 @@ std::vector<paddle::Tensor> AppendAttention(
cache_quant_type_str,
use_neox_rotary_style,
max_input_length,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -719,6 +730,7 @@ std::vector<paddle::Tensor> AppendAttention(
cache_quant_type_str,
use_neox_rotary_style,
max_input_length,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -764,6 +776,7 @@ std::vector<paddle::Tensor> AppendAttention(
cache_quant_type_str,
use_neox_rotary_style,
max_input_length,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -821,10 +834,12 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape) {
const int token_num = qkv_shape[0];
const int kv_num_heads = key_cache_shape[1];
const int head_dim = key_cache_shape[3];
const int total_num_head = qkv_shape[qkv_shape.size() - 1] / head_dim;
const int num_heads = total_num_head - 2 * kv_num_heads;
return {{token_num, num_heads * head_dim}, qkv_shape};
const int head_dim_qk = key_cache_shape[3];
const int head_dim_v = value_cache_shape[3];
const int q_hidden_size =
qkv_shape[qkv_shape.size() - 1] - kv_num_heads * (head_dim_qk + head_dim_v);
const int num_heads = q_hidden_size / head_dim_qk;
return {{token_num, num_heads * head_dim_v}, qkv_shape};
}

std::vector<paddle::DataType> AppendAttentionInferDtype(
Expand Down Expand Up @@ -865,6 +880,7 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const int max_input_length,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
Expand Down Expand Up @@ -941,6 +957,7 @@ PD_BUILD_OP(append_attention)
"cache_quant_type: std::string",
"use_neox_rotary_style: bool",
"max_input_length: int",
"softmax_scale: float",
"quant_max_bound: float",
"quant_min_bound: float",
"out_linear_in_scale: float",
Expand Down
Loading