Skip to content

Commit aefeb41

Browse files
committed
basically same converter but with different config
1 parent 8e3d195 commit aefeb41

File tree

1 file changed

+197
-0
lines changed

1 file changed

+197
-0
lines changed

Diff for: examples/xglm/convert_dense2langmoe.py

+197
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
"""
2+
Converts a nanotron model to HF format
3+
Command:
4+
torchrun --nproc-per-node=1 convert_dense2moe.py --checkpoint-path=nanotron_weights --save-path=nanotron_moe_weights
5+
"""
6+
7+
import dataclasses
8+
import json
9+
import warnings
10+
from argparse import ArgumentParser
11+
from pathlib import Path
12+
from typing import Optional
13+
14+
from torch import nn
15+
import torch
16+
import nanotron
17+
from nanotron.config.models_config import GPT3Config, GPT3LangMoEConfig
18+
from nanotron.models.gpt3 import GPT3ForTraining, GPTBlock
19+
from nanotron.models.gpt3_langmoe import GPT3LangMoEForTraining, GPT3LangMoEBlock
20+
from nanotron.trainer import mark_tied_parameters
21+
22+
from convert_utils import convert_generic, create_nt_model
23+
24+
25+
def convert_config(config: GPT3Config, num_experts=8, num_languages=32, language_embedding_size=128) -> GPT3LangMoEConfig:
26+
return GPT3LangMoEConfig(
27+
**config.__dict__,
28+
is_moe=True,
29+
moe_num_experts=num_experts,
30+
num_experts_per_tok=min(2, num_experts), # arbitrarily chosen
31+
moe_loss_weight=0.01, # arbitrarily chosen
32+
moe_z_loss_weight=0.001, # arbitrarily chosen
33+
moe_glu=False,
34+
num_languages=num_languages,
35+
language_embedding_size=language_embedding_size,
36+
)
37+
38+
39+
def convert_dense_to_moe(ff_moe: nn.Module, dense_ff: nn.Module, num_experts: int):
40+
with torch.no_grad():
41+
# only copy the weight matrix and repeat it n_expert times
42+
weight_1 = dense_ff.c_fc.weight.clone()
43+
if num_experts == 1:
44+
ff_moe.experts.mlp.w1.module.weight.data = weight_1.contiguous()
45+
else:
46+
# [intermediate_size, hidden_size] -> [hidden_size, intermediate_size * n_experts]
47+
weight_1 = weight_1.T
48+
ff_moe.experts.mlp.w1.module.weight.data = weight_1.repeat(1, num_experts)
49+
50+
weight_2 = dense_ff.c_proj.weight.clone()
51+
if num_experts == 1: # just a specific case for 1 expert
52+
ff_moe.experts.mlp.w2.module.weight.data = weight_2.contiguous()
53+
else:
54+
# [hidden_size, intermediate_size] -> [intermediate_size * n_experts, hidden_size]
55+
weight_2 = weight_2.T
56+
ff_moe.experts.mlp.w2.module.weight.data = weight_2.repeat(num_experts, 1)
57+
58+
# # -- could add bias only for 2nd layer, because that works with the MegaBlocks MoE implementation
59+
# # -- but won't make a big difference?
60+
# ff_moe.experts.bias.copy_(dense_ff.c_proj.bias)
61+
62+
# init gating randomly
63+
nn.init.normal_(ff_moe.gate.layer.weight, mean=0.0, std=0.02)
64+
65+
66+
def convert_decoder(block_moe: GPT3LangMoEBlock, block_nt: GPTBlock, num_experts: int):
67+
convert_generic(block_moe.ln_1, block_nt.ln_1)
68+
convert_generic(block_moe.attn, block_nt.attn)
69+
convert_generic(block_moe.ln_2, block_nt.ln_2)
70+
convert_dense_to_moe(block_moe.ff, block_nt.ff, num_experts)
71+
72+
73+
def convert(
74+
model_moe: GPT3LangMoEForTraining, model_dense: GPT3ForTraining, num_experts: int
75+
):
76+
convert_generic(
77+
model_moe.model.token_embeddings.pp_block.token_embedding,
78+
model_dense.model.token_embeddings.pp_block.token_embedding,
79+
)
80+
# init laguage embedding randomly
81+
nn.init.normal_(model_moe.model.language_embeddings.pp_block.language_embedding.weight, mean=0.0, std=0.02)
82+
for layer_moe, layer_nt in zip(model_moe.model.decoder, model_dense.model.decoder):
83+
convert_decoder(layer_moe.pp_block, layer_nt.pp_block, num_experts)
84+
convert_generic(
85+
model_moe.model.final_layer_norm.pp_block,
86+
model_dense.model.final_layer_norm.pp_block,
87+
)
88+
convert_generic(
89+
model_moe.model.lm_head.pp_block, model_dense.model.lm_head.pp_block
90+
)
91+
92+
93+
def create_nt_moe_model(
94+
model_config: Optional[GPT3Config] = None,
95+
device: torch.device = torch.device("cuda"),
96+
dtype: torch.dtype = torch.bfloat16,
97+
checkpoint_path: Optional[Path] = None,
98+
):
99+
100+
if model_config is None:
101+
assert checkpoint_path is not None
102+
with open(checkpoint_path / "model_config.json") as f:
103+
model_config = GPT3LangMoEConfig(**json.load(f))
104+
105+
parallel_config = nanotron.config.ParallelismArgs(dp=1, pp=1, tp=1)
106+
parallel_context = nanotron.parallel.ParallelContext(
107+
data_parallel_size=parallel_config.dp,
108+
pipeline_parallel_size=parallel_config.pp,
109+
tensor_parallel_size=parallel_config.tp,
110+
)
111+
model_nt = nanotron.models.build_model(
112+
model_builder=lambda: GPT3LangMoEForTraining(
113+
config=model_config,
114+
parallel_context=parallel_context,
115+
parallel_config=parallel_config,
116+
random_states=None,
117+
),
118+
parallel_context=parallel_context,
119+
dtype=dtype,
120+
device=device,
121+
)
122+
mark_tied_parameters(model=model_nt, parallel_context=parallel_context)
123+
124+
if checkpoint_path is not None:
125+
nanotron.serialize.load_weights(
126+
model=model_nt,
127+
parallel_context=parallel_context,
128+
root_folder=checkpoint_path,
129+
)
130+
131+
return model_nt
132+
133+
134+
def main(
135+
checkpoint_path: Path,
136+
save_path: Path,
137+
num_experts: int,
138+
num_languages: int,
139+
language_embedding_size: int,
140+
):
141+
# Load nanotron model.
142+
model_dense = create_nt_model(checkpoint_path=checkpoint_path)
143+
144+
# Init moe model.
145+
model_config_moe = convert_config(model_dense.config, num_experts, num_languages, language_embedding_size)
146+
model_moe = create_nt_moe_model(model_config=model_config_moe)
147+
148+
convert(model_moe, model_dense, num_experts)
149+
nanotron.serialize.save_weights(
150+
model=model_moe,
151+
parallel_context=model_moe.parallel_context,
152+
root_folder=save_path,
153+
)
154+
with open(save_path / "model_config.json", "w+") as f:
155+
json.dump(dataclasses.asdict(model_config_moe), f)
156+
print(f"Model saved to {save_path}")
157+
158+
159+
if __name__ == "__main__":
160+
# fix all random seeds
161+
torch.manual_seed(0)
162+
torch.cuda.manual_seed(0)
163+
torch.cuda.manual_seed_all(0)
164+
torch.backends.cudnn.deterministic = True
165+
parser = ArgumentParser(description="Convert dense weights to moe format")
166+
parser.add_argument(
167+
"--checkpoint-path",
168+
type=Path,
169+
default="checkpoints/xglm-7.5B",
170+
help="Path to the nanotron dense checkpoint",
171+
)
172+
parser.add_argument(
173+
"--save-path",
174+
type=Path,
175+
default="checkpoints/xglm-moe-7.5B",
176+
help="Path to save the nanotron moe model",
177+
)
178+
parser.add_argument(
179+
"--num-experts",
180+
type=int,
181+
default=8,
182+
help="Number of experts in the MoE model (duplicates of MLP layer)",
183+
)
184+
parser.add_argument(
185+
"--num-languages",
186+
type=int,
187+
default=32,
188+
help="Number of languages for the language embedding",
189+
)
190+
parser.add_argument(
191+
"--language-embedding-size",
192+
type=int,
193+
default=128,
194+
help="Size of the language embedding",
195+
)
196+
args = parser.parse_args()
197+
main(args.checkpoint_path, args.save_path, args.num_experts, args.num_languages, args.language_embedding_size)

0 commit comments

Comments
 (0)