Skip to content

[Kernel] Enable FP16 and BF16 CUTLASS MoE kernels #15932

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/permute_cols.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
"csrc/cutlass_moe/moe_mm_entry.cu"
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
Expand Down Expand Up @@ -471,21 +472,21 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# to compile MoE kernels that use its output.
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu"
"csrc/quantization/cutlass_w8a8/moe/moe_data.cu")
set(SRCS "csrc/cutlass_moe/moe_mm_c3x.cu"
"csrc/cutlass_moe/moe_data.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM90=1")
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
message(STATUS "Building moe_mm_c3x for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
message(STATUS "Not building moe_mm_c3x kernels as CUDA Compiler version is "
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
"if you intend on running FP8 quantized MoE models on Hopper.")
else()
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
message(STATUS "Not building moe_mm_c3x as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe_fp8,
from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe,
fused_experts,
fused_topk)
from vllm.utils import FlexibleArgumentParser
Expand All @@ -18,8 +18,8 @@
DEFAULT_BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512]
DEFAULT_TP_SIZES = [1]

PER_ACT_TOKEN_OPTS = [False]
PER_OUT_CH_OPTS = [False]
PER_ACT_TOKEN_OPTS = [False, True]
PER_OUT_CH_OPTS = [False, True]


def to_fp8(tensor: torch.Tensor):
Expand Down Expand Up @@ -48,18 +48,21 @@ def bench_run(results: list[benchmark.Measurement], model: str,
w1 = torch.randn((num_experts, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((num_experts, k, n), device="cuda", dtype=dtype) / 10

_, a_scale = ops.scaled_fp8_quant(a)
_, a_scale = ops.scaled_fp8_quant(a,
use_per_token_if_dynamic=per_act_token)

w1_q = torch.empty((num_experts, 2 * n, k),
device="cuda",
dtype=torch.float8_e4m3fn)
w2_q = torch.empty((num_experts, k, n),
device="cuda",
dtype=torch.float8_e4m3fn)
w1_scale = torch.empty((num_experts, 1, 1),
n_b_scales = 2 * n if per_out_ch else 1
k_b_scales = k if per_out_ch else 1
w1_scale = torch.empty((num_experts, n_b_scales, 1),
device="cuda",
dtype=torch.float32)
w2_scale = torch.empty((num_experts, 1, 1),
w2_scale = torch.empty((num_experts, k_b_scales, 1),
device="cuda",
dtype=torch.float32)

Expand All @@ -81,8 +84,10 @@ def bench_run(results: list[benchmark.Measurement], model: str,
dtype=torch.int64)

for expert in range(num_experts):
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert])
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert])
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
w1[expert], use_per_token_if_dynamic=per_out_ch)
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
w2[expert], use_per_token_if_dynamic=per_out_ch)
w1_q_notransp = w1_q.clone()
w2_q_notransp = w2_q.clone()
w1_q = w1_q.transpose(1, 2)
Expand All @@ -105,7 +110,8 @@ def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_scale)
a1_scale=a_scale,
a2_scale=a_scale)

