Skip to content

Support XPU for auto-paralllel LLaMa #9796

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions llm/auto_parallel/llama/run_llama2_13b_xpu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#!/bin/bash

# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

task_name_or_path="llama2-13b-auto"

#export XPUAPI_DEBUG=0x1
#export XPURT_DISPATCH_MODE=PROFILING
export XBLAS_FC_HBM_VERSION=40

# PaddlePaddle
export FLAGS_use_stride_kernel="0"
export XPU_PADDLE_L3_SIZE=98566144 # 94 MB
export XPU_CDNN_CLUSTER_PARALLEL=1
export XPU_CDNN_CLUSTER_PARALLEL_STREAM_NUMBER=2

# PDC
unset PADDLE_ELASTIC_JOB_ID
unset PADDLE_TRAINER_ENDPOINTS
unset DISTRIBUTED_TRAINER_ENDPOINTS
unset FLAGS_START_PORT
unset PADDLE_ELASTIC_TIMEOUT
unset PADDLE_TRAINERS_NUM

# BKCL
# export BKCL_DEBUG=1
# Multi-computer RDMA
#export BKCL_ENABLE_XDR=1
#export BKCL_RDMA_FORCE_TREE=1
#export BKCL_TREE_THRESHOLD=0
#export BKCL_RDMA_NICS=xgbe1,xgbe1,xgbe2,xgbe2,xgbe3,xgbe3,xgbe4,xgbe4
#export BKCL_SOCKET_IFNAME=xgbe0
#export BKCL_FORCE_L3_RDMA=0
export LD_LIBRARY_PATH=/usr/local/lib:/usr/lib64
echo "bkcl version:"
strings ${bkcl_location}/libbkcl.so | grep COM

export CUDA_DEVICE_MAX_CONNECTIONS=8

#PYTHONPATH
export PYTHONPATH=../../../:$PYTHONPATH

# for debug
#export GLOG_v=10
export FLAGS_call_stack_level=2

rm -rf output/$task_name_or_path
PYTHONPATH=../:$PYTHONPATH \
python -u -m paddle.distributed.launch \
--xpus "0,1,2,3,4,5,6,7" \
--log_dir "output/$task_name_or_path/" \
run_pretrain_auto.py \
--model_name_or_path "meta-llama/Llama-2-13b" \
--tokenizer_name_or_path "meta-llama/Llama-2-13b" \
--input_dir "./data" \
--output_dir "output/$task_name_or_path" \
--split 949,50,1 \
--max_seq_length 4096 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--use_flash_attention 1 \
--use_fused_rope 1 \
--fuse_attention_ffn 1 \
--fuse_attention_qkv 1 \
--use_fused_rms_norm 0 \
--num_hidden_layers 4 \
--bf16 \
--fp16_opt_level "O2" \
--amp_master_grad true \
--scale_loss 1024 \
--learning_rate 0.00003 \
--min_learning_rate 0.000005 \
--lr_scheduler_type "cosine" \
--max_steps 10 \
--save_steps 100000 \
--weight_decay 0.01 \
--warmup_ratio 0.01 \
--max_grad_norm 1.0 \
--logging_steps 1 \
--sequence_parallel 0 \
--dataloader_num_workers 4 \
--pipeline_parallel_degree 1 \
--tensor_parallel_degree 1 \
--gradient_accumulation_steps 1 \
--eval_steps 1000 \
--report_to "visualdl" \
--disable_tqdm true \
--continue_training 0 \
--recompute 0 \
--do_train \
--seed 1026 \
--device "xpu" \
--enable_auto_parallel 1 \
--to_static 1
17 changes: 17 additions & 0 deletions llm/auto_parallel/llama/run_pretrain_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
print_rank_0,
)
from paddlenlp.trainer.utils.doc import add_start_docstrings
from paddlenlp.utils.tools import get_env_device


@dataclass
Expand Down Expand Up @@ -173,6 +174,11 @@ class ModelArguments:
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)

