Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion swift/megatron/convert.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import math
import os
import shutil
from contextlib import contextmanager
from dataclasses import fields
from typing import Any, Dict
Expand All @@ -15,7 +17,7 @@
from megatron.training.initialize import initialize_megatron

from swift.llm import ExportArguments, HfConfigFactory, prepare_model_template, to_device, to_float_dtype
from swift.utils import get_logger, get_n_params_grads
from swift.utils import get_logger, get_n_params_grads, is_master
from .argument import MegatronArguments
from .model import get_megatron_model_meta
from .utils import (convert_hf_config, forward_step_helper, get_padding_to, patch_load_base_checkpoint,
Expand Down Expand Up @@ -332,6 +334,12 @@ def convert_mcore2hf(args: ExportArguments) -> None:
bridge = megatron_model_meta.bridge_cls()
logger.info('Converting weights and saving the model...')
bridge.save_weights([mg_model], args.output_dir)
if is_master():
args_path = os.path.join(megatron_args.adapter_load or megatron_args.load or args.model, 'args.json')
if os.path.exists(args_path):
shutil.copy(args_path, os.path.join(args.output_dir, 'args.json'))
else:
args.save_args(args.output_dir)
logger.info(f'Successfully saved HF model weights in `{args.output_dir}`.')
if args.test_convert_precision:
hf_model, template = prepare_model_template(args, model=args.output_dir)
Expand Down
18 changes: 12 additions & 6 deletions swift/megatron/export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def convert_mcore2hf(self) -> None:
shutil.copy(args_path, os.path.join(args.save, 'args.json'))
else:
args.save_args(args.save)
logger.info(f'Successfully saved HF model weights in `{args.save}`.')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This log message will be printed by all ranks in a distributed environment, which can lead to cluttered logs. It's better to guard this with a rank check, similar to the log message on line 135 (logger.info_if(..., cond=is_last_rank())), to ensure the message is printed only once.

Suggested change
logger.info(f'Successfully saved HF model weights in `{args.save}`.')
logger.info_if(f'Successfully saved HF model weights in `{args.save}`.', cond=is_last_rank())

if args.test_convert_precision:
with disable_safe_ddp_context_use_barrier():
if save_peft_format:
Expand Down Expand Up @@ -114,13 +115,18 @@ def convert_hf2mcore(self) -> None:
logger.info('Merge LoRA...')
mg_model = peft_model.merge_and_unload()
logger.info('Successfully transferred HF model weights to MG model.')
# hf_model does not support loading args.adapter_load, so test_convert_precision cannot be performed
support_convert_precision = args.adapter_load is None
if args.test_convert_precision:
with disable_safe_ddp_context_use_barrier():
device_map = args.device_map or 'auto'
hf_model, template = prepare_model_template(
args, device_map=device_map) if is_last_rank() else (None, template)
test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype)
dist.barrier()
if support_convert_precision:
with disable_safe_ddp_context_use_barrier():
device_map = args.device_map or 'auto'
hf_model, template = prepare_model_template(
args, device_map=device_map) if is_last_rank() else (None, template)
test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype)
dist.barrier()
else:
logger.warning('Skip test_convert_precision because `--adapter_load` is specified.')
args.save_args(args.save)
logger.info('Saving the model...')
save_peft_format = args.train_type == 'lora' and not args.merge_lora
Expand Down
2 changes: 2 additions & 0 deletions swift/megatron/model/gpt_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1432,6 +1432,8 @@ def save_weights(self, mg_models, output_dir: str, is_peft_format: bool = False)
if is_peft_format:
from swift.llm import get_multimodal_target_regex
peft_config = copy(mg_models[0].peft_config[self._adapter_name])
if args.task_type == 'seq_cls':
peft_config.task_type = 'SEQ_CLS'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better maintainability and to avoid potential issues with string literals, it's recommended to use the TaskType enum from the peft library directly. This makes the code more robust to potential changes in the library.

To implement this, you would add from peft import TaskType with the other local imports and then change this line to:

peft_config.task_type = TaskType.SEQ_CLS

if args.is_multimodal and 'all-linear' in args.target_modules:
peft_config.target_modules = get_multimodal_target_regex(
self.hf_model,
Expand Down
Loading