Skip to content

Commit c595cbb

Browse files
committed
Enable cutlass fp8 kernels
Signed-off-by: Amir Klein <[email protected]>
1 parent 90515e5 commit c595cbb

File tree

2 files changed

+21
-17
lines changed

2 files changed

+21
-17
lines changed

vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,12 @@ def apply(
140140
expert_tokens_meta: mk.ExpertTokensMetadata | None,
141141
apply_router_weight_on_input: bool | None,
142142
):
143-
assert activation == "silu", (
144-
"Only activation silu is supported in FlashInferExperts"
145-
)
143+
from flashinfer.fused_moe.core import ActivationType
146144

145+
activation_str_to_value_map = {
146+
"silu": ActivationType.Swiglu, # This is the default
147+
"relu2_no_mul": ActivationType.Relu2,
148+
}
147149
if self.quant_dtype == torch.float8_e4m3fn:
148150
quant_scales = [
149151
self.g1_alphas,
@@ -193,6 +195,7 @@ def apply(
193195
ep_size=self.ep_size,
194196
ep_rank=self.ep_rank,
195197
output=output,
198+
activation_type=activation_str_to_value_map[activation],
196199
)
197200

198201

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -354,11 +354,7 @@ def __init__(
354354

355355
self.cutlass_fp8_supported = cutlass_fp8_supported()
356356
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
357-
if (
358-
envs.VLLM_USE_FLASHINFER_MOE_FP8
359-
and has_flashinfer_moe()
360-
and self.moe.is_act_and_mul
361-
):
357+
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
362358
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
363359
logger.info_once(
364360
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
@@ -557,7 +553,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
557553
)
558554

559555
if self.flashinfer_moe_backend is not None:
560-
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
556+
if self.moe.is_act_and_mul:
557+
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
561558
register_moe_scaling_factors(layer)
562559
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
563560
rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
@@ -570,13 +567,21 @@ def get_fused_moe_quant_config(
570567

571568
return fp8_w8a8_moe_quant_config(
572569
w1_scale=layer.w13_weight_scale,
573-
g1_alphas=(layer.w13_weight_scale * layer.w13_input_scale).squeeze(),
570+
g1_alphas=layer.output1_scales_gate_scalar.squeeze()
571+
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
572+
else None,
574573
w2_scale=layer.w2_weight_scale,
575-
g2_alphas=(layer.w2_weight_scale * layer.w2_input_scale).squeeze(),
574+
g2_alphas=layer.output2_scales_scalar.squeeze()
575+
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
576+
else None,
576577
a1_scale=layer.w13_input_scale,
577-
a1_gscale=layer.w13_input_scale,
578+
a1_gscale=layer.w13_input_scale
579+
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
580+
else None,
578581
a2_scale=layer.w2_input_scale,
579-
a2_gscale=1.0 / layer.w2_input_scale,
582+
a2_gscale=layer.w2_input_scale_inv
583+
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
584+
else None,
580585
per_act_token_quant=False,
581586
)
582587

@@ -660,10 +665,6 @@ def apply(
660665
apply_router_weight_on_input=apply_router_weight_on_input,
661666
)
662667
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
663-
assert not renormalize
664-
assert activation == "silu", (
665-
f"Expected 'silu' activation but got {activation}"
666-
)
667668
return flashinfer_cutlass_moe_fp8(
668669
x,
669670
layer,

0 commit comments

Comments
 (0)