use_fast_layer_norm: bool = field(
default=False,
metadata={"help": "GPT3 model, use fast layernorm"},
)

config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
Expand Down Expand Up @@ -496,6 +502,8 @@ def main():

config = config_class.from_pretrained(model_args.model_name_or_path)

config.use_fast_layer_norm = model_args.use_fast_layer_norm

config.seq_length = data_args.max_seq_length
# There are some technique extend RotaryEmbedding context. so don't change max_position_embeddings
if not model_args.continue_training:
Expand Down Expand Up @@ -544,6 +552,15 @@ def main():
pipeline = training_args.strategy.pipeline
pipeline.vpp_degree = config.virtual_pp_degree
pipeline.vpp_seg_method = training_args.virtual_pipeline_seg_method
if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1:
try:
from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401

LinearConfig.enable_accumulate_steps_opt()
LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps)
except ImportError:
# It's OK, not use accumulate_steps optimization
pass

print("Final pre-training config:", config)

Expand Down
3 changes: 1 addition & 2 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1632,13 +1632,12 @@ def is_segment_parallel_supported():
"enable_mp_async_allreduce", # allreduce_matmul_grad_overlapping in auto_parallel
"enable_delay_scale_loss",
"replace_with_c_embedding",
# "enable_mp_skip_c_identity",
# "enable_mp_fused_linear_param_grad_add",
"replace_with_parallel_cross_entropy",
]:
raise ValueError(
f"Found unknown tensor parallell config {x}, "
f"accept config is enable_mp_async_allreduce, replace_with_c_embedding, enable_mp_skip_c_identity and enable_mp_fused_linear_param_grad_add"
f"accept config is enable_mp_async_allreduce, replace_with_c_embedding, and enable_mp_fused_linear_param_grad_add"
)
try:
if "enable_mp_async_allreduce" in mp_config:
Expand Down
50 changes: 44 additions & 6 deletions paddlenlp/transformers/llama/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@
CausalLMOutputWithCrossAttentions,
)
from paddlenlp.transformers.model_utils import PretrainedModel, register_base_model
from paddlenlp.utils.tools import get_env_device

from . import fusion_ops
from .configuration import (
LLAMA_PRETRAINED_INIT_CONFIGURATION,
LLAMA_PRETRAINED_RESOURCE_FILES_MAP,
Expand All @@ -69,7 +71,6 @@
build_alibi_tensor,
get_triangle_upper_mask,
repeat_kv,
rms_norm_fused,
)

try:
Expand Down Expand Up @@ -218,7 +219,9 @@

def forward(self, hidden_states):
if self.config.use_fused_rms_norm:
return rms_norm_fused(hidden_states, self.weight, self.variance_epsilon)
return fusion_ops.fusion_rms_norm(

Check warning on line 222 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L222

Added line #L222 was not covered by tests
hidden_states, self.weight, self.variance_epsilon, self.config.use_fast_layer_norm
)

with paddle.amp.auto_cast(False):
variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
Expand Down Expand Up @@ -308,7 +311,7 @@
self.ipp = ipp

self.use_fused_rope = config.use_fused_rope
if self.use_fused_rope:
if self.use_fused_rope and get_env_device() not in ["npu", "mlu", "xpu", "gcu", "intel_hpu"]:

Check warning on line 314 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L314

Added line #L314 was not covered by tests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

自动并行 modeling_network.py 和 modeling_auto.py的关系是什么,modeling_network.py需要同步修改吗?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modeling_network.py是中层API实现,modeling_auto.py是基础API实现,这两套理想状态下应该是合成一份的。这个PR只是昆仑适配的第一步,目前只支持动半纯dp的场景,后边还需要继续测试和迭代,等这块完善之后,后续应该会有专门的线条将 modeling_network.py和modeling_auto.py做合并。

