Skip to content

Commit 436ed07

Browse files
authored
support megatron llama (#3532)
1 parent 7b309c9 commit 436ed07

File tree

14 files changed

+1534
-1429
lines changed

14 files changed

+1534
-1429
lines changed

docs/source/Instruction/Megatron-SWIFT训练.md

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11

22
# Megatron-SWIFT训练
33

4+
SWIFT引入了Megatron的并行技术来加速大模型的训练,包括数据并行、张量并行、流水线并行、序列并行,上下文并行。支持Megatron训练的模型可以参考[支持的模型与数据集文档](./支持的模型和数据集.md)
5+
46
## 环境准备
57
使用Megatron-SWIFT,除了安装swift依赖外,还需要安装以下内容:
68

@@ -15,7 +17,7 @@ cd apex
1517
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
1618
```
1719

18-
依赖库Megatron-LM将会由swift进行git clone并安装,不需要用户手动安装。你也可以通过环境变量`MEGATRON_LM_PATH`指向已经下载好的repo路径(断网环境)。
20+
依赖库Megatron-LM将会由swift进行git clone并安装,不需要用户手动安装。你也可以通过环境变量`MEGATRON_LM_PATH`指向已经下载好的repo路径(断网环境[core_r0.11.0分支](https://github.com/NVIDIA/Megatron-LM/tree/core_r0.11.0))。
1921

2022

2123
## 快速入门案例
@@ -93,7 +95,7 @@ I am a language model developed by swift, you can call me swift-robot. How can I
9395
```
9496

9597
- 更多案例可以查看[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/megatron)
96-
98+
- 若要进行预训练,你可以使用`megatron pt`替代`megatron sft`,这将会使用生成式的template进行训练。
9799

98100
## 命令行参数
99101

@@ -202,7 +204,6 @@ I am a language model developed by swift, you can call me swift-robot. How can I
202204
- position_embedding_type: 位置编码的类型,可选为'learned_absolute'、'rope'、'relative'和'none',默认为'rope'。
203205
- rotary_base: 默认为10000。
204206
- rotary_percent: 默认为1.。
205-
- rotary_seq_len_interpolation_factor: 序列长度差值系数,默认为None。
206207
- normalization: 可选为'LayerNorm', 'RMSNorm',默认为RMSNorm。
207208
- norm_epsilon: 默认为1e-5。
208209
- swiglu: 使用swiglu替代默认的gelu。默认为True。

docs/source/Instruction/支持的模型和数据集.md

+699-699
Large diffs are not rendered by default.

docs/source_en/Instruction/Megatron-SWIFT-Training.md

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11

22
# Megatron-SWIFT Training
33

4+
SWIFT incorporates Megatron's parallelization techniques to accelerate the training of large models, including data parallelism, tensor parallelism, pipeline parallelism, sequence parallelism, and context parallelism. For models that support Megatron training, please refer to the [Supported Models and Datasets documentation](./Supported-models-and-datasets.md).
5+
46
## Environment Setup
57

68
To use Megatron-SWIFT, in addition to installing the `swift` dependencies, you also need to install the following:
@@ -16,7 +18,7 @@ cd apex
1618
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
1719
```
1820

19-
The dependency library Megatron-LM will be git cloned and installed by swift, no manual installation by the user is required. You can also use the environment variable `MEGATRON_LM_PATH` to point to the already downloaded repo path (for offline environments).
21+
The dependency library Megatron-LM will be git cloned and installed by swift, no manual installation by the user is required. You can also use the environment variable `MEGATRON_LM_PATH` to point to the already downloaded repo path (for offline environments, use the [core_r0.11.0 branch](https://github.com/NVIDIA/Megatron-LM/tree/core_r0.11.0)).
2022

2123

2224
## Quick Start Example
@@ -99,7 +101,7 @@ I am a language model developed by swift, you can call me swift-robot. How can I
99101
```
100102

101103
- More cases can be viewed [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/megatron).
102-
104+
- For pretraining, you can use `megatron pt` instead of `megatron sft`, which will use a generative template for training.
103105

104106
## Command Line Arguments
105107

@@ -215,7 +217,6 @@ I am a language model developed by swift, you can call me swift-robot. How can I
215217
- position_embedding_type: Type of positional embedding, options are 'learned_absolute', 'rope', 'relative', and 'none'. Default is 'rope'.
216218
- rotary_base: Default is 10000.
217219
- rotary_percent: Default is 1.
218-
- rotary_seq_len_interpolation_factor: Sequence length interpolation factor, default is None.
219220
- normalization: Options are 'LayerNorm', 'RMSNorm'. Default is RMSNorm.
220221
- norm_epsilon: Default is 1e-5.
221222
- swiglu: Uses swiglu instead of the default gelu. Default is True.

docs/source_en/Instruction/Supported-models-and-datasets.md

+699-699
Large diffs are not rendered by default.

scripts/utils/run_model_info.py

+34-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Any, List
22

33
from swift.llm import MODEL_MAPPING, TEMPLATE_MAPPING, ModelType, TemplateType
4+
from swift.utils import is_megatron_available
45

56

67
def get_url_suffix(model_id):
@@ -9,17 +10,36 @@ def get_url_suffix(model_id):
910
return model_id
1011

1112

13+
def get_cache_mapping(fpath):
14+
with open(fpath, 'r', encoding='utf-8') as f:
15+
text = f.read()
16+
idx = text.find('| Model ID |')
17+
text = text[idx:]
18+
text_list = text.split('\n')[2:]
19+
cache_mapping = {}
20+
for text in text_list:
21+
if not text:
22+
continue
23+
items = text.split('|')
24+
if len(items) < 6:
25+
break
26+
cache_mapping[items[1]] = items[5]
27+
return cache_mapping
28+
29+
1230
def get_model_info_table():
1331
fpaths = ['docs/source/Instruction/支持的模型和数据集.md', 'docs/source_en/Instruction/Supported-models-and-datasets.md']
32+
cache_mapping = get_cache_mapping(fpaths[0])
1433
end_words = [['### 多模态大模型', '## 数据集'], ['### Multimodal large models', '## Datasets']]
1534
result = [
1635
'| Model ID | Model Type | Default Template | '
17-
'Requires | Tags | HF Model ID |\n'
36+
'Requires | Support Megatron | Tags | HF Model ID |\n'
1837
'| -------- | -----------| ---------------- | '
19-
'-------- | ---- | ----------- |\n'
38+
'-------- | ---------------- | ---- | ----------- |\n'
2039
] * 2
2140
res_llm: List[Any] = []
2241
res_mllm: List[Any] = []
42+
mg_count = 0
2343
for template in TemplateType.get_template_name_list():
2444
assert template in TEMPLATE_MAPPING
2545

@@ -40,12 +60,22 @@ def get_model_info_table():
4060
hf_model_id = '-'
4161
tags = ', '.join(group.tags or model_meta.tags) or '-'
4262
requires = ', '.join(group.requires or model_meta.requires) or '-'
43-
r = (f'|{ms_model_id}|{model_type}|{template}|{requires}|{tags}|{hf_model_id}|\n')
63+
if is_megatron_available():
64+
from swift.megatron import model
65+
support_megatron = getattr(model_meta, 'support_megatron', False)
66+
if 'gptq' in ms_model_id.lower() or 'awq' in ms_model_id.lower() or 'int' in ms_model_id.lower():
67+
support_megatron = False
68+
support_megatron = '&#x2714;' if support_megatron else '&#x2718;'
69+
else:
70+
support_megatron = cache_mapping.get(ms_model_id, '&#x2718;')
71+
if support_megatron == '&#x2714;':
72+
mg_count += 1
73+
r = (f'|{ms_model_id}|{model_type}|{template}|{requires}|{support_megatron}|{tags}|{hf_model_id}|\n')
4474
if model_meta.is_multimodal:
4575
res_mllm.append(r)
4676
else:
4777
res_llm.append(r)
48-
print(f'LLM总数: {len(res_llm)}, MLLM总数: {len(res_mllm)}')
78+
print(f'LLM总数: {len(res_llm)}, MLLM总数: {len(res_mllm)}, Megatron支持模型: {mg_count}')
4979
text = ['', ''] # llm, mllm
5080
for i, res in enumerate([res_llm, res_mllm]):
5181
for r in res:

swift/megatron/argument/megatron_args.py

-1
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@ class MegatronArguments(ExtraMegatronArguments):
9090
position_embedding_type: Literal['learned_absolute', 'rope', 'relative', 'none'] = 'rope'
9191
rotary_base: int = 10000
9292
rotary_percent: float = 1.
93-
rotary_seq_len_interpolation_factor: Optional[int] = None
9493
normalization: Literal['LayerNorm', 'RMSNorm'] = 'RMSNorm'
9594
norm_epsilon: float = 1e-5
9695
swiglu: bool = True

swift/megatron/model/config.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
'padded_vocab_size': ['vocab_size'],
1717
'attention_dropout': ['attention_dropout'],
1818
'untie_embeddings_and_output_weights': ['tie_word_embeddings'],
19-
'swiglu': ['hidden_act']
19+
'swiglu': ['hidden_act'],
20+
'add_qkv_bias': ['attention_bias'],
21+
'disable_bias_linear': ['mlp_bias']
2022
}
2123

