Skip to content

Commit a05f7cd

Browse files
authored
[bugfix] fix megatron seq_cls lora bridge (#7054)
1 parent 851ff69 commit a05f7cd

File tree

3 files changed

+23
-7
lines changed

3 files changed

+23
-7
lines changed

swift/megatron/convert.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22

33
import math
4+
import os
5+
import shutil
46
from contextlib import contextmanager
57
from dataclasses import fields
68
from typing import Any, Dict
@@ -15,7 +17,7 @@
1517
from megatron.training.initialize import initialize_megatron
1618

1719
from swift.llm import ExportArguments, HfConfigFactory, prepare_model_template, to_device, to_float_dtype
18-
from swift.utils import get_logger, get_n_params_grads
20+
from swift.utils import get_logger, get_n_params_grads, is_master
1921
from .argument import MegatronArguments
2022
from .model import get_megatron_model_meta
2123
from .utils import (convert_hf_config, forward_step_helper, get_padding_to, patch_load_base_checkpoint,
@@ -332,6 +334,12 @@ def convert_mcore2hf(args: ExportArguments) -> None:
332334
bridge = megatron_model_meta.bridge_cls()
333335
logger.info('Converting weights and saving the model...')
334336
bridge.save_weights([mg_model], args.output_dir)
337+
if is_master():
338+
args_path = os.path.join(megatron_args.adapter_load or megatron_args.load or args.model, 'args.json')
339+
if os.path.exists(args_path):
340+
shutil.copy(args_path, os.path.join(args.output_dir, 'args.json'))
341+
else:
342+
args.save_args(args.output_dir)
335343
logger.info(f'Successfully saved HF model weights in `{args.output_dir}`.')
336344
if args.test_convert_precision:
337345
hf_model, template = prepare_model_template(args, model=args.output_dir)

swift/megatron/export/export.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def convert_mcore2hf(self) -> None:
6767
shutil.copy(args_path, os.path.join(args.save, 'args.json'))
6868
else:
6969
args.save_args(args.save)
70+
logger.info(f'Successfully saved HF model weights in `{args.save}`.')
7071
if args.test_convert_precision:
7172
with disable_safe_ddp_context_use_barrier():
7273
if save_peft_format:
@@ -114,13 +115,18 @@ def convert_hf2mcore(self) -> None:
114115
logger.info('Merge LoRA...')
115116
mg_model = peft_model.merge_and_unload()
116117
logger.info('Successfully transferred HF model weights to MG model.')
118+
# hf_model does not support loading args.adapter_load, so test_convert_precision cannot be performed
119+
support_convert_precision = args.adapter_load is None
117120
if args.test_convert_precision:
118-
with disable_safe_ddp_context_use_barrier():
119-
device_map = args.device_map or 'auto'
120-
hf_model, template = prepare_model_template(
121-
args, device_map=device_map) if is_last_rank() else (None, template)
122-
test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype)
123-
dist.barrier()
121+
if support_convert_precision:
122+
with disable_safe_ddp_context_use_barrier():
123+
device_map = args.device_map or 'auto'
124+
hf_model, template = prepare_model_template(
125+
args, device_map=device_map) if is_last_rank() else (None, template)
126+
test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype)
127+
dist.barrier()
128+
else:
129+
logger.warning('Skip test_convert_precision because `--adapter_load` is specified.')
124130
args.save_args(args.save)
125131
logger.info('Saving the model...')
126132
save_peft_format = args.train_type == 'lora' and not args.merge_lora

swift/megatron/model/gpt_bridge.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1432,6 +1432,8 @@ def save_weights(self, mg_models, output_dir: str, is_peft_format: bool = False)
14321432
if is_peft_format:
14331433
from swift.llm import get_multimodal_target_regex
14341434
peft_config = copy(mg_models[0].peft_config[self._adapter_name])
1435+
if args.task_type == 'seq_cls':
1436+
peft_config.task_type = 'SEQ_CLS'
14351437
if args.is_multimodal and 'all-linear' in args.target_modules:
14361438
peft_config.target_modules = get_multimodal_target_regex(
14371439
self.hf_model,

0 commit comments

Comments
 (0)