def run_cutlass_moe(a: torch.Tensor, a_scale: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor,
Expand All @@ -115,18 +121,18 @@ def run_cutlass_moe(a: torch.Tensor, a_scale: torch.Tensor,
ab_strides2: torch.Tensor, c_strides2: torch.Tensor,
num_repeats: int):
for _ in range(num_repeats):
cutlass_moe_fp8(a,
w1,
w2,
w1_scale,
w2_scale,
topk_weights,
topk_ids,
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
a1_scale=a_scale)
cutlass_moe(a,
w1,
w2,
topk_weights,
topk_ids,
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_scale)

def run_cutlass_from_graph(
a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor,
Expand All @@ -137,18 +143,18 @@ def run_cutlass_from_graph(
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
return cutlass_moe_fp8(a,
w1_q,
w2_q,
w1_scale,
w2_scale,
topk_weights,
topk_ids,
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
a1_scale=a_scale)
return cutlass_moe(a,
w1_q,
w2_q,
topk_weights,
topk_ids,
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_scale)

def run_triton_from_graph(a: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, topk_weights: torch.Tensor,
Expand All @@ -165,7 +171,8 @@ def run_triton_from_graph(a: torch.Tensor, w1: torch.Tensor,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_scale)
a1_scale=a_scale,
a2_scale=a_scale)

def replay_graph(graph, num_repeats):
for _ in range(num_repeats):
Expand All @@ -180,12 +187,16 @@ def replay_graph(graph, num_repeats):
ab_strides2, c_strides2)
torch.cuda.synchronize()

triton_stream = torch.cuda.Stream()
triton_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(triton_graph, stream=triton_stream):
run_triton_from_graph(a, w1_q_notransp, w2_q_notransp, topk_weights,
topk_ids, w1_scale, w2_scale, a_scale)
torch.cuda.synchronize()
if not per_act_token and not per_out_ch:
triton_stream = torch.cuda.Stream()
triton_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(triton_graph, stream=triton_stream):
run_triton_from_graph(a, w1_q_notransp, w2_q_notransp,
topk_weights, topk_ids, w1_scale, w2_scale,
a_scale)
torch.cuda.synchronize()
else:
triton_graph = []

min_run_time = 5
num_warmup = 5
Expand Down Expand Up @@ -223,31 +234,32 @@ def replay_graph(graph, num_repeats):
"replay_graph": replay_graph,
}

# Warmup
run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids,
w1_scale, w2_scale, a_scale, num_warmup)

results.append(
benchmark.Timer(
stmt=
"run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="triton_moe",
).blocked_autorange(min_run_time=min_run_time))

# Warmup
replay_graph(triton_graph, num_warmup)

results.append(
benchmark.Timer(
stmt="replay_graph(triton_graph, num_runs)",
globals=globals,
label=label,
sub_label=sub_label,
description="triton_moe_cuda_graphs",
).blocked_autorange(min_run_time=min_run_time))
if not per_act_token and not per_out_ch:
# Warmup
run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids,
w1_scale, w2_scale, a_scale, num_warmup)

results.append(
benchmark.Timer(
stmt=
"run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="triton_moe",
).blocked_autorange(min_run_time=min_run_time))

# Warmup
replay_graph(triton_graph, num_warmup)

results.append(
benchmark.Timer(
stmt="replay_graph(triton_graph, num_runs)",
globals=globals,
label=label,
sub_label=sub_label,
description="triton_moe_cuda_graphs",
).blocked_autorange(min_run_time=min_run_time))

# Warmup
run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

constexpr uint64_t THREADS_PER_EXPERT = 512;

__global__ void compute_problem_sizes(const int* __restrict__ topk_ids,
__global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
int32_t* problem_sizes1,
int32_t* problem_sizes2,
int32_t* atomic_buffer,
Expand Down Expand Up @@ -45,16 +45,28 @@ __global__ void compute_expert_offsets(
}
}

__global__ void compute_arg_sorts(const int* __restrict__ topk_ids,
__global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
const int32_t* __restrict__ expert_offsets,
int32_t* input_permutation,
int32_t* output_permutation,
int32_t* atomic_buffer, const int topk_length,
const int topk) {
int expert_id = blockIdx.x;
int blk_expert_id = blockIdx.x;
int const num_experts = gridDim.x;
int32_t const num_tokens = expert_offsets[num_experts];

for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
if (topk_ids[i] == expert_id) {
int start = atomicAdd(&atomic_buffer[expert_id], 1);
int const expert_id = topk_ids[i];
if (expert_id == -1 && blockIdx.x == 0) {
// output_permutation is used to re-order the moe outputs. It is
// used as c2 = c2[c_map], where c2 is a torch.tensor that is the
// output of the cutlass kernels and c_map is the output_permutation.
// c2 is initialized to zeros, therefore by setting the output_permutation
// to num_tokens, we are guaranteed to fill the moe outputs to zero
// for "invalid" topk_ids.
output_permutation[i] = num_tokens;
} else if (expert_id == blk_expert_id) {
int start = atomicAdd(&atomic_buffer[blk_expert_id], 1);
input_permutation[start] = i / topk;
output_permutation[i] = start;
}
Expand Down Expand Up @@ -83,6 +95,7 @@ void get_cutlass_moe_mm_data_caller(
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(input_permutation.data_ptr()),
static_cast<int32_t*>(output_permutation.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(),
Expand Down
Loading