22
22
# See the License for the specific language governing permissions and
23
23
# limitations under the License.
24
24
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
25
- from collections .abc import Iterable
26
- from typing import Any , Optional , Union
25
+ import typing
26
+ from collections .abc import Iterable , Callable
27
+ from typing import Optional , Any , Union
27
28
28
29
import torch
29
- from torch import nn
30
+ from torch import nn , Tensor
30
31
from transformers import PretrainedConfig
31
32
32
33
from vllm .attention import Attention
33
34
from vllm .compilation .decorators import support_torch_compile
34
- from vllm .config import CacheConfig , VllmConfig
35
- from vllm .distributed import get_pp_group , get_tensor_model_parallel_world_size
35
+ from vllm .config import CacheConfig , VllmConfig , get_current_vllm_config
36
+ from vllm .distributed import (get_ep_group , get_pp_group ,
37
+ get_tensor_model_parallel_world_size )
36
38
from vllm .logger import init_logger
37
39
from vllm .model_executor .layers .activation import SiluAndMul
38
40
from vllm .model_executor .layers .fused_moe import FusedMoE
50
52
from vllm .model_executor .sampling_metadata import SamplingMetadata
51
53
from vllm .sequence import IntermediateTensors
52
54
53
- from .interfaces import SupportsPP
55
+ from .interfaces import SupportsPP , MixtureOfExperts
54
56
from .utils import (AutoWeightsLoader , extract_layer_index ,
55
57
is_pp_missing_parameter ,
56
58
make_empty_intermediate_tensors_factory , make_layers ,
@@ -101,6 +103,7 @@ def __init__(
101
103
config : PretrainedConfig ,
102
104
quant_config : Optional [QuantizationConfig ] = None ,
103
105
prefix : str = "" ,
106
+ enable_eplb : bool = False ,
104
107
):
105
108
super ().__init__ ()
106
109
self .tp_size = get_tensor_model_parallel_world_size ()
@@ -110,14 +113,29 @@ def __init__(
110
113
f"Tensor parallel size { self .tp_size } is greater than "
111
114
f"the number of experts { config .num_experts } ." )
112
115
116
+ self .ep_group = get_ep_group ().device_group
117
+ self .ep_size = self .ep_group .size ()
118
+
119
+ vllm_config = get_current_vllm_config ()
120
+ parallel_config = vllm_config .parallel_config
121
+ self .n_routed_experts = config .n_routed_experts
122
+ self .n_redundant_experts = parallel_config .num_redundant_experts
123
+ self .n_logical_experts = self .n_routed_experts
124
+ self .n_physical_experts = (self .n_logical_experts +
125
+ self .n_redundant_experts )
126
+ self .n_local_physical_experts = self .n_physical_experts // self .ep_size
127
+ self .enable_eplb = enable_eplb
128
+
113
129
self .experts = FusedMoE (num_experts = config .num_experts ,
114
130
top_k = config .num_experts_per_tok ,
115
131
hidden_size = config .hidden_size ,
116
132
intermediate_size = config .moe_intermediate_size ,
117
133
reduce_results = False ,
118
134
renormalize = config .norm_topk_prob ,
119
135
quant_config = quant_config ,
120
- prefix = f"{ prefix } .experts" )
136
+ prefix = f"{ prefix } .experts" ,
137
+ enable_eplb = self .enable_eplb ,
138
+ num_redundant_experts = self .n_redundant_experts )
121
139
122
140
self .gate = ReplicatedLinear (config .hidden_size ,
123
141
config .num_experts ,
@@ -246,6 +264,7 @@ def __init__(
246
264
cache_config : Optional [CacheConfig ] = None ,
247
265
quant_config : Optional [QuantizationConfig ] = None ,
248
266
prefix : str = "" ,
267
+ enable_eplb : bool = False ,
249
268
) -> None :
250
269
super ().__init__ ()
251
270
self .hidden_size = config .hidden_size
@@ -277,7 +296,8 @@ def __init__(
277
296
(layer_idx + 1 ) % config .decoder_sparse_step == 0 ):
278
297
self .mlp = Qwen3MoeSparseMoeBlock (config = config ,
279
298
quant_config = quant_config ,
280
- prefix = f"{ prefix } .mlp" )
299
+ prefix = f"{ prefix } .mlp" ,
300
+ enable_eplb = enable_eplb )
281
301
else :
282
302
self .mlp = Qwen3MoeMLP (hidden_size = config .hidden_size ,
283
303
intermediate_size = config .intermediate_size ,
@@ -323,6 +343,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
323
343
config = vllm_config .model_config .hf_config
324
344
cache_config = vllm_config .cache_config
325
345
quant_config = vllm_config .quant_config
346
+ enable_eplb = vllm_config .parallel_config .enable_eplb
347
+ self .num_redundant_experts = (
348
+ vllm_config .parallel_config .num_redundant_experts )
326
349
327
350
self .padding_idx = config .pad_token_id
328
351
self .vocab_size = config .vocab_size
@@ -336,7 +359,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
336
359
lambda prefix : Qwen3MoeDecoderLayer (config = config ,
337
360
cache_config = cache_config ,
338
361
quant_config = quant_config ,
339
- prefix = prefix ),
362
+ prefix = prefix ,
363
+ enable_eplb = enable_eplb ),
340
364
prefix = f"{ prefix } .layers" ,
341
365
)
342
366
self .norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
@@ -375,15 +399,6 @@ def forward(
375
399
hidden_states , _ = self .norm (hidden_states , residual )
376
400
return hidden_states
377
401
378
- def get_expert_mapping (self ) -> list [tuple [str , str , int , str ]]:
379
- # Params for weights, fp8 weight scales, fp8 activation scales
380
- # (param_name, weight_name, expert_id, shard_id)
381
- return FusedMoE .make_expert_params_mapping (
382
- ckpt_gate_proj_name = "gate_proj" ,
383
- ckpt_down_proj_name = "down_proj" ,
384
- ckpt_up_proj_name = "up_proj" ,
385
- num_experts = self .config .num_experts )
386
-
387
402
def load_weights (self , weights : Iterable [tuple [str ,
388
403
torch .Tensor ]]) -> set [str ]:
389
404
stacked_params_mapping = [
@@ -400,9 +415,17 @@ def load_weights(self, weights: Iterable[tuple[str,
400
415
".v_scale" , "_v_scale" , ".weight_scale" ,
401
416
"_weight_scale" , ".input_scale" , "_input_scale" )
402
417
418
+ # Params for weights, fp8 weight scales, fp8 activation scales
419
+ # (param_name, weight_name, expert_id, shard_id)
420
+ expert_params_mapping = FusedMoE .make_expert_params_mapping (
421
+ ckpt_gate_proj_name = "gate_proj" ,
422
+ ckpt_down_proj_name = "down_proj" ,
423
+ ckpt_up_proj_name = "up_proj" ,
424
+ num_experts = self .config .num_experts ,
425
+ num_redundant_experts = self .num_redundant_experts )
426
+
403
427
params_dict = dict (self .named_parameters ())
404
428
loaded_params : set [str ] = set ()
405
- expert_params_mapping = self .get_expert_mapping ()
406
429
for name , loaded_weight in weights :
407
430
for (param_name , weight_name , shard_id ) in stacked_params_mapping :
408
431
# Skip non-stacked layers and experts (experts handled below).
@@ -433,27 +456,37 @@ def load_weights(self, weights: Iterable[tuple[str,
433
456
weight_loader (param , loaded_weight , shard_id )
434
457
break
435
458
else :
459
+ is_expert_weight = False
436
460
for mapping in expert_params_mapping :
437
461
param_name , weight_name , expert_id , shard_id = mapping
438
462
if weight_name not in name :
439
463
continue
440
- name = name .replace (weight_name , param_name )
464
+
465
+ is_expert_weight = True
466
+
467
+ name_mapped = name .replace (weight_name , param_name )
468
+
441
469
# Skip layers on other devices.
442
- if is_pp_missing_parameter (name , self ):
470
+ if is_pp_missing_parameter (name_mapped , self ):
443
471
continue
444
472
# Skip loading extra parameters for GPTQ/modelopt models.
445
473
if name .endswith (
446
474
ignore_suffixes ) and name not in params_dict :
447
475
continue
448
- param = params_dict [name ]
449
- weight_loader = param .weight_loader
450
- weight_loader (param ,
451
- loaded_weight ,
452
- name ,
453
- shard_id = shard_id ,
454
- expert_id = expert_id )
455
- break
476
+ param = params_dict [name_mapped ]
477
+ weight_loader = typing .cast (Callable [..., bool ],
478
+ param .weight_loader )
479
+ success = weight_loader (param ,
480
+ loaded_weight ,
481
+ name_mapped ,
482
+ shard_id = shard_id ,
483
+ expert_id = expert_id ,
484
+ return_success = True )
485
+ if success :
486
+ break
456
487
else :
488
+ if is_expert_weight :
489
+ continue
457
490
# Skip loading extra parameters for GPTQ/modelopt models.
458
491
if name .endswith (
459
492
ignore_suffixes ) and name not in params_dict :
@@ -482,7 +515,7 @@ def load_weights(self, weights: Iterable[tuple[str,
482
515
return loaded_params
483
516
484
517
485
- class Qwen3MoeForCausalLM (nn .Module , SupportsPP ):
518
+ class Qwen3MoeForCausalLM (nn .Module , SupportsPP , MixtureOfExperts ):
486
519
packed_modules_mapping = {
487
520
"qkv_proj" : [
488
521
"q_proj" ,
@@ -513,6 +546,43 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
513
546
self .logits_processor = LogitsProcessor (config .vocab_size )
514
547
self .make_empty_intermediate_tensors = (
515
548
self .model .make_empty_intermediate_tensors )
549
+
550
+ # Implement the MixtureOfExperts protocol.
551
+ self .expert_weights = []
552
+
553
+ self .moe_layers : list [FusedMoE ] = []
554
+ for layer in self .model .layers :
555
+ assert isinstance (layer , Qwen3MoeDecoderLayer )
556
+ if isinstance (layer .mlp , Qwen3MoeDecoderLayer ):
557
+ self .moe_layers .append (layer .mlp .experts )
558
+ self .num_moe_layers = len (self .moe_layers )
559
+
560
+ example_layer = typing .cast (
561
+ Qwen3MoeSparseMoeBlock ,
562
+ self .model .layers [config .num_hidden_layers - 1 ].mlp )
563
+
564
+ self .num_expert_groups = 1
565
+ self .num_logical_experts = example_layer .n_logical_experts
566
+ self .num_physical_experts = example_layer .n_physical_experts
567
+ self .num_local_physical_experts = example_layer .n_local_physical_experts
568
+ self .num_routed_experts = example_layer .n_routed_experts
569
+ self .num_shared_experts = 0
570
+ self .num_redundant_experts = example_layer .n_redundant_experts
571
+
572
+ def set_eplb_state (
573
+ self ,
574
+ expert_load_view : Tensor ,
575
+ logical_to_physical_map : Tensor ,
576
+ logical_replica_count : Tensor ,
577
+ ) -> None :
578
+ for layer_idx , layer in enumerate (self .moe_layers ):
579
+ self .expert_weights .append (layer .get_expert_weights ())
580
+ layer .set_eplb_state (
581
+ moe_layer_idx = layer_idx ,
582
+ expert_load_view = expert_load_view ,
583
+ logical_to_physical_map = logical_to_physical_map ,
584
+ logical_replica_count = logical_replica_count ,
585
+ )
516
586
517
587
def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
518
588
return self .model .get_input_embeddings (input_ids )
@@ -540,7 +610,4 @@ def compute_logits(
540
610
def load_weights (self , weights : Iterable [tuple [str ,
541
611
torch .Tensor ]]) -> set [str ]:
542
612
loader = AutoWeightsLoader (self )
543
- return loader .load_weights (weights )
544
-
545
- def get_expert_mapping (self ) -> list [tuple [str , str , int , str ]]:
546
- return self .model .get_expert_mapping ()
613
+ return loader .load_weights (weights )
0 commit comments