2224

@@ -28,7 +30,7 @@ def convert_hf_config(config) -> Dict[str, Any]:
2830
hf_v = getattr(config, hf_k)
2931
if k == 'rotary_base':
3032
megatron_config[k] = int(hf_v)
31-
elif k == 'untie_embeddings_and_output_weights':
33+
elif k in {'untie_embeddings_and_output_weights', 'disable_bias_linear'}:
3234
megatron_config[k] = not hf_v
3335
elif k == 'swiglu':
3436
if hf_v == 'silu':

swift/megatron/model/gpt/__init__.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,22 @@
99

1010
register_megatron_model(
1111
MegatronModelMeta(MegatronModelType.gpt, [
12-
ModelType.qwen, ModelType.qwen2, ModelType.qwen2_5, ModelType.qwq, ModelType.qwq_preview, ModelType.qwen2_5_math
12+
ModelType.qwen,
13+
ModelType.qwen2,
14+
ModelType.qwen2_5,
15+
ModelType.qwq,
16+
ModelType.qwq_preview,
17+
ModelType.qwen2_5_math,
18+
ModelType.llama3,
19+
ModelType.llama,
20+
ModelType.marco_o1,
21+
ModelType.deepseek_r1_distill,
22+
ModelType.yi,
23+
ModelType.yi_coder,
24+
ModelType.sus,
25+
ModelType.skywork_o1,
26+
ModelType.openbuddy_llama,
27+
ModelType.megrez,
28+
ModelType.numina,
29+
ModelType.mengzi3,
1330
], model_provider, convert_hf_config, convert_mcore2hf, convert_hf2mcore))

