-
Notifications
You must be signed in to change notification settings - Fork 3k
Support XPU for auto-paralllel LLaMa #9796
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
#!/bin/bash | ||
|
||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
task_name_or_path="llama2-13b-auto" | ||
|
||
#export XPUAPI_DEBUG=0x1 | ||
#export XPURT_DISPATCH_MODE=PROFILING | ||
export XBLAS_FC_HBM_VERSION=40 | ||
|
||
# PaddlePaddle | ||
export FLAGS_use_stride_kernel="0" | ||
export XPU_PADDLE_L3_SIZE=98566144 # 94 MB | ||
export XPU_CDNN_CLUSTER_PARALLEL=1 | ||
export XPU_CDNN_CLUSTER_PARALLEL_STREAM_NUMBER=2 | ||
|
||
# PDC | ||
unset PADDLE_ELASTIC_JOB_ID | ||
unset PADDLE_TRAINER_ENDPOINTS | ||
unset DISTRIBUTED_TRAINER_ENDPOINTS | ||
unset FLAGS_START_PORT | ||
unset PADDLE_ELASTIC_TIMEOUT | ||
unset PADDLE_TRAINERS_NUM | ||
|
||
# BKCL | ||
# export BKCL_DEBUG=1 | ||
# Multi-computer RDMA | ||
#export BKCL_ENABLE_XDR=1 | ||
#export BKCL_RDMA_FORCE_TREE=1 | ||
#export BKCL_TREE_THRESHOLD=0 | ||
#export BKCL_RDMA_NICS=xgbe1,xgbe1,xgbe2,xgbe2,xgbe3,xgbe3,xgbe4,xgbe4 | ||
#export BKCL_SOCKET_IFNAME=xgbe0 | ||
#export BKCL_FORCE_L3_RDMA=0 | ||
export LD_LIBRARY_PATH=/usr/local/lib:/usr/lib64 | ||
echo "bkcl version:" | ||
strings ${bkcl_location}/libbkcl.so | grep COM | ||
|
||
export CUDA_DEVICE_MAX_CONNECTIONS=8 | ||
|
||
#PYTHONPATH | ||
export PYTHONPATH=../../../:$PYTHONPATH | ||
|
||
# for debug | ||
#export GLOG_v=10 | ||
export FLAGS_call_stack_level=2 | ||
|
||
rm -rf output/$task_name_or_path | ||
PYTHONPATH=../:$PYTHONPATH \ | ||
python -u -m paddle.distributed.launch \ | ||
--xpus "0,1,2,3,4,5,6,7" \ | ||
--log_dir "output/$task_name_or_path/" \ | ||
run_pretrain_auto.py \ | ||
--model_name_or_path "meta-llama/Llama-2-13b" \ | ||
--tokenizer_name_or_path "meta-llama/Llama-2-13b" \ | ||
--input_dir "./data" \ | ||
--output_dir "output/$task_name_or_path" \ | ||
--split 949,50,1 \ | ||
--max_seq_length 4096 \ | ||
--per_device_train_batch_size 1 \ | ||
--per_device_eval_batch_size 1 \ | ||
--use_flash_attention 1 \ | ||
--use_fused_rope 1 \ | ||
--fuse_attention_ffn 1 \ | ||
--fuse_attention_qkv 1 \ | ||
--use_fused_rms_norm 0 \ | ||
--num_hidden_layers 4 \ | ||
--bf16 \ | ||
--fp16_opt_level "O2" \ | ||
--amp_master_grad true \ | ||
--scale_loss 1024 \ | ||
--learning_rate 0.00003 \ | ||
--min_learning_rate 0.000005 \ | ||
--lr_scheduler_type "cosine" \ | ||
--max_steps 10 \ | ||
--save_steps 100000 \ | ||
--weight_decay 0.01 \ | ||
--warmup_ratio 0.01 \ | ||
--max_grad_norm 1.0 \ | ||
--logging_steps 1 \ | ||
--sequence_parallel 0 \ | ||
--dataloader_num_workers 4 \ | ||
--pipeline_parallel_degree 1 \ | ||
--tensor_parallel_degree 1 \ | ||
--gradient_accumulation_steps 1 \ | ||
--eval_steps 1000 \ | ||
--report_to "visualdl" \ | ||
--disable_tqdm true \ | ||
--continue_training 0 \ | ||
--recompute 0 \ | ||
--do_train \ | ||
--seed 1026 \ | ||
--device "xpu" \ | ||
--enable_auto_parallel 1 \ | ||
--to_static 1 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,7 +52,9 @@ | |
CausalLMOutputWithCrossAttentions, | ||
) | ||
from paddlenlp.transformers.model_utils import PretrainedModel, register_base_model | ||
from paddlenlp.utils.tools import get_env_device | ||
|
||
from . import fusion_ops | ||
from .configuration import ( | ||
LLAMA_PRETRAINED_INIT_CONFIGURATION, | ||
LLAMA_PRETRAINED_RESOURCE_FILES_MAP, | ||
|
@@ -69,7 +71,6 @@ | |
build_alibi_tensor, | ||
get_triangle_upper_mask, | ||
repeat_kv, | ||
rms_norm_fused, | ||
) | ||
|
||
try: | ||
|
@@ -218,7 +219,9 @@ | |
|
||
def forward(self, hidden_states): | ||
if self.config.use_fused_rms_norm: | ||
return rms_norm_fused(hidden_states, self.weight, self.variance_epsilon) | ||
return fusion_ops.fusion_rms_norm( | ||
hidden_states, self.weight, self.variance_epsilon, self.config.use_fast_layer_norm | ||
) | ||
|
||
with paddle.amp.auto_cast(False): | ||
variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True) | ||
|
@@ -308,7 +311,7 @@ | |
self.ipp = ipp | ||
|
||
self.use_fused_rope = config.use_fused_rope | ||
if self.use_fused_rope: | ||
if self.use_fused_rope and get_env_device() not in ["npu", "mlu", "xpu", "gcu", "intel_hpu"]: | ||
if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None: | ||
warnings.warn( | ||
"Enable fuse rope in the config, but fuse rope is not available. " | ||
|
@@ -935,7 +938,22 @@ | |
else: | ||
expanded_attn_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) | ||
# Convert bool attention_mask to float attention mask, which will be added to attention_scores later | ||
expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype) | ||
if get_env_device() in ["npu", "mlu", "intel_hpu"]: | ||
x = paddle.to_tensor(0.0, dtype="float32") | ||
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float32") | ||
expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y).astype(dtype) | ||
elif get_env_device() == "xpu": | ||
x = paddle.to_tensor(0.0, dtype="float32") | ||
y = paddle.to_tensor(-1.7005809656952787e38, dtype="float32") | ||
expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y) | ||
elif get_env_device() == "gcu": | ||
min_val = paddle.finfo(dtype).min | ||
x = paddle.to_tensor(0.0, dtype=dtype) | ||
y = paddle.to_tensor(min_val, dtype=dtype) | ||
expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y).astype(dtype) | ||
Comment on lines
+941
to
+953
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's mask generation differs between different devices. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The mask generation logic is same as here: https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/transformers/llama/modeling.py#L1606. For the following two reasons that XPU needs a different mask:
|
||
else: | ||
expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min) | ||
expanded_attn_mask = expanded_attn_mask.astype(dtype) | ||
return expanded_attn_mask | ||
|
||
def forward( | ||
|
@@ -1166,15 +1184,35 @@ | |
masked_lm_labels.unsqueeze(2), | ||
) | ||
|
||
masked_lm_loss = paddle.masked_select(masked_lm_loss, masked_lm_loss > 0).astype("float32") | ||
loss = paddle.mean(masked_lm_loss) | ||
# XPU dose not support allgather mask with bool dtype, so we use LocalLayer here. | ||
if get_env_device() == "xpu": | ||
|
||
class LocalLossLayer(paddle.distributed.LocalLayer): | ||
def __init__(self, out_dist_attrs): | ||
super().__init__(out_dist_attrs) | ||
|
||
def forward(self, x, mask): | ||
masked_lm_loss = paddle.masked_select(x, mask).astype("float32") | ||
loss = paddle.mean(masked_lm_loss) | ||
return loss | ||
|
||
out_dist_attrs = [ | ||
(masked_lm_loss.process_mesh, [dist.Partial(dist.ReduceType.kRedSum), dist.Replicate()]), | ||
] | ||
loss_func = LocalLossLayer(out_dist_attrs) | ||
loss = loss_func(masked_lm_loss, masked_lm_loss > 0) | ||
else: | ||
masked_lm_loss = paddle.masked_select(masked_lm_loss, masked_lm_loss > 0).astype("float32") | ||
loss = paddle.mean(masked_lm_loss) | ||
|
||
return loss | ||
|
||
|
||
class LlamaLMHeadAuto(nn.Layer): | ||
def __init__(self, config: LlamaConfig): | ||
super(LlamaLMHeadAuto, self).__init__() | ||
self.config = config | ||
|
||
vocab_size = config.vocab_size | ||
self.weight = self.create_parameter( | ||
shape=[config.hidden_size, vocab_size], | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
自动并行 modeling_network.py 和 modeling_auto.py的关系是什么,modeling_network.py需要同步修改吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
modeling_network.py是中层API实现,modeling_auto.py是基础API实现,这两套理想状态下应该是合成一份的。这个PR只是昆仑适配的第一步,目前只支持动半纯dp的场景,后边还需要继续测试和迭代,等这块完善之后,后续应该会有专门的线条将 modeling_network.py和modeling_auto.py做合并。