@@ -354,11 +354,7 @@ def __init__(
354354
355355 self .cutlass_fp8_supported = cutlass_fp8_supported ()
356356 self .flashinfer_moe_backend : FlashinferMoeBackend | None = None
357- if (
358- envs .VLLM_USE_FLASHINFER_MOE_FP8
359- and has_flashinfer_moe ()
360- and self .moe .is_act_and_mul
361- ):
357+ if envs .VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe ():
362358 self .flashinfer_moe_backend = get_flashinfer_moe_backend ()
363359 logger .info_once (
364360 f"Using FlashInfer { self .flashinfer_moe_backend .value } kernels"
@@ -557,7 +553,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
557553 )
558554
559555 if self .flashinfer_moe_backend is not None :
560- layer .w13_weight .data = swap_w13_to_w31 (layer .w13_weight .data )
556+ if self .moe .is_act_and_mul :
557+ layer .w13_weight .data = swap_w13_to_w31 (layer .w13_weight .data )
561558 register_moe_scaling_factors (layer )
562559 if self .flashinfer_moe_backend == FlashinferMoeBackend .TENSORRT_LLM :
563560 rotate_flashinfer_fp8_moe_weights (layer .w13_weight , layer .w2_weight )
@@ -570,13 +567,21 @@ def get_fused_moe_quant_config(
570567
571568 return fp8_w8a8_moe_quant_config (
572569 w1_scale = layer .w13_weight_scale ,
573- g1_alphas = (layer .w13_weight_scale * layer .w13_input_scale ).squeeze (),
570+ g1_alphas = layer .output1_scales_gate_scalar .squeeze ()
571+ if self .flashinfer_moe_backend == FlashinferMoeBackend .CUTLASS
572+ else None ,
574573 w2_scale = layer .w2_weight_scale ,
575- g2_alphas = (layer .w2_weight_scale * layer .w2_input_scale ).squeeze (),
574+ g2_alphas = layer .output2_scales_scalar .squeeze ()
575+ if self .flashinfer_moe_backend == FlashinferMoeBackend .CUTLASS
576+ else None ,
576577 a1_scale = layer .w13_input_scale ,
577- a1_gscale = layer .w13_input_scale ,
578+ a1_gscale = layer .w13_input_scale
579+ if self .flashinfer_moe_backend == FlashinferMoeBackend .CUTLASS
580+ else None ,
578581 a2_scale = layer .w2_input_scale ,
579- a2_gscale = 1.0 / layer .w2_input_scale ,
582+ a2_gscale = layer .w2_input_scale_inv
583+ if self .flashinfer_moe_backend == FlashinferMoeBackend .CUTLASS
584+ else None ,
580585 per_act_token_quant = False ,
581586 )
582587
@@ -660,10 +665,6 @@ def apply(
660665 apply_router_weight_on_input = apply_router_weight_on_input ,
661666 )
662667 elif self .flashinfer_moe_backend == FlashinferMoeBackend .CUTLASS :
663- assert not renormalize
664- assert activation == "silu" , (
665- f"Expected 'silu' activation but got { activation } "
666- )
667668 return flashinfer_cutlass_moe_fp8 (
668669 x ,
669670 layer ,
0 commit comments