swift/megatron/model/gpt/hf2mcore.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,14 @@ def set_attn_state(args, mg_layer, hf_layer):
2020
mg_attn.linear_proj.weight.data.copy_(hf_attn.o_proj.weight)
2121

2222
# Copy bias
23-
mg_attn.linear_qkv.bias.data.copy_(
24-
torch.cat([
25-
hf_attn.q_proj.bias.reshape((num_query_groups, -1)),
26-
hf_attn.k_proj.bias.reshape((num_query_groups, -1)),
27-
hf_attn.v_proj.bias.reshape((num_query_groups, -1)),
28-
],
29-
dim=1).reshape(-1))
23+
if args.add_qkv_bias:
24+
mg_attn.linear_qkv.bias.data.copy_(
25+
torch.cat([
26+
hf_attn.q_proj.bias.reshape((num_query_groups, -1)),
27+
hf_attn.k_proj.bias.reshape((num_query_groups, -1)),
28+
hf_attn.v_proj.bias.reshape((num_query_groups, -1)),
29+
],
30+
dim=1).reshape(-1))
3031

3132

3233
def set_mlp_state(args, mg_layer, hf_layer):

swift/megatron/model/gpt/mcore2hf.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@ def set_attn_state(args, mg_layer, hf_layer):
1717
hf_attn.o_proj.weight.data.copy_(mg_attn.linear_proj.weight)
1818

1919
# Copy bias
20-
mg_attn_bias = mg_attn.linear_qkv.bias.reshape((num_query_groups, -1))
21-
hf_attn.q_proj.bias.data.copy_(mg_attn_bias[:, :q_dim].reshape(-1))
22-
hf_attn.k_proj.bias.data.copy_(mg_attn_bias[:, q_dim:-kv_dim].reshape(-1))
23-
hf_attn.v_proj.bias.data.copy_(mg_attn_bias[:, -kv_dim:].reshape(-1))
20+
if args.add_qkv_bias:
21+
mg_attn_bias = mg_attn.linear_qkv.bias.reshape((num_query_groups, -1))
22+
hf_attn.q_proj.bias.data.copy_(mg_attn_bias[:, :q_dim].reshape(-1))
23+
hf_attn.k_proj.bias.data.copy_(mg_attn_bias[:, q_dim:-kv_dim].reshape(-1))
24+
hf_attn.v_proj.bias.data.copy_(mg_attn_bias[:, -kv_dim:].reshape(-1))
2425

2526

2627
def set_mlp_state(args, mg_layer, hf_layer):

