|
| 1 | +import torch |
| 2 | +import pytest |
| 3 | + |
| 4 | +import nanotron |
| 5 | +from nanotron.config.parallelism_config import ParallelismArgs |
| 6 | +from nanotron.config.models_config import GPT3MoEConfig |
| 7 | +from nanotron.parallel import ParallelContext |
| 8 | +from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer |
| 9 | +from nanotron.trainer import mark_tied_parameters |
| 10 | +from nanotron.models.gpt3_moe import GPT3MoEBlock, GPT3MoEForTraining |
| 11 | +from nanotron.models.moe import LearnedRouter, dMoE |
| 12 | + |
| 13 | +from tests.helpers.utils import init_distributed |
| 14 | + |
| 15 | +from examples.xglm.convert_ntmoe2hf import convert_config, convert_gate, convert_ff, convert |
| 16 | +from examples.xglm.tests.test_implementation import almost_close |
| 17 | + |
| 18 | +from models.xglm_model import XGLMSparseMoeBlock, XGLMForCausalLM |
| 19 | +from models.gating import BasicGate |
| 20 | + |
| 21 | + |
| 22 | +MAX_SEQUENCE_LENGTH = 2048 |
| 23 | +TEST_SEQUENCE_LENGTH = 128 # If we test with a very large sequence length, precision errors get more significant independent of the correct implementation. |
| 24 | +#TEST_SEQUENCE_LENGTH = MAX_SEQUENCE_LENGTH |
| 25 | +BATCH_SIZE = 4 |
| 26 | +HIDDEN_SIZE = 1024 |
| 27 | +DTYPE = torch.bfloat16 |
| 28 | +#DTYPE = torch.float32 |
| 29 | +TEXT = "Hello. This is a relatively long text. I will use this text to test the conversion scripts. Let's finish this text soon because I don't have much more to say. Final note:" |
| 30 | + |
| 31 | +CONFIG = GPT3MoEConfig( |
| 32 | + attn_pdrop=0.0, |
| 33 | + embd_pdrop=0.0, |
| 34 | + resid_pdrop=0.0, |
| 35 | + act_pdrop=0.0, |
| 36 | + eos_token_id=2, |
| 37 | + hidden_size=HIDDEN_SIZE, |
| 38 | + intermediate_size=4096, |
| 39 | + layer_norm_epsilon=1e-05, |
| 40 | + max_position_embeddings=MAX_SEQUENCE_LENGTH, |
| 41 | + num_attention_heads=16, |
| 42 | + num_hidden_layers=24, |
| 43 | + scale_attn_weights=True, |
| 44 | + vocab_size=256008, |
| 45 | + sinusoidal_position_embedding=True, |
| 46 | + position_embedding_offset=2, |
| 47 | + use_spda=DTYPE is not torch.bfloat16, |
| 48 | + # vvv moe vvv |
| 49 | + is_moe=True, |
| 50 | + moe_num_experts=4, |
| 51 | + num_experts_per_tok=4, |
| 52 | + moe_loss_weight=0.01, |
| 53 | + moe_z_loss_weight=0.0, |
| 54 | + moe_glu=False, |
| 55 | +) |
| 56 | +#PARALLEL_CONFIG = ParallelismArgs(dp=1, pp=1, tp=1, expert_parallel_size=1) #CONFIG.moe_num_experts) |
| 57 | + |
| 58 | + |
| 59 | +@pytest.fixture |
| 60 | +def hidden_states() -> torch.Tensor: |
| 61 | + return torch.randn(TEST_SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE, dtype=DTYPE) |
| 62 | + |
| 63 | + |
| 64 | +@pytest.fixture |
| 65 | +def input_mask() -> torch.Tensor: |
| 66 | + return torch.ones(BATCH_SIZE, TEST_SEQUENCE_LENGTH, dtype=torch.bool) |
| 67 | + |
| 68 | + |
| 69 | +@pytest.fixture |
| 70 | +def input_ids() -> torch.Tensor: |
| 71 | + return torch.randint(0, CONFIG.vocab_size, (BATCH_SIZE, TEST_SEQUENCE_LENGTH)) |
| 72 | + |
| 73 | + |
| 74 | +def _test_nt2hf_gate(parallel_context: ParallelContext, hidden_states: torch.Tensor): |
| 75 | + hidden_states = hidden_states.cuda() |
| 76 | + |
| 77 | + config_hf = convert_config(CONFIG) |
| 78 | + gate_nt = LearnedRouter(CONFIG).cuda().to(DTYPE) |
| 79 | + gate_hf = BasicGate(config_hf).cuda().to(DTYPE) |
| 80 | + convert_gate(gate_hf, gate_nt) |
| 81 | + |
| 82 | + router_logits_nt, _, _ = gate_nt(hidden_states.view(-1, HIDDEN_SIZE)) |
| 83 | + router_logits_hf = gate_hf(hidden_states.permute(1, 0, 2).reshape(-1, HIDDEN_SIZE), "") |
| 84 | + |
| 85 | + router_logits_nt = router_logits_nt.view(TEST_SEQUENCE_LENGTH, BATCH_SIZE, -1) |
| 86 | + router_logits_hf = router_logits_hf.view(BATCH_SIZE, TEST_SEQUENCE_LENGTH, -1).permute(1, 0, 2) |
| 87 | + |
| 88 | + assert router_logits_nt.size() == router_logits_hf.size() |
| 89 | + torch.testing.assert_close(router_logits_nt, router_logits_hf) |
| 90 | + |
| 91 | + |
| 92 | +def test_nt2hf_gate(hidden_states: torch.Tensor): |
| 93 | + init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_gate)(hidden_states=hidden_states) |
| 94 | + |
| 95 | + |
| 96 | +def _test_nt2hf_ff(parallel_context: ParallelContext, hidden_states: torch.Tensor): |
| 97 | + hidden_states = hidden_states.cuda() |
| 98 | + |
| 99 | + config_hf = convert_config(CONFIG) |
| 100 | + ff_nt = dMoE(CONFIG, parallel_context, None).cuda().to(DTYPE) |
| 101 | + ff_hf = XGLMSparseMoeBlock(config_hf).cuda().to(DTYPE) |
| 102 | + convert_ff(ff_hf, ff_nt) |
| 103 | + |
| 104 | + out_nt = ff_nt(hidden_states)["hidden_states"] |
| 105 | + out_hf, _ = ff_hf(hidden_states.permute(1, 0, 2).contiguous(), "") |
| 106 | + out_hf = out_hf.permute(1, 0, 2) |
| 107 | + |
| 108 | + assert out_nt.size() == out_hf.size() |
| 109 | + almost_close(out_nt, out_hf, max_far=0.1, far_atol=0.02) |
| 110 | + #torch.testing.assert_close(out_nt, out_hf) |
| 111 | + |
| 112 | + |
| 113 | + |
| 114 | +def _test_nt2hf_model(parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor): |
| 115 | + random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) |
| 116 | + input_ids = input_ids.cuda() |
| 117 | + input_mask = input_mask.cuda() |
| 118 | + |
| 119 | + # unfortunately, we can't use float64 with huggingface xglm. |
| 120 | + new_dtype = torch.float32 if DTYPE == torch.float64 else DTYPE |
| 121 | + |
| 122 | + # Get nanotron model. |
| 123 | + config_nt = GPT3MoEConfig(**vars(CONFIG)) |
| 124 | + if new_dtype not in {torch.bfloat16, torch.float16}: |
| 125 | + config_nt.use_spda = True |
| 126 | + model_nt = nanotron.models.build_model( |
| 127 | + model_builder=lambda: GPT3MoEForTraining( |
| 128 | + config=config_nt, |
| 129 | + parallel_context=parallel_context, |
| 130 | + parallel_config=None, |
| 131 | + random_states=random_states, |
| 132 | + ), |
| 133 | + parallel_context=parallel_context, |
| 134 | + dtype=new_dtype, |
| 135 | + device="cuda", |
| 136 | + ).eval() |
| 137 | + mark_tied_parameters(model=model_nt, parallel_context=parallel_context) |
| 138 | + |
| 139 | + # Create empty model_hf and make conversion. |
| 140 | + model_hf = XGLMForCausalLM(convert_config(config_nt)).cuda().to(new_dtype).eval() |
| 141 | + convert(model_hf, model_nt) |
| 142 | + |
| 143 | + # Needed :/ |
| 144 | + aux_losses = { |
| 145 | + "load_balancing_loss": ( |
| 146 | + torch.zeros(1, device=input_ids.device) |
| 147 | + if not isinstance(input_ids, TensorPointer) |
| 148 | + else TensorPointer(self.input_pp_rank) |
| 149 | + ), |
| 150 | + "z_loss": ( |
| 151 | + torch.zeros(1, device=input_ids.device) |
| 152 | + if not isinstance(input_ids, TensorPointer) |
| 153 | + else TensorPointer(self.input_pp_rank) |
| 154 | + ), |
| 155 | + } |
| 156 | + |
| 157 | + # Get outputs and assert. |
| 158 | + with torch.no_grad(): |
| 159 | + out_nt = model_nt.model(input_ids, input_mask, aux_losses)["sharded_logits"].to(new_dtype) |
| 160 | + del model_nt |
| 161 | + torch.cuda.empty_cache() |
| 162 | + out_hf = model_hf(input_ids=input_ids, attention_mask=input_mask, output_router_logits=False).logits.permute(1, 0, 2) |
| 163 | + del model_hf |
| 164 | + torch.cuda.empty_cache() |
| 165 | + assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" |
| 166 | + return out_nt.cpu(), out_hf.cpu() |
| 167 | + |
| 168 | + |
| 169 | +def test_nt2hf_ff(hidden_states: torch.Tensor): |
| 170 | + init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_ff)(hidden_states=hidden_states) |
| 171 | + |
| 172 | + |
| 173 | +def _test_nt2hf_dummy_xglm(parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor): |
| 174 | + out_nt, out_hf = _test_nt2hf_model(parallel_context, input_ids, input_mask) |
| 175 | + almost_close(out_nt, out_hf, max_far=0.1, far_atol=0.02) |
| 176 | + |
| 177 | + |
| 178 | +def test_nt2hf_dummy_xglm(input_ids: torch.Tensor, input_mask: torch.Tensor): |
| 179 | + init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_dummy_xglm)(input_ids=input_ids, input_mask=input_mask) |
0 commit comments