Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 183 additions & 26 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import warnings
from collections import OrderedDict
from collections.abc import Mapping
from copy import deepcopy
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

Expand All @@ -54,6 +55,9 @@
except:
core = None
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
DygraphShardingOptimizerV2,
)
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer import (
HybridParallelOptimizer,
)
Expand Down Expand Up @@ -102,6 +106,8 @@
except:
pass

from paddle.distributed.flex_checkpoint.dcp.sharded_weight import ShardedWeight

from ..transformers.context_parallel_utils import split_inputs_sequence_dim_load_balance
from ..transformers.model_utils import (
PretrainedModel,
Expand Down Expand Up @@ -226,6 +232,11 @@ def in_auto_parallel_align_mode():
return False


MODEL_STATE_DIC = "model_state"
OPTIMIZER_STATE_DIC = "optimizer_state"
MASTER_WEIGHT_DIC = "master_weight"


__all__ = ["Trainer"]


Expand Down Expand Up @@ -842,6 +853,127 @@ def create_zcc_manager(self, unwrapped_model, resume_from_checkpoint=None):

logger.info("Create zero cost checkpoint manager done.")

def _load_flex_checkpoint(self, resume_from_checkpoint):
def get_metadata_file_name(path):
files = os.listdir(path)
metadata_files = [f for f in files if f.endswith(".metadata")]
assert len(metadata_files) > 0, f"Found no metadata files in {path}"
assert len(metadata_files) == 1, f"Found multiple metadata files in {path}"
return metadata_files[0]

model_sharded_state_dict = self.model.sharded_state_dict()
master_weights_path = os.path.join(resume_from_checkpoint, MASTER_WEIGHT_DIC)
opt_states_path = os.path.join(resume_from_checkpoint, OPTIMIZER_STATE_DIC)
model_states_path = os.path.join(resume_from_checkpoint, MODEL_STATE_DIC)
if not self.args.ignore_load_lr_and_optim:
state_dict_metadata = {}
metadata_paths = [
os.path.join(model_states_path, get_metadata_file_name(model_states_path)),
os.path.join(opt_states_path, get_metadata_file_name(opt_states_path)),
os.path.join(master_weights_path, get_metadata_file_name(master_weights_path)),
]

for metadata_file in metadata_paths:
if not os.path.exists(metadata_file):
raise FileNotFoundError(f"Metadata file not found: {metadata_file}")
metadata = paddle.load(metadata_file)
state_dict_metadata.update(metadata.state_dict_metadata)

init_optimizer(self.optimizer, model_sharded_state_dict, state_dict_metadata)

optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
for k, v in optimizer_sharded_state_dict.items():
v.local_tensor._clear_to_zero_allocation()

if isinstance(self.optimizer._inner_opt, DygraphShardingOptimizerV2):
color_to_comm_buffer_list = self.optimizer._color_to_comm_buffer_list
for color, _comm_buffer_list in color_to_comm_buffer_list.items():
for comm_buffer in _comm_buffer_list:
comm_buffer._clear_param_storage()
else:
state_dict = self.model.state_dict()
for k, v in state_dict.items():
v._clear_to_zero_allocation()

opt_states = {}
master_weights = {}
for k, v in optimizer_sharded_state_dict.items():
if k.endswith(".w_0"):
master_weights[k] = v
else:
opt_states[k] = v

for k, v in opt_states.items():
new_v = ShardedWeight(
key=v.key,
local_tensor=paddle.zeros_like(v.local_tensor),
local_shape=deepcopy(v.local_shape),
global_shape=deepcopy(v.global_shape),
global_offset=deepcopy(v.global_offset),
is_flattened=v.is_flattened,
flattened_range=deepcopy(v.flattened_range),
)
opt_states[k] = new_v

dist.load_state_dict(
opt_states,
opt_states_path,
aoa_config=self.args.aoa_config,
)

optimizer_state_pin = {}
for k, v in opt_states.items():
optimizer_state_pin[k] = v.local_tensor.pin_memory()
del opt_states
for k, v in master_weights.items():
new_v = ShardedWeight(
key=v.key,
local_tensor=paddle.zeros_like(v.local_tensor),
local_shape=deepcopy(v.local_shape),
global_shape=deepcopy(v.global_shape),
global_offset=deepcopy(v.global_offset),
is_flattened=v.is_flattened,
flattened_range=deepcopy(v.flattened_range),
)
master_weights[k] = new_v

dist.load_state_dict(
master_weights,
master_weights_path,
aoa_config=self.args.aoa_config,
)

master_weights_pin = {}
for k, v in master_weights.items():
master_weights_pin[k] = v.local_tensor.pin_memory()
del master_weights

optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
optimizer_sharded_state_dict_pin = {**master_weights_pin, **optimizer_state_pin}

for k, v in optimizer_sharded_state_dict.items():
source_tensor = optimizer_sharded_state_dict_pin[k]
v.local_tensor.set_value(source_tensor)

if isinstance(self.optimizer._inner_opt, DygraphShardingOptimizerV2):
color_to_comm_buffer_list = self.optimizer._color_to_comm_buffer_list
for color, _comm_buffer_list in color_to_comm_buffer_list.items():
for comm_buffer in _comm_buffer_list:
comm_buffer._reset_param_storage()
else:
state_dict = self.model.state_dict()
for k, v in state_dict.items():
new_v = paddle.zeros_like(v)
v.set_value(new_v)

self._load_scheduler(resume_from_checkpoint)

dist.load_state_dict(
model_sharded_state_dict,
model_states_path,
aoa_config=self.args.aoa_config,
)

def train(
self,
resume_from_checkpoint: Optional[Union[str, bool]] = None,
Expand Down Expand Up @@ -975,28 +1107,8 @@ def train(
self.model_wrapped = model
if delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)

if resume_from_checkpoint is not None:
if not self.args.ignore_load_lr_and_optim:
model_sharded_state_dict = self.model.sharded_state_dict()
accessible_files = os.listdir(resume_from_checkpoint)
metadata_files = [file for file in accessible_files if file.endswith(".metadata")]
assert len(metadata_files) == 1, "Only support one metadata file now."
metadata = paddle.load(os.path.join(resume_from_checkpoint, metadata_files[0]))
state_dict_metadata = metadata.state_dict_metadata
init_optimizer(self.optimizer, model_sharded_state_dict, state_dict_metadata)
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
sharded_state_dict = {**model_sharded_state_dict, **optimizer_sharded_state_dict}
dist.load_state_dict(
sharded_state_dict, resume_from_checkpoint, aoa_config=self.args.aoa_config
)
self._load_scheduler(resume_from_checkpoint)
else:
model_sharded_state_dict = self.model.sharded_state_dict()
sharded_state_dict = model_sharded_state_dict
dist.load_state_dict(
sharded_state_dict, resume_from_checkpoint, aoa_config=self.args.aoa_config
)
self._load_flex_checkpoint(resume_from_checkpoint)
else:
model = self._wrap_model(self.model_wrapped)
# for the rest of this function `model` is the outside model, whether it was wrapped or not
Expand Down Expand Up @@ -2794,7 +2906,12 @@ def _save_checkpoint(self, model, metrics=None):

if self.args.save_checkpoint_format == "flex_checkpoint":
model_sharded_state_dict = self.model.sharded_state_dict()
os.makedirs(output_dir, exist_ok=True)
model_state_dict_path = os.path.join(output_dir, MODEL_STATE_DIC)
os.makedirs(model_state_dict_path, exist_ok=True)
dist.save_state_dict(
model_sharded_state_dict,
model_state_dict_path,
)

# Determine the new best metric / best model checkpoint
if metrics is not None and self.args.metric_for_best_model is not None:
Expand Down Expand Up @@ -2858,10 +2975,26 @@ def _save_checkpoint(self, model, metrics=None):
)
else:
if self.args.save_checkpoint_format == "flex_checkpoint":
optimizer_state_dict_path = os.path.join(output_dir, OPTIMIZER_STATE_DIC)
optimizer_states = {}
master_weights = {}

model_sharded_state_dict = self.model.sharded_state_dict()
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
for k, v in optimizer_sharded_state_dict.items():
if k.endswith(".w_0"):
master_weights[k] = v
else:
optimizer_states[k] = v

dist.save_state_dict(
{**model_sharded_state_dict, **optimizer_sharded_state_dict},
output_dir,
optimizer_states,
optimizer_state_dict_path,
)
master_weights_path = os.path.join(output_dir, MASTER_WEIGHT_DIC)
dist.save_state_dict(
master_weights,
master_weights_path,
)
if self.args.should_save:
if self.tokenizer is not None and self.args.save_tokenizer:
Expand Down Expand Up @@ -2919,10 +3052,34 @@ def _save_checkpoint(self, model, metrics=None):
)
elif self.args.save_checkpoint_format == "flex_checkpoint":
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
model_sharded_state_dict = self.model.sharded_state_dict()
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
model_state_dict_path = os.path.join(output_dir, MODEL_STATE_DIC)
os.makedirs(model_state_dict_path, exist_ok=True)
dist.save_state_dict(
{**model_sharded_state_dict, **optimizer_sharded_state_dict},
output_dir,
model_sharded_state_dict,
model_state_dict_path,
)
if not self.args.ignore_save_lr_and_optim:
optimizer_state_dict_path = os.path.join(output_dir, OPTIMIZER_STATE_DIC)
optimizer_states = {}
master_weights = {}
for k, v in optimizer_sharded_state_dict.items():
if k.endswith(".w_0"):
master_weights[k] = v
else:
optimizer_states[k] = v

dist.save_state_dict(
optimizer_states,
optimizer_state_dict_path,
)

master_weights_path = os.path.join(output_dir, MASTER_WEIGHT_DIC)
dist.save_state_dict(
master_weights,
master_weights_path,
)
if self.args.should_save:
if self.tokenizer is not None and self.args.save_tokenizer:
self.tokenizer.save_pretrained(output_dir)
Expand Down
Loading