Skip to content

Commit 5386287

Browse files
committed
modify qwen moe model
Signed-off-by: hsliu <[email protected]>
1 parent 68d28e3 commit 5386287

File tree

1 file changed

+101
-34
lines changed

1 file changed

+101
-34
lines changed

vllm/model_executor/models/qwen3_moe.py

Lines changed: 101 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,19 @@
2222
# See the License for the specific language governing permissions and
2323
# limitations under the License.
2424
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
25-
from collections.abc import Iterable
26-
from typing import Any, Optional, Union
25+
import typing
26+
from collections.abc import Iterable, Callable
27+
from typing import Optional, Any, Union
2728

2829
import torch
29-
from torch import nn
30+
from torch import nn, Tensor
3031
from transformers import PretrainedConfig
3132

3233
from vllm.attention import Attention
3334
from vllm.compilation.decorators import support_torch_compile
34-
from vllm.config import CacheConfig, VllmConfig
35-
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
35+
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
36+
from vllm.distributed import (get_ep_group, get_pp_group,
37+
get_tensor_model_parallel_world_size)
3638
from vllm.logger import init_logger
3739
from vllm.model_executor.layers.activation import SiluAndMul
3840
from vllm.model_executor.layers.fused_moe import FusedMoE
@@ -50,7 +52,7 @@
5052
from vllm.model_executor.sampling_metadata import SamplingMetadata
5153
from vllm.sequence import IntermediateTensors
5254

53-
from .interfaces import SupportsPP
55+
from .interfaces import SupportsPP, MixtureOfExperts
5456
from .utils import (AutoWeightsLoader, extract_layer_index,
5557
is_pp_missing_parameter,
5658
make_empty_intermediate_tensors_factory, make_layers,
@@ -101,6 +103,7 @@ def __init__(
101103
config: PretrainedConfig,
102104
quant_config: Optional[QuantizationConfig] = None,
103105
prefix: str = "",
106+
enable_eplb: bool = False,
104107
):
105108
super().__init__()
106109
self.tp_size = get_tensor_model_parallel_world_size()
@@ -110,14 +113,29 @@ def __init__(
110113
f"Tensor parallel size {self.tp_size} is greater than "
111114
f"the number of experts {config.num_experts}.")
112115

116+
self.ep_group = get_ep_group().device_group
117+
self.ep_size = self.ep_group.size()
118+
119+
vllm_config = get_current_vllm_config()
120+
parallel_config = vllm_config.parallel_config
121+
self.n_routed_experts = config.n_routed_experts
122+
self.n_redundant_experts = parallel_config.num_redundant_experts
123+
self.n_logical_experts = self.n_routed_experts
124+
self.n_physical_experts = (self.n_logical_experts +
125+
self.n_redundant_experts)
126+
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
127+
self.enable_eplb = enable_eplb
128+
113129
self.experts = FusedMoE(num_experts=config.num_experts,
114130
top_k=config.num_experts_per_tok,
115131
hidden_size=config.hidden_size,
116132
intermediate_size=config.moe_intermediate_size,
117133
reduce_results=False,
118134
renormalize=config.norm_topk_prob,
119135
quant_config=quant_config,
120-
prefix=f"{prefix}.experts")
136+
prefix=f"{prefix}.experts",
137+
enable_eplb=self.enable_eplb,
138+
num_redundant_experts=self.n_redundant_experts)
121139

122140
self.gate = ReplicatedLinear(config.hidden_size,
123141
config.num_experts,
@@ -246,6 +264,7 @@ def __init__(
246264
cache_config: Optional[CacheConfig] = None,
247265
quant_config: Optional[QuantizationConfig] = None,
248266
prefix: str = "",
267+
enable_eplb: bool = False,
249268
) -> None:
250269
super().__init__()
251270
self.hidden_size = config.hidden_size
@@ -277,7 +296,8 @@ def __init__(
277296
(layer_idx + 1) % config.decoder_sparse_step == 0):
278297
self.mlp = Qwen3MoeSparseMoeBlock(config=config,
279298
quant_config=quant_config,
280-
prefix=f"{prefix}.mlp")
299+
prefix=f"{prefix}.mlp",
300+
enable_eplb=enable_eplb)
281301
else:
282302
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
283303
intermediate_size=config.intermediate_size,
@@ -323,6 +343,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
323343
config = vllm_config.model_config.hf_config
324344
cache_config = vllm_config.cache_config
325345
quant_config = vllm_config.quant_config
346+
enable_eplb = vllm_config.parallel_config.enable_eplb
347+
self.num_redundant_experts = (
348+
vllm_config.parallel_config.num_redundant_experts)
326349

