Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XPU] Support empty_cache on XPUs #9789

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions llm/alignment/ppo/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
speed_metrics,
)
from paddlenlp.transformers import PretrainedModel, PretrainedTokenizer
from paddlenlp.utils import empty_device_cache


class StepTrainer(Trainer):
Expand Down Expand Up @@ -1032,7 +1033,7 @@ def gen_epoch_data():
ptx_batches = [None for _ in range(len(rl_batches))]
self.timers and self.timers("ptx-batch").stop()

paddle.device.cuda.empty_cache()
empty_device_cache()

self.set_train()
for _ in range(self.args.update_iters):
Expand Down Expand Up @@ -1152,7 +1153,7 @@ def train(

# ##### model and optimizer related setting #####
policy_model, value_model = self.init_train_model_opt(max_steps, resume_from_checkpoint)
paddle.device.cuda.empty_cache()
empty_device_cache()

# ##### traing statistic logging #####
# Number of trainable parameters only account for policy_model
Expand Down Expand Up @@ -1208,7 +1209,7 @@ def train(
# with self.enable(self.value_trainer.optimizer):
with self.enable(): # put value optimizer guard in rl_step
rl_info = self.rl_step(rl_batch)
paddle.device.cuda.empty_cache()
empty_device_cache()
self.timers and self.timers("rl_step").stop()

if self.use_ptx:
Expand All @@ -1224,7 +1225,7 @@ def train(
ptx_info = self.ptx_step(ptx_batch)
rl_info.update(ptx_info)
self.timers and self.timers("ptx_step").stop()
paddle.device.cuda.empty_cache()
empty_device_cache()

self.state.global_step += 1
self.state.epoch = epoch + (step + 1) / steps_in_epoch
Expand Down
3 changes: 2 additions & 1 deletion paddlenlp/quantization/quantization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from paddle.nn.quant import weight_quantize

from ..utils.log import logger
from ..utils.memory_utils import empty_device_cache
from .quantization_linear import (
ColumnParallelQuantizationLinear,
QuantizationLinear,
Expand Down Expand Up @@ -150,7 +151,7 @@ def convert_to_quantize_state_dict_without_check(state_dict, quantization_linear
state_dict.update(qlora_state_dict)
del target_weight
gc.collect()
paddle.device.cuda.empty_cache()
empty_device_cache()
return state_dict


Expand Down
20 changes: 10 additions & 10 deletions paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
unwrap_model,
)
from paddlenlp.transformers.utils import dtype_byte_size
from paddlenlp.utils import infohub
from paddlenlp.utils import empty_device_cache, infohub
from paddlenlp.utils.env import (
LORA_WEIGHTS_NAME,
MAX_QUANTIZATION_TIMES,
Expand Down Expand Up @@ -158,7 +158,7 @@ def save_unified_checkpoint(self, model, optimizer, output_dir, signal_dir=None)
if self.args.should_save:
save_model_config(model_to_save, save_directory)

paddle.device.cuda.empty_cache()
empty_device_cache()

if strtobool(os.getenv("FLAG_LLM_PDC", "False")) and self.args.should_save:
world_size = paddle.distributed.get_world_size()
Expand Down Expand Up @@ -195,7 +195,7 @@ def load_unified_checkpoint(self, model, resume_from_checkpoint: str):
load_unified_checkpoint_locally(self.args, model, resume_from_checkpoint, safe_serialization=True)

def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, output_dir, signal_dir):
paddle.device.cuda.empty_cache()
empty_device_cache()

# gather global master_weights status.
global_master_weights = reduce_master_weights_status(master_weights is not None)
Expand Down Expand Up @@ -373,7 +373,7 @@ def save_unified_optimizer(self, model, optimizer, output_dir, signal_dir):
optim_state_dict, shard_optim_file, sharded_optim_index = results[0]
master_weight_state_dict, shard_master_weight_file, sharded_master_weight_index = results[1]

paddle.device.cuda.empty_cache()
empty_device_cache()
save_directory = output_dir
os.makedirs(save_directory, exist_ok=True)
if signal_dir is not None:
Expand Down Expand Up @@ -506,7 +506,7 @@ def unified_checkpoint_into_shards(
Returns:
tuple: state_dict, config, shard_file: file name, sharded_index: map for weight to file name.
"""
paddle.device.cuda.empty_cache()
empty_device_cache()
assert hasattr(model_to_save, "config")

state_dict = get_expected_state_dict(model_to_save, concat_additional_adapter=True)
Expand Down Expand Up @@ -558,7 +558,7 @@ def unified_checkpoint_into_shards(
elif isinstance(model_to_save, PrefixModelForCausalLM):
sharded_index["type"] = "ptuning"

paddle.device.cuda.empty_cache()
empty_device_cache()

return state_dict, shard_file, sharded_index

Expand All @@ -576,7 +576,7 @@ def unified_optimizer_into_shards(
optimizer (Optimizer): optimizer to save.
safe_serialization (bool, optional): safe serialization using safetensors. Defaults to False.
"""
paddle.device.cuda.empty_cache()
empty_device_cache()

# gather global master_weights status.
global_master_weights = reduce_master_weights_status(master_weights is not None)
Expand Down Expand Up @@ -643,7 +643,7 @@ def unified_optimizer_into_shards(
filter_optim_keys,
state_dict if args.use_expert_parallel else None,
)
paddle.device.cuda.empty_cache()
empty_device_cache()

if master_weights is not None:
logger.info("Unified master weight tensor parallel in shards")
Expand All @@ -653,7 +653,7 @@ def unified_optimizer_into_shards(
filter_master_keys,
state_dict if args.use_expert_parallel else None,
)
paddle.device.cuda.empty_cache()
empty_device_cache()

# build index json file
index_optimizer_file, index_master_weight_file = {}, {}
Expand Down Expand Up @@ -704,7 +704,7 @@ def unified_optimizer_into_shards(
else:
sharded_optim_index["master_weights"] = False

paddle.device.cuda.empty_cache()
empty_device_cache()
if master_weights is None:
return [(optim_state_dict, shard_optimizer_file, sharded_optim_index)]
else:
Expand Down
3 changes: 2 additions & 1 deletion paddlenlp/trl/embedding_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
SimpleInfclLoss,
)
from paddlenlp.transformers.embedding_utils import dist_gather_tensor_with_gradient
from paddlenlp.utils import empty_device_cache

__all__ = ["EmbeddingTrainer"]

Expand Down Expand Up @@ -63,7 +64,7 @@ def __init__(self, model_args, **kwargs):
def clear_memory(self):
self.accum_q_features.clear()
self.accum_p_features.clear()
paddle.device.cuda.empty_cache()
empty_device_cache()

def clear_state(self):
self.accum_data.clear()
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .import_utils import *
from .infohub import infohub
from .initializer import to
from .memory_utils import empty_device_cache
from .optimizer import *
from .serialization import load_torch

Expand Down
39 changes: 39 additions & 0 deletions paddlenlp/utils/memory_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# coding:utf-8
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle

from .log import logger
from .tools import get_env_device

__all__ = [
"empty_device_cache",
]


def empty_device_cache():
device = get_env_device()
if device == "gpu":
paddle.device.cuda.empty_cache()
elif device == "xpu":
paddle.device.xpu.empty_cache()
else:
if not getattr(empty_device_cache, "has_warned", False):
logger.warning(
"The current device ({}) does not support empty cache, calling empty_device_cache() will have no effect.".format(
device
)
)
setattr(empty_device_cache, "has_warned", True)
9 changes: 5 additions & 4 deletions slm/examples/RLHF/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
speed_metrics,
)
from paddlenlp.transformers import PretrainedModel, PretrainedTokenizer
from paddlenlp.utils import empty_device_cache


class StepTrainer(Trainer):
Expand Down Expand Up @@ -1032,7 +1033,7 @@ def gen_epoch_data():
ptx_batches = [None for _ in range(len(rl_batches))]
self.timers and self.timers("ptx-batch").stop()

paddle.device.cuda.empty_cache()
empty_device_cache()

self.set_train()
for _ in range(self.args.update_iters):
Expand Down Expand Up @@ -1152,7 +1153,7 @@ def train(

# ##### model and optimizer related setting #####
policy_model, value_model = self.init_train_model_opt(max_steps, resume_from_checkpoint)
paddle.device.cuda.empty_cache()
empty_device_cache()

# ##### traing statistic logging #####
# Number of trainable parameters only account for policy_model
Expand Down Expand Up @@ -1208,7 +1209,7 @@ def train(
# with self.enable(self.value_trainer.optimizer):
with self.enable(): # put value optimizer guard in rl_step
rl_info = self.rl_step(rl_batch)
paddle.device.cuda.empty_cache()
empty_device_cache()
self.timers and self.timers("rl_step").stop()

if self.use_ptx:
Expand All @@ -1224,7 +1225,7 @@ def train(
ptx_info = self.ptx_step(ptx_batch)
rl_info.update(ptx_info)
self.timers and self.timers("ptx_step").stop()
paddle.device.cuda.empty_cache()
empty_device_cache()

self.state.global_step += 1
self.state.epoch = epoch + (step + 1) / steps_in_epoch
Expand Down