Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions tests/kernels/moe/test_flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
import torch

from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
Expand All @@ -22,10 +25,10 @@
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe

if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
100
90
):
pytest.skip(
"Requires flashinfer_cutlass_fused_moe and nvfp4 support",
"Supported for sm >= 90",
allow_module_level=True,
)

Expand Down Expand Up @@ -133,6 +136,8 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
topk: int,
monkeypatch,
):
if not current_platform.has_device_capability(100):
pytest.skip("Test is only supported for sm >= 100")
current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config):
Expand Down Expand Up @@ -186,9 +191,6 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
torch.testing.assert_close(output, flashinfer_output, atol=5.5e-2, rtol=1e-2)


@pytest.mark.skip(
"Requires flashinfer version that contains https://github.com/flashinfer-ai/flashinfer/pull/1472"
)
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
Expand Down Expand Up @@ -218,9 +220,13 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(

quant_config = fp8_w8a8_moe_quant_config(
w1_scale=td.w13_weight_scale,
g1_alphas=(td.w13_weight_scale * td.a1_scale).squeeze(),
w2_scale=td.w2_weight_scale,
g2_alphas=(td.w2_weight_scale * td.a2_scale).squeeze(),
a1_scale=td.a1_scale,
a1_gscale=td.a1_scale,
a2_scale=td.a2_scale,
a2_gscale=1.0 / td.a2_scale,
per_act_token_quant=False,
)

Expand All @@ -240,6 +246,12 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(

td.layer.dp_size = 1

def get_fused_moe_quant_config(n: torch.nn.Module) -> FusedMoEQuantConfig:
return quant_config

td.layer.get_fused_moe_quant_config = get_fused_moe_quant_config
td.layer.quant_method = td.layer

flashinfer_cutlass_output = flashinfer_cutlass_moe_fp8(
td.hidden_states,
td.layer,
Expand Down
8 changes: 8 additions & 0 deletions vllm/model_executor/layers/fused_moe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,16 +463,24 @@ def fp8_w8a8_moe_quant_config(
per_act_token_quant: bool = False,
per_out_ch_quant: bool = False,
block_shape: list[int] | None = None,
a1_gscale: torch.Tensor | None = None,
a2_gscale: torch.Tensor | None = None,
g1_alphas: torch.Tensor | None = None,
g2_alphas: torch.Tensor | None = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for fp8 activations and fp8 weights.
"""
return FusedMoEQuantConfig.make(
torch.float8_e4m3fn,
w1_scale=w1_scale,
g1_alphas=g1_alphas,
w2_scale=w2_scale,
g2_alphas=g2_alphas,
a1_scale=a1_scale,
a1_gscale=a1_gscale,
a2_scale=a2_scale,
a2_gscale=a2_gscale,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_out_ch_quant,
block_shape=block_shape,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,12 @@ def apply(
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool | None,
):
assert activation == "silu", (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe worth asserting that activation is one of the keys in activation_str_to_value_map?

"Only activation silu is supported in FlashInferExperts"
)
from flashinfer.fused_moe.core import ActivationType

activation_str_to_value_map = {
"silu": ActivationType.Swiglu, # This is the default
"relu2_no_mul": ActivationType.Relu2,
}
if self.quant_dtype == torch.float8_e4m3fn:
quant_scales = [
self.g1_alphas,
Expand Down Expand Up @@ -193,6 +195,7 @@ def apply(
ep_size=self.ep_size,
ep_rank=self.ep_rank,
output=output,
activation_type=activation_str_to_value_map[activation],
)


Expand Down
33 changes: 19 additions & 14 deletions vllm/model_executor/layers/quantization/modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,11 +354,7 @@ def __init__(

self.cutlass_fp8_supported = cutlass_fp8_supported()
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
if (
envs.VLLM_USE_FLASHINFER_MOE_FP8
and has_flashinfer_moe()
and self.moe.is_act_and_mul
):
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
logger.info_once(
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
Expand Down Expand Up @@ -557,7 +553,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
)

if self.flashinfer_moe_backend is not None:
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
if self.moe.is_act_and_mul:
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
register_moe_scaling_factors(layer)
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
Expand All @@ -570,9 +567,21 @@ def get_fused_moe_quant_config(

return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
g1_alphas=layer.output1_scales_gate_scalar.squeeze()
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
else None,
w2_scale=layer.w2_weight_scale,
g2_alphas=layer.output2_scales_scalar.squeeze()
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
else None,
a1_scale=layer.w13_input_scale,
a1_gscale=layer.w13_input_scale
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
else None,
a2_scale=layer.w2_input_scale,
a2_gscale=layer.w2_input_scale_inv
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
else None,
per_act_token_quant=False,
)

Expand Down Expand Up @@ -656,10 +665,6 @@ def apply(
apply_router_weight_on_input=apply_router_weight_on_input,
)
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
assert not renormalize
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. why did you remove the assert that renormalize is not True?
  2. Maybe worth asserting that activation is either "silu" or "relu2"?

assert activation == "silu", (
f"Expected 'silu' activation but got {activation}"
)
return flashinfer_cutlass_moe_fp8(
x,
layer,
Expand Down Expand Up @@ -1159,8 +1164,8 @@ def __init__(
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> None:
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
detect_nvfp4_moe_support,
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import (
detect_nvfp4_moe_support, # noqa: E501
)

super().__init__(moe)
Expand Down Expand Up @@ -1773,8 +1778,8 @@ def apply(
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
):
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
flashinfer_cutlass_moe_fp4,
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
flashinfer_cutlass_moe_fp4, # noqa: E501
)

assert self.moe_quant_config is not None
Expand Down