Skip to content

Commit 930fe81

Browse files
committed
very close
1 parent 38364d5 commit 930fe81

File tree

4 files changed

+234
-22
lines changed

4 files changed

+234
-22
lines changed

examples/xglm/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,4 @@ To save back to huggingface format use
3131
torchrun examples/xglm/convert_ntmoe2hf.py --checkpoint-path=$SCRATCH/checkpoints/xglm-8x564M --save-path=$SCRATCH/checkpoints/huggingface/xglm-8x56fM
3232
```
3333

34-
Make sure to have the [XGLM MOE implementation](https://github.com/negar-foroutan/Multilingual_MoE) installed (e.g. using `PYTHONPATH=/path/to/Multilingual_MoE/models`).
34+
Make sure to have the [XGLM MOE implementation](https://github.com/negar-foroutan/Multilingual_MoE) installed (e.g. using `PYTHONPATH=/path/to/Multilingual_MoE`).

examples/xglm/convert_ntmoe2hf.py

+34-14
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,26 @@
99
from pathlib import Path
1010
from typing import Optional
1111

12+
import torch
1213
from transformers import AutoTokenizer
14+
from tqdm import tqdm
1315

1416
from nanotron.config.models_config import GPT3MoEConfig
1517
from nanotron.models.gpt3_moe import GPT3MoEForTraining, GPT3MoEBlock
16-
from nanotron.models.moe import dMoE, SparseMLP
18+
from nanotron.models.moe import dMoE, SparseMLP, LearnedRouter
1719

18-
from examples.xglm.convert_dense2moe import create_nt_moe_model, convert_attention
20+
from examples.xglm.convert_dense2moe import create_nt_moe_model
21+
from examples.xglm.convert_nt2hf import convert_attention
1922
from examples.xglm.convert_utils import convert_generic
2023

2124
from models.xglm_model import XGLMForCausalLM, XGLMDecoderLayer, XGLMmoeConfig, XGLMSparseMoeBlock, XGLMMLP
25+
from models.gating import BasicGate
2226

2327
# TODO: nanotron moe scales down the moe weights but hf doesn't
2428
# TODO: nanotron does not use pdrop in moe.
2529

2630

27-
def convert_config(config: GPT3MoEConfig) -> XGLMmoeConfig
31+
def convert_config(config: GPT3MoEConfig) -> XGLMmoeConfig:
2832
assert config.moe_num_experts > 1, f"Why are you using a 1-expert moe? lol"
2933
if config.embd_pdrop != config.resid_pdrop:
3034
warnings.warn(
@@ -59,7 +63,7 @@ def convert_config(config: GPT3MoEConfig) -> XGLMmoeConfig
5963
num_experts_per_tok=config.num_experts_per_tok,
6064
gate_type="linear",
6165
gate_depth=1,
62-
router_aux_loss_coef=config.moe_looss_weight,
66+
router_aux_loss_coef=config.moe_loss_weight,
6367
)
6468

6569

@@ -69,25 +73,38 @@ def convert_mlp(mlp_hf: XGLMMLP, mlp_nt: SparseMLP):
6973
convert_generic(mlp_hf.fc2, mlp_nt.w2.module)
7074

7175

72-
def convert_ff(ff_hf: XGLMSparseMoeBlock, ff_nt: dMoE):
73-
convert_generic(ff_hf.gate.gate, ff_nt.router.layer)
74-
for expert_hf, expert_nt in zip(ff_hf.experts, ff_nt.experts):
75-
convert_mlp(expert_hf, expert_nt.mlp)
76+
def convert_gate(gate_hf: BasicGate, gate_nt: LearnedRouter):
77+
convert_generic(gate_hf.gate, gate_nt.layer)
78+
7679

80+
def convert_ff(ff_hf: XGLMSparseMoeBlock, ff_nt: dMoE):
81+
convert_gate(ff_hf.gate, ff_nt.gate)
82+
int_size = ff_nt.config.intermediate_size
83+
for i, expert_hf in enumerate(ff_hf.experts):
84+
# TODO: fc1, fc2 has bias
85+
i0 = i*int_size
86+
i1 = (i + 1)*int_size
87+
with torch.no_grad():
88+
expert_hf.fc1.weight.copy_(ff_nt.experts.mlp.w1.module.weight.T[i0:i1, :].clone())
89+
expert_hf.fc1.bias.data.zero_()
90+
expert_hf.fc2.weight.copy_(ff_nt.experts.mlp.w2.module.weight[i0:i1, :].T.clone())
91+
expert_hf.fc2.bias.data.zero_()
7792

7893
def convert_decoder(block_hf: XGLMDecoderLayer, block_nt: GPT3MoEBlock):
7994
convert_generic(block_hf.self_attn_layer_norm, block_nt.ln_1)
8095
convert_attention(block_hf.self_attn, block_nt.attn)
8196
convert_generic(block_hf.final_layer_norm, block_nt.ln_2)
8297
# TODO: hf has fc1, fc2 attributes but they are not used, probably should be removed.
83-
convert_generic(block_hf.fc1, block_nt.ff.c_fc)
84-
convert_generic(block_hf.fc2, block_nt.ff.c_proj)
98+
#return block_nt.ff
99+
convert_ff(block_hf.block_sparse_moe, block_nt.ff) # REMOVE
85100

86101

87102
def convert(model_hf: XGLMForCausalLM, model_nt: GPT3MoEForTraining):
88103
convert_generic(model_hf.model.embed_tokens, model_nt.model.token_embeddings.pp_block.token_embedding)
89-
for layer_hf, layer_nt in zip(model_hf.model.layers, model_nt.model.decoder):
90-
convert_decoder(layer_hf, layer_nt.pp_block)
104+
for layer_hf, layer_nt in tqdm(zip(model_hf.model.layers, model_nt.model.decoder), desc="Converting layers",
105+
total=model_nt.config.num_hidden_layers):
106+
#return convert_decoder(layer_hf, layer_nt.pp_block)
107+
convert_decoder(layer_hf, layer_nt.pp_block) # REMOVE
91108
convert_generic(model_hf.model.layer_norm, model_nt.model.final_layer_norm.pp_block)
92109
convert_generic(model_hf.lm_head, model_nt.model.lm_head.pp_block)
93110

@@ -104,7 +121,10 @@ def main(checkpoint_path: Path, save_path: Path, tokenizer_name: Optional[str]):
104121
if tokenizer_name is not None:
105122
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
106123
tokenizer.save_pretrained(save_path)
107-
convert(model_hf, model_nt)
124+
states = torch.randn(4, 1, 1024)
125+
#return convert(model_hf, model_nt), states.cuda().bfloat16()
126+
convert(model_hf, model_nt), states.cuda().bfloat16() # REMOVE
127+
print("Saving...")
108128
model_hf.save_pretrained(save_path)
109129
print(f"Model saved to {save_path}")
110130

@@ -119,4 +139,4 @@ def main(checkpoint_path: Path, save_path: Path, tokenizer_name: Optional[str]):
119139
)
120140
parser.add_argument("--tokenizer-name", type=str, default="facebook/xglm-7.5B")
121141
args = parser.parse_args()
122-
main(args.checkpoint_path, args.save_path, args.tokenizer_name)
142+
ret = main(args.checkpoint_path, args.save_path, args.tokenizer_name)

examples/xglm/tests/test_moe.py

+179
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
import torch
2+
import pytest
3+
4+
import nanotron
5+
from nanotron.config.parallelism_config import ParallelismArgs
6+
from nanotron.config.models_config import GPT3MoEConfig
7+
from nanotron.parallel import ParallelContext
8+
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
9+
from nanotron.trainer import mark_tied_parameters
10+
from nanotron.models.gpt3_moe import GPT3MoEBlock, GPT3MoEForTraining
11+
from nanotron.models.moe import LearnedRouter, dMoE
12+
13+
from tests.helpers.utils import init_distributed
14+
15+
from examples.xglm.convert_ntmoe2hf import convert_config, convert_gate, convert_ff, convert
16+
from examples.xglm.tests.test_implementation import almost_close
17+
18+
from models.xglm_model import XGLMSparseMoeBlock, XGLMForCausalLM
19+
from models.gating import BasicGate
20+
21+
22+
MAX_SEQUENCE_LENGTH = 2048
23+
TEST_SEQUENCE_LENGTH = 128 # If we test with a very large sequence length, precision errors get more significant independent of the correct implementation.
24+
#TEST_SEQUENCE_LENGTH = MAX_SEQUENCE_LENGTH
25+
BATCH_SIZE = 4
26+
HIDDEN_SIZE = 1024
27+
DTYPE = torch.bfloat16
28+
#DTYPE = torch.float32
29+
TEXT = "Hello. This is a relatively long text. I will use this text to test the conversion scripts. Let's finish this text soon because I don't have much more to say. Final note:"
30+
31+
CONFIG = GPT3MoEConfig(
32+
attn_pdrop=0.0,
33+
embd_pdrop=0.0,
34+
resid_pdrop=0.0,
35+
act_pdrop=0.0,
36+
eos_token_id=2,
37+
hidden_size=HIDDEN_SIZE,
38+
intermediate_size=4096,
39+
layer_norm_epsilon=1e-05,
40+
max_position_embeddings=MAX_SEQUENCE_LENGTH,
41+
num_attention_heads=16,
42+
num_hidden_layers=24,
43+
scale_attn_weights=True,
44+
vocab_size=256008,
45+
sinusoidal_position_embedding=True,
46+
position_embedding_offset=2,
47+
use_spda=DTYPE is not torch.bfloat16,
48+
# vvv moe vvv
49+
is_moe=True,
50+
moe_num_experts=4,
51+
num_experts_per_tok=4,
52+
moe_loss_weight=0.01,
53+
moe_z_loss_weight=0.0,
54+
moe_glu=False,
55+
)
56+
#PARALLEL_CONFIG = ParallelismArgs(dp=1, pp=1, tp=1, expert_parallel_size=1) #CONFIG.moe_num_experts)
57+
58+
59+
@pytest.fixture
60+
def hidden_states() -> torch.Tensor:
61+
return torch.randn(TEST_SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE, dtype=DTYPE)
62+
63+
64+
@pytest.fixture
65+
def input_mask() -> torch.Tensor:
66+
return torch.ones(BATCH_SIZE, TEST_SEQUENCE_LENGTH, dtype=torch.bool)
67+
68+
69+
@pytest.fixture
70+
def input_ids() -> torch.Tensor:
71+
return torch.randint(0, CONFIG.vocab_size, (BATCH_SIZE, TEST_SEQUENCE_LENGTH))
72+
73+
74+
def _test_nt2hf_gate(parallel_context: ParallelContext, hidden_states: torch.Tensor):
75+
hidden_states = hidden_states.cuda()
76+
77+
config_hf = convert_config(CONFIG)
78+
gate_nt = LearnedRouter(CONFIG).cuda().to(DTYPE)
79+
gate_hf = BasicGate(config_hf).cuda().to(DTYPE)
80+
convert_gate(gate_hf, gate_nt)
81+
82+
router_logits_nt, _, _ = gate_nt(hidden_states.view(-1, HIDDEN_SIZE))
83+
router_logits_hf = gate_hf(hidden_states.permute(1, 0, 2).reshape(-1, HIDDEN_SIZE), "")
84+
85+
router_logits_nt = router_logits_nt.view(TEST_SEQUENCE_LENGTH, BATCH_SIZE, -1)
86+
router_logits_hf = router_logits_hf.view(BATCH_SIZE, TEST_SEQUENCE_LENGTH, -1).permute(1, 0, 2)
87+
88+
assert router_logits_nt.size() == router_logits_hf.size()
89+
torch.testing.assert_close(router_logits_nt, router_logits_hf)
90+
91+
92+
def test_nt2hf_gate(hidden_states: torch.Tensor):
93+
init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_gate)(hidden_states=hidden_states)
94+
95+
96+
def _test_nt2hf_ff(parallel_context: ParallelContext, hidden_states: torch.Tensor):
97+
hidden_states = hidden_states.cuda()
98+
99+
config_hf = convert_config(CONFIG)
100+
ff_nt = dMoE(CONFIG, parallel_context, None).cuda().to(DTYPE)
101+
ff_hf = XGLMSparseMoeBlock(config_hf).cuda().to(DTYPE)
102+
convert_ff(ff_hf, ff_nt)
103+
104+
out_nt = ff_nt(hidden_states)["hidden_states"]
105+
out_hf, _ = ff_hf(hidden_states.permute(1, 0, 2).contiguous(), "")
106+
out_hf = out_hf.permute(1, 0, 2)
107+
108+
assert out_nt.size() == out_hf.size()
109+
almost_close(out_nt, out_hf, max_far=0.1, far_atol=0.02)
110+
#torch.testing.assert_close(out_nt, out_hf)
111+
112+
113+
114+
def _test_nt2hf_model(parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor):
115+
random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()})
116+
input_ids = input_ids.cuda()
117+
input_mask = input_mask.cuda()
118+
119+
# unfortunately, we can't use float64 with huggingface xglm.
120+
new_dtype = torch.float32 if DTYPE == torch.float64 else DTYPE
121+
122+
# Get nanotron model.
123+
config_nt = GPT3MoEConfig(**vars(CONFIG))
124+
if new_dtype not in {torch.bfloat16, torch.float16}:
125+
config_nt.use_spda = True
126+
model_nt = nanotron.models.build_model(
127+
model_builder=lambda: GPT3MoEForTraining(
128+
config=config_nt,
129+
parallel_context=parallel_context,
130+
parallel_config=None,
131+
random_states=random_states,
132+
),
133+
parallel_context=parallel_context,
134+
dtype=new_dtype,
135+
device="cuda",
136+
).eval()
137+
mark_tied_parameters(model=model_nt, parallel_context=parallel_context)
138+
139+
# Create empty model_hf and make conversion.
140+
model_hf = XGLMForCausalLM(convert_config(config_nt)).cuda().to(new_dtype).eval()
141+
convert(model_hf, model_nt)
142+
143+
# Needed :/
144+
aux_losses = {
145+
"load_balancing_loss": (
146+
torch.zeros(1, device=input_ids.device)
147+
if not isinstance(input_ids, TensorPointer)
148+
else TensorPointer(self.input_pp_rank)
149+
),
150+
"z_loss": (
151+
torch.zeros(1, device=input_ids.device)
152+
if not isinstance(input_ids, TensorPointer)
153+
else TensorPointer(self.input_pp_rank)
154+
),
155+
}
156+
157+
# Get outputs and assert.
158+
with torch.no_grad():
159+
out_nt = model_nt.model(input_ids, input_mask, aux_losses)["sharded_logits"].to(new_dtype)
160+
del model_nt
161+
torch.cuda.empty_cache()
162+
out_hf = model_hf(input_ids=input_ids, attention_mask=input_mask, output_router_logits=False).logits.permute(1, 0, 2)
163+
del model_hf
164+
torch.cuda.empty_cache()
165+
assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}"
166+
return out_nt.cpu(), out_hf.cpu()
167+
168+
169+
def test_nt2hf_ff(hidden_states: torch.Tensor):
170+
init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_ff)(hidden_states=hidden_states)
171+
172+
173+
def _test_nt2hf_dummy_xglm(parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor):
174+
out_nt, out_hf = _test_nt2hf_model(parallel_context, input_ids, input_mask)
175+
almost_close(out_nt, out_hf, max_far=0.1, far_atol=0.02)
176+
177+
178+
def test_nt2hf_dummy_xglm(input_ids: torch.Tensor, input_mask: torch.Tensor):
179+
init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_dummy_xglm)(input_ids=input_ids, input_mask=input_mask)

src/nanotron/models/moe.py

+20-7
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,8 @@ def forward(self, hidden_states: torch.Tensor):
162162
router_logits, expert_weights, top_experts = self.gate(x)
163163

164164
# Compute the experts.
165-
x, lbl_loss, z_loss = self.experts(x, router_logits, expert_weights, top_experts)
165+
#return self.experts(x, router_logits, expert_weights, top_experts)
166+
x, lbl_loss, z_loss = self.experts(x, router_logits, expert_weights, top_experts) #REMOVE
166167
return {
167168
"hidden_states": x.reshape(batch_size, sequence_length, -1),
168169
"load_balancing_loss": lbl_loss,
@@ -300,12 +301,15 @@ def forward_once(self, x, expert_weights, top_experts): # TODO: sparse
300301
) = self.indices_and_padded_bins(top_experts)
301302

302303
# Route the tokens for MoE computation.
304+
#x_pre = x.clone()
303305
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.num_experts_per_tok)
306+
#print("forward_once a", x.shape)
304307

305308
with torch.no_grad():
306309
topo = self.topology(x, padded_bins)
307310

308-
x = self.mlp(x, topo)
311+
x = self.mlp(x, topo) #REMOVE
312+
#return x_pre, self.mlp(x, topo)
309313

310314
# Un-route the data for the MoE output.
311315
x = ops.padded_scatter(
@@ -422,7 +426,11 @@ def forward(self, x, router_logits, expert_weights, top_experts):
422426
top_experts: tensor of shape [sequence_length * batch_size, num_experts_per_tok]
423427
"""
424428
# Compute the experts.
425-
x, tokens_per_expert = self.forward_fn(x, expert_weights.flatten(), top_experts.flatten())
429+
x, tokens_per_expert = self.forward_fn(x, expert_weights.flatten(), top_experts.flatten()) #REMOVE
430+
#return router_logits
431+
#print("nano b", expert_weights)
432+
#return expert_weights.bfloat16()
433+
#return self.forward_fn(x, expert_weights.flatten(), top_experts.flatten())
426434
if self.training:
427435
lbl_loss = load_balancing_loss(router_logits, tokens_per_expert, self.config)
428436
z_loss = router_z_loss(router_logits, self.config)
@@ -595,9 +603,14 @@ def __init__(
595603

596604
def forward(self, x, topo):
597605
self.w1.scale_gradients(), self.w2.scale_gradients()
598-
x = self.sdd(x.contiguous(), self.w1.module.weight, topo)
599-
activation_fn_out = act_fn(x, self.act)
600-
return self.dsd(activation_fn_out, self.w2.module.weight)
606+
x = self.sdd(x.contiguous(), self.w1.module.weight, topo) # REMOVE
607+
#x1 = self.sdd(x.contiguous(), self.w1.module.weight, topo)
608+
activation_fn_out = act_fn(x, self.act) # REMOVE
609+
#print(x.shape, activation_fn_out.shape, self.w2.module.weight.shape)
610+
#activation_fn_out = act_fn(x1, self.act)
611+
return self.dsd(activation_fn_out, self.w2.module.weight) #REMOVE
612+
#x2 = self.dsd(activation_fn_out, self.w2.module.weight)
613+
#return x, x1, x2, topo, self.w1.module.weight, self.w2.module.weight
601614

602615

603616
class MLP(nn.Module):
@@ -718,4 +731,4 @@ def forward(self, x, topo):
718731
x1 = self.sdd(x, self.w1.module.weight, topo)
719732
x2 = self.sdd(x, self.w3.module.weight, topo)
720733
x = stk.ops.mul(act_fn(x1, self.act), x2)
721-
return self.dsd(x, self.w2.module.weight)
734+
return self.dsd(x, self.w2.module.weight)

0 commit comments

Comments
 (0)