Skip to content

Commit 2ee8d3b

Browse files
authored
[Model] use FusedMoE layer in Jamba (vllm-project#6935)
1 parent daed30c commit 2ee8d3b

File tree

1 file changed

+49
-108
lines changed

1 file changed

+49
-108
lines changed

vllm/model_executor/models/jamba.py

Lines changed: 49 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# coding=utf-8
2-
"""Inference-only Jurassic model."""
2+
"""Inference-only Jamba model."""
33
from dataclasses import dataclass
44
from typing import Dict, Iterable, List, Optional, Tuple
55

@@ -15,10 +15,9 @@
1515
from vllm.attention.layer import Attention
1616
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
1717
from vllm.distributed import (get_tensor_model_parallel_rank,
18-
get_tensor_model_parallel_world_size,
19-
tensor_model_parallel_all_reduce)
18+
get_tensor_model_parallel_world_size)
2019
from vllm.model_executor.layers.activation import SiluAndMul
21-
from vllm.model_executor.layers.fused_moe import fused_moe
20+
from vllm.model_executor.layers.fused_moe import FusedMoE
2221
from vllm.model_executor.layers.layernorm import RMSNorm
2322
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
2423
MergedColumnParallelLinear,
@@ -282,108 +281,50 @@ def forward(self, x):
282281

283282

284283
class JambaMoE(nn.Module):
285-
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
286-
across all ranks.
287284