327350
self.padding_idx = config.pad_token_id
328351
self.vocab_size = config.vocab_size
@@ -336,7 +359,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
336359
lambda prefix: Qwen3MoeDecoderLayer(config=config,
337360
cache_config=cache_config,
338361
quant_config=quant_config,
339-
prefix=prefix),
362+
prefix=prefix,
363+
enable_eplb=enable_eplb),
340364
prefix=f"{prefix}.layers",
341365
)
342366
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -375,15 +399,6 @@ def forward(
375399
hidden_states, _ = self.norm(hidden_states, residual)
376400
return hidden_states
377401

378-
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
379-
# Params for weights, fp8 weight scales, fp8 activation scales
380-
# (param_name, weight_name, expert_id, shard_id)
381-
return FusedMoE.make_expert_params_mapping(
382-
ckpt_gate_proj_name="gate_proj",
383-
ckpt_down_proj_name="down_proj",
384-
ckpt_up_proj_name="up_proj",
385-
num_experts=self.config.num_experts)
386-
387402
def load_weights(self, weights: Iterable[tuple[str,
388403
torch.Tensor]]) -> set[str]:
389404
stacked_params_mapping = [
@@ -400,9 +415,17 @@ def load_weights(self, weights: Iterable[tuple[str,
400415
".v_scale", "_v_scale", ".weight_scale",
401416
"_weight_scale", ".input_scale", "_input_scale")
402417

418+
# Params for weights, fp8 weight scales, fp8 activation scales
419+
# (param_name, weight_name, expert_id, shard_id)
420+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
421+
ckpt_gate_proj_name="gate_proj",
422+
ckpt_down_proj_name="down_proj",
423+
ckpt_up_proj_name="up_proj",
424+
num_experts=self.config.num_experts,
425+
num_redundant_experts=self.num_redundant_experts)
426+
403427
params_dict = dict(self.named_parameters())
404428
loaded_params: set[str] = set()
405-
expert_params_mapping = self.get_expert_mapping()
406429
for name, loaded_weight in weights:
407430
for (param_name, weight_name, shard_id) in stacked_params_mapping:
408431
# Skip non-stacked layers and experts (experts handled below).
@@ -433,27 +456,37 @@ def load_weights(self, weights: Iterable[tuple[str,
433456
weight_loader(param, loaded_weight, shard_id)
434457
break
435458
else:
459+
is_expert_weight = False
436460
for mapping in expert_params_mapping:
437461
param_name, weight_name, expert_id, shard_id = mapping
438462
if weight_name not in name:
439463
continue
440-
name = name.replace(weight_name, param_name)
464+
465+
is_expert_weight = True
466+
467+
name_mapped = name.replace(weight_name, param_name)
468+
441469
# Skip layers on other devices.
442-
if is_pp_missing_parameter(name, self):
470+
if is_pp_missing_parameter(name_mapped, self):
443471
continue
444472
# Skip loading extra parameters for GPTQ/modelopt models.
445473
if name.endswith(
446474
ignore_suffixes) and name not in params_dict:
447475
continue
448-
param = params_dict[name]
449-
weight_loader = param.weight_loader
450-
weight_loader(param,
451-
loaded_weight,
452-
name,
453-
shard_id=shard_id,
454-
expert_id=expert_id)
455-
break
476+
param = params_dict[name_mapped]
477+
weight_loader = typing.cast(Callable[..., bool],
478+
param.weight_loader)
479+
success = weight_loader(param,
480+
loaded_weight,
481+
name_mapped,
482+
shard_id=shard_id,
483+
expert_id=expert_id,
484+
return_success=True)
485+
if success:
486+
break
456487
else:
488+
if is_expert_weight:
489+
continue
457490
# Skip loading extra parameters for GPTQ/modelopt models.
458491
if name.endswith(
459492
ignore_suffixes) and name not in params_dict:
@@ -482,7 +515,7 @@ def load_weights(self, weights: Iterable[tuple[str,
482515
return loaded_params
483516

484517

485-
class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
518+
class Qwen3MoeForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
486519
packed_modules_mapping = {
487520
"qkv_proj": [
488521
"q_proj",
@@ -513,6 +546,43 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
513546
self.logits_processor = LogitsProcessor(config.vocab_size)
514547
self.make_empty_intermediate_tensors = (
515548
self.model.make_empty_intermediate_tensors)
549+
550+
# Implement the MixtureOfExperts protocol.
551+
self.expert_weights = []
552+
553+
self.moe_layers: list[FusedMoE] = []
554+
for layer in self.model.layers:
555+
assert isinstance(layer, Qwen3MoeDecoderLayer)
556+
if isinstance(layer.mlp, Qwen3MoeDecoderLayer):
557+
self.moe_layers.append(layer.mlp.experts)
558+
self.num_moe_layers = len(self.moe_layers)
559+
560+
example_layer = typing.cast(
561+
Qwen3MoeSparseMoeBlock,
562+
self.model.layers[config.num_hidden_layers - 1].mlp)
563+
564+
self.num_expert_groups = 1
565+
self.num_logical_experts = example_layer.n_logical_experts
566+
self.num_physical_experts = example_layer.n_physical_experts
567+
self.num_local_physical_experts = example_layer.n_local_physical_experts
568+
self.num_routed_experts = example_layer.n_routed_experts
569+
self.num_shared_experts = 0
570+
self.num_redundant_experts = example_layer.n_redundant_experts
571+
572+
def set_eplb_state(
573+
self,
574+
expert_load_view: Tensor,
575+
logical_to_physical_map: Tensor,
576+
logical_replica_count: Tensor,
577+
) -> None:
578+
for layer_idx, layer in enumerate(self.moe_layers):
579+
self.expert_weights.append(layer.get_expert_weights())
580+
layer.set_eplb_state(
581+
moe_layer_idx=layer_idx,
582+
expert_load_view=expert_load_view,
583+
logical_to_physical_map=logical_to_physical_map,
584+
logical_replica_count=logical_replica_count,
585+
)
516586

517587
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
518588
return self.model.get_input_embeddings(input_ids)
@@ -540,7 +610,4 @@ def compute_logits(
540610
def load_weights(self, weights: Iterable[tuple[str,
541611
torch.Tensor]]) -> set[str]:
542612
loader = AutoWeightsLoader(self)
543-
return loader.load_weights(weights)
544-
545-
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
546-
return self.model.get_expert_mapping()
613+
return loader.load_weights(weights)

0 commit comments

Comments
 (0)