swift/megatron/model/gpt/model.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,7 @@ def model_provider(pre_process=True, post_process=True):
2424
position_embedding_type=args.position_embedding_type,
2525
rotary_percent=args.rotary_percent,
2626
rotary_base=args.rotary_base,
27-
rope_scaling=args.use_rope_scaling)
27+
rope_scaling=args.use_rope_scaling,
28+
rope_scaling_factor=args.rope_scaling_factor,
29+
seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor)
2830
return model

swift/megatron/model/register.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,16 @@ class MegatronModelMeta:
2424
model_groups: List[ModelGroup] = field(default_factory=list)
2525

2626

27-
def register_megatron_model(model_meta: MegatronModelMeta, *, exist_ok: bool = False):
28-
megatron_model_type = model_meta.megatron_model_type
29-
for model_type in model_meta.model_types:
30-
model_meta.model_groups += MODEL_MAPPING[model_type].model_groups
27+
def register_megatron_model(megatron_model_meta: MegatronModelMeta, *, exist_ok: bool = False):
28+
megatron_model_type = megatron_model_meta.megatron_model_type
29+
for model_type in megatron_model_meta.model_types:
30+
model_meta = MODEL_MAPPING[model_type]
31+
model_meta.support_megatron = True
32+
megatron_model_meta.model_groups += model_meta.model_groups
3133
if not exist_ok and megatron_model_type in MEGATRON_MODEL_MAPPING:
3234
raise ValueError(f'The `{megatron_model_type}` has already been registered in the MODEL_MAPPING.')
3335

34-
MEGATRON_MODEL_MAPPING[megatron_model_type] = model_meta
36+
MEGATRON_MODEL_MAPPING[megatron_model_type] = megatron_model_meta
3537

3638

3739
def get_megatron_model_meta(model_id_or_path: str) -> Optional[MegatronModelMeta]:

swift/megatron/utils/convert.py

+2
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def convert_hf2mcore(args: ExportArguments) -> None:
6262
kwargs = args.get_model_kwargs()
6363
hf_model, processor = get_model_tokenizer(**kwargs)
6464
megatron_model_meta = get_megatron_model_meta(args.model)
65+
assert megatron_model_meta is not None, f'Model: {args.model} is not supported.'
6566
kwargs = megatron_model_meta.convert_hf_config(processor.model_info.config)
6667
megatron_args = MegatronArguments(**kwargs, **convert_kwargs, save=args.output_dir)
6768
patch_megatron_tokenizer(processor)
@@ -83,6 +84,7 @@ def convert_mcore2hf(args: ExportArguments) -> None:
8384
kwargs = args.get_model_kwargs()
8485
hf_model, processor = get_model_tokenizer(**kwargs)
8586
megatron_model_meta = get_megatron_model_meta(args.model)
87+
assert megatron_model_meta is not None, f'Model: {args.model} is not supported.'
8688
kwargs = megatron_model_meta.convert_hf_config(processor.model_info.config)
8789
megatron_args = MegatronArguments(**kwargs, **convert_kwargs, load=args.mcore_model)
8890
patch_megatron_tokenizer(processor)

tests/megatron/test_align/test_llm.py

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import os
2+
3+
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
4+
5+
6+
def _test_model(model_id):
7+
from swift.llm import export_main, ExportArguments
8+
export_main(ExportArguments(model=model_id, to_mcore=True, exist_ok=True, test_convert_precision=True))
9+
10+
11+
def test_llama2():
12+
_test_model('modelscope/Llama-2-7b-chat-ms')
13+
14+
15+
def test_llama3():
16+
_test_model('LLM-Research/Meta-Llama-3-8B-Instruct')
17+
18+
19+
def test_marco_o1():
20+
_test_model('AIDC-AI/Marco-o1')
21+
22+
23+
def test_deepseek_r1_llama():
24+
# TODO: FIX rope
25+
_test_model('deepseek-ai/DeepSeek-R1-Distill-Llama-8B')
26+
27+
28+
def test_deepseek_r1_qwen():
29+
_test_model('deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B')
30+
31+
32+
def test_yi():
33+
_test_model('01ai/Yi-1.5-6B-Chat')
34+
35+
36+
def test_megrez():
37+
_test_model('InfiniAI/Megrez-3b-Instruct')
38+
39+
40+
if __name__ == '__main__':
41+
# test_llama2()
42+
# test_llama3()
43+
# test_marco_o1()
44+
# test_deepseek_r1_llama()
45+
# test_deepseek_r1_qwen()
46+
# test_yi()
47+
test_megrez()

0 commit comments

Comments
 (0)