288-
Each expert's weights are sharded across all ranks and a fused MoE
289-
kernel is used for the forward pass, and finally we reduce the outputs
290-
across ranks.
291-
"""
292-
293-
def __init__(
294-
self,
295-
config: JambaConfig,
296-
params_dtype: Optional[torch.dtype] = None,
297-
tp_size: Optional[int] = None,
298-
quant_config: Optional[QuantizationConfig] = None,
299-
):
285+
def __init__(self,
286+
config: JambaConfig,
287+
num_experts: Optional[int] = None,
288+
top_k: Optional[int] = None,
289+
params_dtype: Optional[torch.dtype] = None,
290+
tp_size: Optional[int] = None,
291+
quant_config: Optional[QuantizationConfig] = None):
300292
super().__init__()
301-
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
302-
self.num_total_experts = config.num_experts
303-
self.top_k = config.num_experts_per_tok
293+
self.num_total_experts = num_experts or config.num_experts
294+
self.top_k = top_k or config.num_experts_per_tok
304295
self.hidden_size = config.hidden_size
305-
self.intermediate_size = config.intermediate_size // self.tp_size
306-
307-
if params_dtype is None:
308-
params_dtype = torch.get_default_dtype()
309-
self.params_dtype = params_dtype
310-
311-
self.router = ReplicatedLinear(self.hidden_size,
312-
self.num_total_experts,
313-
bias=False,
314-
params_dtype=self.params_dtype)
315-
316-
self.ws = nn.Parameter(
317-
torch.empty(
318-
self.num_total_experts,
319-
2 * self.intermediate_size,
320-
self.hidden_size,
321-
device="cuda",
322-
dtype=self.params_dtype,
323-
))
324-
self.w2s = nn.Parameter(
325-
torch.empty(
326-
self.num_total_experts,
327-
self.hidden_size,
328-
self.intermediate_size,
329-
device="cuda",
330-
dtype=self.params_dtype,
331-
))
296+
self.intermediate_size = config.intermediate_size
332297

333-
set_weight_attrs(
334-
self.ws,
335-
{
336-
"weight_loader": self.weight_loader,
337-
},
338-
)
339-
set_weight_attrs(
340-
self.w2s,
341-
{
342-
"weight_loader": self.weight_loader,
343-
},
344-
)
345-
346-
def weight_loader(
347-
self,
348-
param: nn.Parameter,
349-
loaded_weight: torch.Tensor,
350-
weight_name: str,
351-
expert_id: int,
352-
):
353-
tp_rank = get_tensor_model_parallel_rank()
354-
param_data = param.data
355-
shard_size = self.intermediate_size
356-
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
357-
if weight_name.endswith("gate_proj.weight"):
358-
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
359-
if weight_name.endswith("up_proj.weight"):
360-
param_data[expert_id,
361-
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
362-
if weight_name.endswith("down_proj.weight"):
363-
param_data[expert_id, :, :] = loaded_weight[:, shard]
298+
if self.num_total_experts > 1:
299+
self.router = ReplicatedLinear(self.hidden_size,
300+
self.num_total_experts,
301+
bias=False,
302+
quant_config=None,
303+
params_dtype=params_dtype)
304+
305+
self.experts = FusedMoE(self.num_total_experts,
306+
self.top_k,
307+
self.hidden_size,
308+
self.intermediate_size,
309+
tp_size=tp_size,
310+
params_dtype=params_dtype,
311+
reduce_results=True,
312+
renormalize=False,
313+
use_grouped_topk=False,
314+
quant_config=quant_config)
364315

365316
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
366-
num_tokens, hidden_size = hidden_states.shape
317+
orig_shape = hidden_states.shape
367318
hidden_states = hidden_states.view(-1, self.hidden_size)
368319
# router_logits: (batch * sequence_length, n_experts)
369-
router_logits, _ = self.router(hidden_states)
370-
371-
final_hidden_states = fused_moe(
372-
hidden_states,
373-
self.ws,
374-
self.w2s,
375-
router_logits,
376-
self.top_k,
377-
renormalize=
378-
False, # Mixtral normalize the expert probs to 1. We don't!
379-
inplace=True,
380-
)
381-
382-
if self.tp_size > 1:
383-
final_hidden_states = tensor_model_parallel_all_reduce(
384-
final_hidden_states)
385-
386-
return final_hidden_states.view(num_tokens, hidden_size)
320+
if self.num_total_experts > 1:
321+
router_logits, _ = self.router(hidden_states)
322+
else:
323+
router_logits = torch.ones((hidden_states.shape[0], 1),
324+
device=hidden_states.device,
325+
dtype=hidden_states.dtype)
326+
hidden_states = self.experts(hidden_states, router_logits)
327+
return hidden_states.view(orig_shape)
387328

388329

389330
class JambaMambaDecoderLayer(nn.Module):
@@ -917,15 +858,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
917858
("gate_up_proj", "up_proj", 1),
918859
]
919860

920-
expert_params_mapping = [
921-
# (param_name, weight_name, expert_id)
922-
(
923-
"ws" if weight_name in ["gate_proj", "up_proj"] else "w2s",
924-
f"experts.{expert_id}.{weight_name}.weight",
925-
expert_id,
926-
) for expert_id in range(self.config.num_experts)
927-
for weight_name in ["down_proj", "up_proj", "gate_proj"]
928-
]
861+
# Params for weights, fp8 weight scales, fp8 activation scales
862+
# (param_name, weight_name, expert_id, shard_id)
863+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
864+
ckpt_gate_proj_name="gate_proj",
865+
ckpt_down_proj_name="down_proj",
866+
ckpt_up_proj_name="up_proj",
867+
num_experts=self.config.num_experts)
929868

930869
params_dict = dict(self.named_parameters())
931870
for name, loaded_weight in weights:
@@ -952,7 +891,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
952891
weight_loader(param, loaded_weight, shard_id)
953892
break
954893
else:
955-
for param_name, weight_name, expert_id in expert_params_mapping:
894+
for mapping in expert_params_mapping:
895+
param_name, weight_name, expert_id, shard_id = mapping
956896
if weight_name not in name:
957897
continue
958898
name = name.replace(weight_name, param_name)
@@ -961,6 +901,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
961901
weight_loader(param,
962902
loaded_weight,
963903
weight_name,
904+
shard_id=shard_id,
964905
expert_id=expert_id)
965906
break
966907
else:

0 commit comments

Comments
 (0)