if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None:
warnings.warn(
"Enable fuse rope in the config, but fuse rope is not available. "
Expand Down Expand Up @@ -935,7 +938,22 @@
else:
expanded_attn_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
# Convert bool attention_mask to float attention mask, which will be added to attention_scores later
expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype)
if get_env_device() in ["npu", "mlu", "intel_hpu"]:
x = paddle.to_tensor(0.0, dtype="float32")
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float32")
expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y).astype(dtype)
elif get_env_device() == "xpu":
x = paddle.to_tensor(0.0, dtype="float32")
y = paddle.to_tensor(-1.7005809656952787e38, dtype="float32")
expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y)
elif get_env_device() == "gcu":
min_val = paddle.finfo(dtype).min
x = paddle.to_tensor(0.0, dtype=dtype)
y = paddle.to_tensor(min_val, dtype=dtype)
expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y).astype(dtype)

Check warning on line 953 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L941-L953

Added lines #L941 - L953 were not covered by tests
Comment on lines +941 to +953
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's mask generation differs between different devices.

Copy link
Collaborator Author

@From00 From00 Feb 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mask generation logic is same as here: https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/transformers/llama/modeling.py#L1606.

For the following two reasons that XPU needs a different mask:

  1. The flash_attention kernel implemented in XPU is different than in GPU, which may lead to numeric overflow when the mask value is too small. Therefore, a specific mask number -1.7005809656952787e38 is needed. @runzhech is fixing this issue and we can use paddle.finfo(dtype).min like GPU after fixed.
  2. The flash_attention kernel in XPU requires the mask input to be float32,so the astype(dtype) cannot be added in XPU mask generation.

See these two PRs for more details: #9495, #9652

else:
expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min)
expanded_attn_mask = expanded_attn_mask.astype(dtype)

Check warning on line 956 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L955-L956

Added lines #L955 - L956 were not covered by tests
return expanded_attn_mask

def forward(
Expand Down Expand Up @@ -1166,15 +1184,35 @@
masked_lm_labels.unsqueeze(2),
)

masked_lm_loss = paddle.masked_select(masked_lm_loss, masked_lm_loss > 0).astype("float32")
loss = paddle.mean(masked_lm_loss)
# XPU dose not support allgather mask with bool dtype, so we use LocalLayer here.
if get_env_device() == "xpu":

Check warning on line 1188 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L1188

Added line #L1188 was not covered by tests

class LocalLossLayer(paddle.distributed.LocalLayer):
def __init__(self, out_dist_attrs):
super().__init__(out_dist_attrs)

Check warning on line 1192 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L1190-L1192

Added lines #L1190 - L1192 were not covered by tests

def forward(self, x, mask):
masked_lm_loss = paddle.masked_select(x, mask).astype("float32")
loss = paddle.mean(masked_lm_loss)
return loss

Check warning on line 1197 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L1194-L1197

Added lines #L1194 - L1197 were not covered by tests

out_dist_attrs = [

Check warning on line 1199 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L1199

Added line #L1199 was not covered by tests
(masked_lm_loss.process_mesh, [dist.Partial(dist.ReduceType.kRedSum), dist.Replicate()]),
]
loss_func = LocalLossLayer(out_dist_attrs)
loss = loss_func(masked_lm_loss, masked_lm_loss > 0)

Check warning on line 1203 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L1202-L1203

Added lines #L1202 - L1203 were not covered by tests
else:
masked_lm_loss = paddle.masked_select(masked_lm_loss, masked_lm_loss > 0).astype("float32")
loss = paddle.mean(masked_lm_loss)

Check warning on line 1206 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L1205-L1206

Added lines #L1205 - L1206 were not covered by tests

return loss


class LlamaLMHeadAuto(nn.Layer):
def __init__(self, config: LlamaConfig):
super(LlamaLMHeadAuto, self).__init__()
self.config = config

vocab_size = config.vocab_size
self.weight = self.create_parameter(
shape=[config.hidden_size, vocab_size],
Expand Down
Loading