Skip to content

Commit 3bce1f4

Browse files
Moe converters (#5)
* Converters ready * Added xglm transformers implementation --------- Co-authored-by: Negar Foroutan <[email protected]>
1 parent 328b8c2 commit 3bce1f4

File tree

6 files changed

+1596
-1
lines changed

6 files changed

+1596
-1
lines changed

examples/xglm/README.md

+5
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,8 @@ cd examples/xglm
2525
torchrun --nproc-per-node=1 convert_dense2moe.py --checkpoint-path=checkpoints/xglm-564M --save-path=$SCRATCH/checkpoints/xglm-8x564M --num-experts=8
2626
```
2727
Note that this upcycling _drops_ the bias parameters of the MLP because the MegaBlocks implementation does not support bias parameters. While this is a limitation of the current implementation, the performance is quickly recovered after a few training steps.
28+
29+
To save back to huggingface format use
30+
```bash
31+
torchrun examples/xglm/convert_ntmoe2hf.py --checkpoint-path=$SCRATCH/checkpoints/xglm-8x564M --save-path=$SCRATCH/checkpoints/huggingface/xglm-8x56fM
32+
```

examples/xglm/convert_ntmoe2hf.py

+140
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
"""
2+
Converts a nanotron moe model to HF format
3+
Command:
4+
torchrun --nproc-per-node=1 convert_nt2hf.py --checkpoint-path=nanotron_weights --save-path=hf_weights
5+
"""
6+
7+
import warnings
8+
from argparse import ArgumentParser
9+
from pathlib import Path
10+
from typing import Optional
11+
12+
import torch
13+
from transformers import AutoTokenizer
14+
from tqdm import tqdm
15+
16+
from nanotron.config.models_config import GPT3MoEConfig
17+
from nanotron.models.gpt3_moe import GPT3MoEForTraining, GPT3MoEBlock
18+
from nanotron.models.moe import dMoE, SparseMLP, LearnedRouter
19+
20+
from examples.xglm.convert_dense2moe import create_nt_moe_model
21+
from examples.xglm.convert_nt2hf import convert_attention
22+
from examples.xglm.convert_utils import convert_generic
23+
from examples.xglm.transformers_impl.xglm_model import XGLMForCausalLM, XGLMDecoderLayer, XGLMmoeConfig, XGLMSparseMoeBlock, XGLMMLP
24+
from examples.xglm.transformers_impl.gating import BasicGate
25+
26+
27+
def convert_config(config: GPT3MoEConfig) -> XGLMmoeConfig:
28+
if config.embd_pdrop != config.resid_pdrop:
29+
warnings.warn(
30+
f"nanotron.embd_pdrop = {config.embd_pdrop} does not match with "
31+
f"nanotron.resid_pdrop = {config.resid_pdrop}. "
32+
"XGLM implementation needs these two values to be equal "
33+
"for correct conversion."
34+
)
35+
if config.layer_norm_epsilon != 1e-5:
36+
warnings.warn(f"nanotron.layer_norm_epsilon must be 1e-5, not {config.layer_norm_epsilon}")
37+
if config.moe_z_loss_weight != 0:
38+
warnings.warn(f"transformer implementation does not support z loss")
39+
assert not config.moe_glu, "Transformer implementation does not support glu MLP layers"
40+
41+
return XGLMmoeConfig(
42+
# Regular xglm config.
43+
activation_function=config.activation_function,
44+
attention_dropout=config.attn_pdrop,
45+
dropout=config.embd_pdrop,
46+
eos_token_id=config.eos_token_id,
47+
d_model=config.hidden_size,
48+
ffn_dim=config.intermediate_size,
49+
max_position_embeddings=config.max_position_embeddings,
50+
attention_heads=config.num_attention_heads,
51+
num_layers=config.num_hidden_layers,
52+
vocab_size=config.vocab_size,
53+
decoder_start_token_id=config.position_embedding_offset,
54+
activation_dropout=config.act_pdrop,
55+
scale_embedding=config.scale_embedding,
56+
# Moe specifics.
57+
num_local_experts=config.moe_num_experts,
58+
num_experts_per_tok=config.num_experts_per_tok,
59+
gate_type="linear",
60+
gate_depth=1,
61+
router_aux_loss_coef=config.moe_loss_weight,
62+
)
63+
64+
65+
def convert_mlp(mlp_hf: XGLMMLP, mlp_nt: SparseMLP):
66+
convert_generic(mlp_hf.fc1, mlp_nt.w1.module)
67+
convert_generic(mlp_hf.fc2, mlp_nt.w2.module)
68+
69+
70+
def convert_gate(gate_hf: BasicGate, gate_nt: LearnedRouter):
71+
convert_generic(gate_hf.gate, gate_nt.layer)
72+
73+
74+
def convert_ff(ff_hf: XGLMSparseMoeBlock, ff_nt: dMoE):
75+
convert_gate(ff_hf.gate, ff_nt.gate)
76+
int_size = ff_nt.config.intermediate_size
77+
if len(ff_hf.experts) == 1:
78+
assert ff_nt.experts.mlp.w1.module.weight.shape == (int_size*len(ff_hf.experts), ff_nt.config.hidden_size)
79+
assert ff_nt.experts.mlp.w2.module.weight.shape == (ff_nt.config.hidden_size, int_size*len(ff_hf.experts))
80+
else:
81+
assert ff_nt.experts.mlp.w1.module.weight.T.shape == (int_size*len(ff_hf.experts), ff_nt.config.hidden_size)
82+
assert ff_nt.experts.mlp.w2.module.weight.shape == (int_size*len(ff_hf.experts), ff_nt.config.hidden_size)
83+
84+
for i, expert_hf in enumerate(ff_hf.experts):
85+
i0 = i*int_size
86+
i1 = (i + 1)*int_size
87+
with torch.no_grad():
88+
if len(ff_hf.experts) == 1:
89+
expert_hf.fc1.weight.copy_(ff_nt.experts.mlp.w1.module.weight[i0:i1, :].clone())
90+
expert_hf.fc2.weight.copy_(ff_nt.experts.mlp.w2.module.weight[:, i0:i1].clone())
91+
else:
92+
expert_hf.fc1.weight.copy_(ff_nt.experts.mlp.w1.module.weight.T[i0:i1, :].clone())
93+
expert_hf.fc2.weight.copy_(ff_nt.experts.mlp.w2.module.weight[i0:i1, :].T.clone())
94+
95+
def convert_decoder(block_hf: XGLMDecoderLayer, block_nt: GPT3MoEBlock):
96+
convert_generic(block_hf.self_attn_layer_norm, block_nt.ln_1)
97+
convert_attention(block_hf.self_attn, block_nt.attn)
98+
convert_generic(block_hf.final_layer_norm, block_nt.ln_2)
99+
convert_ff(block_hf.block_sparse_moe, block_nt.ff)
100+
101+
102+
def convert(model_hf: XGLMForCausalLM, model_nt: GPT3MoEForTraining):
103+
convert_generic(model_hf.model.embed_tokens, model_nt.model.token_embeddings.pp_block.token_embedding)
104+
for layer_hf, layer_nt in tqdm(zip(model_hf.model.layers, model_nt.model.decoder), desc="Converting layers",
105+
total=model_nt.config.num_hidden_layers):
106+
convert_decoder(layer_hf, layer_nt.pp_block)
107+
convert_generic(model_hf.model.layer_norm, model_nt.model.final_layer_norm.pp_block)
108+
convert_generic(model_hf.lm_head, model_nt.model.lm_head.pp_block)
109+
110+
111+
def main(checkpoint_path: Path, save_path: Path, tokenizer_name: Optional[str]):
112+
# Load nanotron model.
113+
model_nt = create_nt_moe_model(checkpoint_path=checkpoint_path)
114+
115+
# Init huggingface model.
116+
model_config_hf = convert_config(model_nt.config)
117+
model_hf = XGLMForCausalLM._from_config(model_config_hf)
118+
119+
# Copy weights, initialize tokenizer and save model.
120+
if tokenizer_name is not None:
121+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
122+
tokenizer.save_pretrained(save_path)
123+
states = torch.randn(4, 1, 1024)
124+
convert(model_hf, model_nt), states.cuda().bfloat16()
125+
print("Saving...")
126+
model_hf.save_pretrained(save_path)
127+
print(f"Model saved to {save_path}")
128+
129+
130+
if __name__ == "__main__":
131+
parser = ArgumentParser(description="Convert HF weights to nanotron format")
132+
parser.add_argument(
133+
"--checkpoint-path", type=Path, default="checkpoints/xglm-7.5B", help="Path to the nanotron checkpoint"
134+
)
135+
parser.add_argument(
136+
"--save-path", type=Path, default="facebook/xglm-7.5B", help="Path to save the huggingface model"
137+
)
138+
parser.add_argument("--tokenizer-name", type=str, default="facebook/xglm-7.5B")
139+
args = parser.parse_args()
140+
ret = main(args.checkpoint_path, args.save_path, args.tokenizer_name)

examples/xglm/tests/test_moe.py

+182
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
import torch
2+
import pytest
3+
4+
import nanotron
5+
from nanotron.config.parallelism_config import ParallelismArgs
6+
from nanotron.config.models_config import GPT3MoEConfig
7+
from nanotron.parallel import ParallelContext
8+
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
9+
from nanotron.trainer import mark_tied_parameters
10+
from nanotron.models.gpt3_moe import GPT3MoEBlock, GPT3MoEForTraining
11+
from nanotron.models.moe import LearnedRouter, dMoE
12+
13+
from tests.helpers.utils import init_distributed
14+
15+
from examples.xglm.convert_ntmoe2hf import convert_config, convert_gate, convert_ff, convert
16+
from examples.xglm.tests.test_implementation import almost_close
17+
from examples.xglm.transformers_impl.xglm_model import XGLMSparseMoeBlock, XGLMForCausalLM
18+
from examples.xglm.transformers_impl.gating import BasicGate
19+
20+
21+
MAX_SEQUENCE_LENGTH = 2048
22+
TEST_SEQUENCE_LENGTH = 128 # If we test with a very large sequence length, precision errors get more significant independent of the correct implementation.
23+
#TEST_SEQUENCE_LENGTH = MAX_SEQUENCE_LENGTH
24+
BATCH_SIZE = 4
25+
HIDDEN_SIZE = 1024
26+
#DTYPE = torch.bfloat16
27+
DTYPE = torch.float32
28+
TEXT = "Hello. This is a relatively long text. I will use this text to test the conversion scripts. Let's finish this text soon because I don't have much more to say. Final note:"
29+
30+
CONFIG = GPT3MoEConfig(
31+
attn_pdrop=0.0,
32+
embd_pdrop=0.0,
33+
resid_pdrop=0.0,
34+
act_pdrop=0.0,
35+
eos_token_id=2,
36+
hidden_size=HIDDEN_SIZE,
37+
intermediate_size=4096,
38+
layer_norm_epsilon=1e-05,
39+
max_position_embeddings=MAX_SEQUENCE_LENGTH,
40+
num_attention_heads=16,
41+
num_hidden_layers=24,
42+
scale_attn_weights=True,
43+
vocab_size=256008,
44+
sinusoidal_position_embedding=True,
45+
position_embedding_offset=2,
46+
use_spda=DTYPE is not torch.bfloat16,
47+
# vvv moe vvv
48+
is_moe=True,
49+
moe_num_experts=8,
50+
num_experts_per_tok=2,
51+
moe_loss_weight=0.01,
52+
moe_z_loss_weight=0.0,
53+
moe_glu=False,
54+
)
55+
PARALLEL_CONFIG = ParallelismArgs(dp=1, pp=1, tp=1, expert_parallel_size=1) #CONFIG.moe_num_experts)
56+
57+
58+
@pytest.fixture
59+
def hidden_states() -> torch.Tensor:
60+
return torch.randn(TEST_SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE, dtype=DTYPE)
61+
62+
63+
@pytest.fixture
64+
def input_mask() -> torch.Tensor:
65+
return torch.ones(BATCH_SIZE, TEST_SEQUENCE_LENGTH, dtype=torch.bool)
66+
67+
68+
@pytest.fixture
69+
def input_ids() -> torch.Tensor:
70+
return torch.randint(0, CONFIG.vocab_size, (BATCH_SIZE, TEST_SEQUENCE_LENGTH))
71+
72+
73+
def _test_nt2hf_gate(parallel_context: ParallelContext, hidden_states: torch.Tensor):
74+
hidden_states = hidden_states.cuda()
75+
76+
config_hf = convert_config(CONFIG)
77+
gate_nt = LearnedRouter(CONFIG).cuda().to(DTYPE)
78+
gate_hf = BasicGate(config_hf).cuda().to(DTYPE)
79+
convert_gate(gate_hf, gate_nt)
80+
81+
router_logits_nt, _, _ = gate_nt(hidden_states.view(-1, HIDDEN_SIZE))
82+
router_logits_hf = gate_hf(hidden_states.permute(1, 0, 2).reshape(-1, HIDDEN_SIZE), "")
83+
84+
router_logits_nt = router_logits_nt.view(TEST_SEQUENCE_LENGTH, BATCH_SIZE, -1)
85+
router_logits_hf = router_logits_hf.view(BATCH_SIZE, TEST_SEQUENCE_LENGTH, -1).permute(1, 0, 2)
86+
87+
assert router_logits_nt.size() == router_logits_hf.size()
88+
torch.testing.assert_close(router_logits_nt, router_logits_hf)
89+
90+
91+
def test_nt2hf_gate(hidden_states: torch.Tensor):
92+
init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_gate)(hidden_states=hidden_states)
93+
94+
95+
def _test_nt2hf_ff(parallel_context: ParallelContext, hidden_states: torch.Tensor,
96+
num_experts: int, num_experts_per_tok: int):
97+
hidden_states = hidden_states.cuda()
98+
99+
config = {**vars(CONFIG)}
100+
config.update({"moe_num_experts": num_experts, "num_experts_per_tok": num_experts_per_tok})
101+
config = GPT3MoEConfig(**config)
102+
config_hf = convert_config(config)
103+
ff_nt = dMoE(config, parallel_context, PARALLEL_CONFIG).cuda().to(DTYPE)
104+
ff_hf = XGLMSparseMoeBlock(config_hf).cuda().to(DTYPE)
105+
convert_ff(ff_hf, ff_nt)
106+
107+
out_nt = ff_nt(hidden_states)["hidden_states"]
108+
out_hf, _ = ff_hf(hidden_states.permute(1, 0, 2).contiguous(), "")
109+
out_hf = out_hf.permute(1, 0, 2)
110+
111+
assert out_nt.size() == out_hf.size()
112+
almost_close(out_nt, out_hf, max_far=0.05, far_atol=0.003)
113+
114+
115+
@pytest.mark.parametrize("num_experts,num_experts_per_tok", [(1, 1), (2, 1), (4, 1), (4, 2), (8, 1), (8, 2), (8, 4)])
116+
def test_nt2hf_ff(hidden_states: torch.Tensor, num_experts: int, num_experts_per_tok: int):
117+
init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_ff)(hidden_states=hidden_states, num_experts=num_experts, num_experts_per_tok=num_experts_per_tok)
118+
119+
120+
def _test_nt2hf_model(parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor):
121+
random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()})
122+
input_ids = input_ids.cuda()
123+
input_mask = input_mask.cuda()
124+
125+
# unfortunately, we can't use float64 with huggingface xglm.
126+
new_dtype = torch.float32 if DTYPE == torch.float64 else DTYPE
127+
128+
# Get nanotron model.
129+
config_nt = GPT3MoEConfig(**vars(CONFIG))
130+
if new_dtype not in {torch.bfloat16, torch.float16}:
131+
config_nt.use_spda = True
132+
model_nt = nanotron.models.build_model(
133+
model_builder=lambda: GPT3MoEForTraining(
134+
config=config_nt,
135+
parallel_context=parallel_context,
136+
parallel_config=None,
137+
random_states=random_states,
138+
),
139+
parallel_context=parallel_context,
140+
dtype=new_dtype,
141+
device="cuda",
142+
).eval()
143+
mark_tied_parameters(model=model_nt, parallel_context=parallel_context)
144+
145+
# Create empty model_hf and make conversion.
146+
model_hf = XGLMForCausalLM(convert_config(config_nt)).cuda().to(new_dtype).eval()
147+
convert(model_hf, model_nt)
148+
149+
# Needed :/
150+
aux_losses = {
151+
"load_balancing_loss": (
152+
torch.zeros(1, device=input_ids.device)
153+
if not isinstance(input_ids, TensorPointer)
154+
else TensorPointer(self.input_pp_rank)
155+
),
156+
"z_loss": (
157+
torch.zeros(1, device=input_ids.device)
158+
if not isinstance(input_ids, TensorPointer)
159+
else TensorPointer(self.input_pp_rank)
160+
),
161+
}
162+
163+
# Get outputs and assert.
164+
with torch.no_grad():
165+
out_nt = model_nt.model(input_ids, input_mask, aux_losses)["sharded_logits"].to(new_dtype)
166+
del model_nt
167+
torch.cuda.empty_cache()
168+
out_hf = model_hf(input_ids=input_ids, attention_mask=input_mask, output_router_logits=False).logits.permute(1, 0, 2)
169+
del model_hf
170+
torch.cuda.empty_cache()
171+
assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}"
172+
return out_nt.cpu(), out_hf.cpu()
173+
174+
175+
def _test_nt2hf_dummy_xglm(parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor):
176+
out_nt, out_hf = _test_nt2hf_model(parallel_context, input_ids, input_mask)
177+
almost_close(out_nt, out_hf, max_far=0.01, far_atol=2.0) # We allow for less than 1% errors, but some of these are very large!
178+
#torch.testing.assert_close(out_nt.bfloat16(), out_hf.bfloat16())
179+
180+
181+
def test_nt2hf_dummy_xglm(input_ids: torch.Tensor, input_mask: torch.Tensor):
182+
init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_dummy_xglm)(input_ids=input_ids, input_mask=input_mask)

0 commit comments

Comments
 (0)