1
1
# coding=utf-8
2
- """Inference-only Jurassic model."""
2
+ """Inference-only Jamba model."""
3
3
from dataclasses import dataclass
4
4
from typing import Dict , Iterable , List , Optional , Tuple
5
5
15
15
from vllm .attention .layer import Attention
16
16
from vllm .config import CacheConfig , LoRAConfig , SchedulerConfig
17
17
from vllm .distributed import (get_tensor_model_parallel_rank ,
18
- get_tensor_model_parallel_world_size ,
19
- tensor_model_parallel_all_reduce )
18
+ get_tensor_model_parallel_world_size )
20
19
from vllm .model_executor .layers .activation import SiluAndMul
21
- from vllm .model_executor .layers .fused_moe import fused_moe
20
+ from vllm .model_executor .layers .fused_moe import FusedMoE
22
21
from vllm .model_executor .layers .layernorm import RMSNorm
23
22
from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
24
23
MergedColumnParallelLinear ,
@@ -282,108 +281,50 @@ def forward(self, x):
282
281
283
282
284
283
class JambaMoE (nn .Module ):
285
- """A tensor-parallel MoE implementation for Mixtral that shards each expert
286
- across all ranks.
287
284
288
- Each expert's weights are sharded across all ranks and a fused MoE
289
- kernel is used for the forward pass, and finally we reduce the outputs
290
- across ranks.
291
- """
292
-
293
- def __init__ (
294
- self ,
295
- config : JambaConfig ,
296
- params_dtype : Optional [torch .dtype ] = None ,
297
- tp_size : Optional [int ] = None ,
298
- quant_config : Optional [QuantizationConfig ] = None ,
299
- ):
285
+ def __init__ (self ,
286
+ config : JambaConfig ,
287
+ num_experts : Optional [int ] = None ,
288
+ top_k : Optional [int ] = None ,
289
+ params_dtype : Optional [torch .dtype ] = None ,
290
+ tp_size : Optional [int ] = None ,
291
+ quant_config : Optional [QuantizationConfig ] = None ):
300
292
super ().__init__ ()
301
- self .tp_size = tp_size or get_tensor_model_parallel_world_size ()
302
- self .num_total_experts = config .num_experts
303
- self .top_k = config .num_experts_per_tok
293
+ self .num_total_experts = num_experts or config .num_experts
294
+ self .top_k = top_k or config .num_experts_per_tok
304
295
self .hidden_size = config .hidden_size
305
- self .intermediate_size = config .intermediate_size // self .tp_size
306
-
307
- if params_dtype is None :
308
- params_dtype = torch .get_default_dtype ()
309
- self .params_dtype = params_dtype
310
-
311
- self .router = ReplicatedLinear (self .hidden_size ,
312
- self .num_total_experts ,
313
- bias = False ,
314
- params_dtype = self .params_dtype )
315
-
316
- self .ws = nn .Parameter (
317
- torch .empty (
318
- self .num_total_experts ,
319
- 2 * self .intermediate_size ,
320
- self .hidden_size ,
321
- device = "cuda" ,
322
- dtype = self .params_dtype ,
323
- ))
324
- self .w2s = nn .Parameter (
325
- torch .empty (
326
- self .num_total_experts ,
327
- self .hidden_size ,
328
- self .intermediate_size ,
329
- device = "cuda" ,
330
- dtype = self .params_dtype ,
331
- ))
296
+ self .intermediate_size = config .intermediate_size
332
297
333
- set_weight_attrs (
334
- self .ws ,
335
- {
336
- "weight_loader" : self .weight_loader ,
337
- },
338
- )
339
- set_weight_attrs (
340
- self .w2s ,
341
- {
342
- "weight_loader" : self .weight_loader ,
343
- },
344
- )
345
-
346
- def weight_loader (
347
- self ,
348
- param : nn .Parameter ,
349
- loaded_weight : torch .Tensor ,
350
- weight_name : str ,
351
- expert_id : int ,
352
- ):
353
- tp_rank = get_tensor_model_parallel_rank ()
354
- param_data = param .data
355
- shard_size = self .intermediate_size
356
- shard = slice (tp_rank * shard_size , (tp_rank + 1 ) * shard_size )
357
- if weight_name .endswith ("gate_proj.weight" ):
358
- param_data [expert_id , 0 :shard_size , :] = loaded_weight [shard , :]
359
- if weight_name .endswith ("up_proj.weight" ):
360
- param_data [expert_id ,
361
- shard_size :2 * shard_size , :] = loaded_weight [shard , :]
362
- if weight_name .endswith ("down_proj.weight" ):
363
- param_data [expert_id , :, :] = loaded_weight [:, shard ]
298
+ if self .num_total_experts > 1 :
299
+ self .router = ReplicatedLinear (self .hidden_size ,
300
+ self .num_total_experts ,
301
+ bias = False ,
302
+ quant_config = None ,
303
+ params_dtype = params_dtype )
304
+
305
+ self .experts = FusedMoE (self .num_total_experts ,
306
+ self .top_k ,
307
+ self .hidden_size ,
308
+ self .intermediate_size ,
309
+ tp_size = tp_size ,
310
+ params_dtype = params_dtype ,
311
+ reduce_results = True ,
312
+ renormalize = False ,
313
+ use_grouped_topk = False ,
314
+ quant_config = quant_config )
364
315
365
316
def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
366
- num_tokens , hidden_size = hidden_states .shape
317
+ orig_shape = hidden_states .shape
367
318
hidden_states = hidden_states .view (- 1 , self .hidden_size )
368
319
# router_logits: (batch * sequence_length, n_experts)
369
- router_logits , _ = self .router (hidden_states )
370
-
371
- final_hidden_states = fused_moe (
372
- hidden_states ,
373
- self .ws ,
374
- self .w2s ,
375
- router_logits ,
376
- self .top_k ,
377
- renormalize =
378
- False , # Mixtral normalize the expert probs to 1. We don't!
379
- inplace = True ,
380
- )
381
-
382
- if self .tp_size > 1 :
383
- final_hidden_states = tensor_model_parallel_all_reduce (
384
- final_hidden_states )
385
-
386
- return final_hidden_states .view (num_tokens , hidden_size )
320
+ if self .num_total_experts > 1 :
321
+ router_logits , _ = self .router (hidden_states )
322
+ else :
323
+ router_logits = torch .ones ((hidden_states .shape [0 ], 1 ),
324
+ device = hidden_states .device ,
325
+ dtype = hidden_states .dtype )
326
+ hidden_states = self .experts (hidden_states , router_logits )
327
+ return hidden_states .view (orig_shape )
387
328
388
329
389
330
class JambaMambaDecoderLayer (nn .Module ):
@@ -917,15 +858,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
917
858
("gate_up_proj" , "up_proj" , 1 ),
918
859
]
919
860
920
- expert_params_mapping = [
921
- # (param_name, weight_name, expert_id)
922
- (
923
- "ws" if weight_name in ["gate_proj" , "up_proj" ] else "w2s" ,
924
- f"experts.{ expert_id } .{ weight_name } .weight" ,
925
- expert_id ,
926
- ) for expert_id in range (self .config .num_experts )
927
- for weight_name in ["down_proj" , "up_proj" , "gate_proj" ]
928
- ]
861
+ # Params for weights, fp8 weight scales, fp8 activation scales
862
+ # (param_name, weight_name, expert_id, shard_id)
863
+ expert_params_mapping = FusedMoE .make_expert_params_mapping (
864
+ ckpt_gate_proj_name = "gate_proj" ,
865
+ ckpt_down_proj_name = "down_proj" ,
866
+ ckpt_up_proj_name = "up_proj" ,
867
+ num_experts = self .config .num_experts )
929
868
930
869
params_dict = dict (self .named_parameters ())
931
870
for name , loaded_weight in weights :
@@ -952,7 +891,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
952
891
weight_loader (param , loaded_weight , shard_id )
953
892
break
954
893
else :
955
- for param_name , weight_name , expert_id in expert_params_mapping :
894
+ for mapping in expert_params_mapping :
895
+ param_name , weight_name , expert_id , shard_id = mapping
956
896
if weight_name not in name :
957
897
continue
958
898
name = name .replace (weight_name , param_name )
@@ -961,6 +901,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
961
901
weight_loader (param ,
962
902
loaded_weight ,
963
903
weight_name ,
904
+ shard_id = shard_id ,
964
905
expert_id = expert_id )
965
906
break
966
907
else :
0 commit comments