Skip to content

Commit 6cd5e5b

Browse files
authored
[Misc] Fused MoE Marlin support for GPTQ (vllm-project#8217)
1 parent c7cb5c3 commit 6cd5e5b

File tree

19 files changed

+912
-204
lines changed

19 files changed

+912
-204
lines changed

.buildkite/test-pipeline.yaml

+12-1
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,18 @@ steps:
386386
- vllm/
387387
- tests/weight_loading
388388
commands:
389-
- bash weight_loading/run_model_weight_loading_test.sh
389+
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt
390+
391+
- label: Weight Loading Multiple GPU Test - Large Models # optional
392+
working_dir: "/vllm-workspace/tests"
393+
num_gpus: 2
394+
gpu: a100
395+
optional: true
396+
source_file_dependencies:
397+
- vllm/
398+
- tests/weight_loading
399+
commands:
400+
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt
390401

391402

392403
##### multi gpus test #####

csrc/moe/marlin_moe_ops.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -1737,4 +1737,4 @@ torch::Tensor marlin_gemm_moe(
17371737
moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k,
17381738
thread_n, sms, max_par, replicate_input, apply_weights);
17391739
return c;
1740-
}
1740+
}

csrc/moe/marlin_moe_ops.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ torch::Tensor marlin_gemm_moe(
99
const torch::Tensor& g_idx, const torch::Tensor& perm,
1010
torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k,
1111
bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size,
12-
bool replicate_input, bool apply_weights);
12+
bool replicate_input, bool apply_weights);

csrc/moe/torch_bindings.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
1616
"g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int "
1717
"size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, "
1818
"bool replicate_input, bool apply_weights) -> Tensor");
19-
2019
m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe);
2120
#endif
2221
}

tests/kernels/test_moe.py

+217-4
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,22 @@
22
33
Run `pytest tests/kernels/test_moe.py`.
44
"""
5+
from typing import List
6+
57
import pytest
68
import torch
79
from transformers import MixtralConfig
810
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
911

1012
from vllm.model_executor.layers.activation import SiluAndMul
1113
from vllm.model_executor.layers.fused_moe import fused_moe
14+
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
15+
fused_marlin_moe, single_marlin_moe)
16+
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
17+
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
18+
marlin_quantize)
1219
from vllm.model_executor.models.mixtral import MixtralMoE
20+
from vllm.scalar_type import scalar_types
1321

1422

1523
def torch_moe(a, w1, w2, score, topk):
@@ -29,6 +37,20 @@ def torch_moe(a, w1, w2, score, topk):
2937
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
3038

3139

40+
def torch_moe_single(a, w, score, topk):
41+
B, D = a.shape
42+
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
43+
out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device)
44+
score = torch.softmax(score, dim=-1, dtype=torch.float32)
45+
_, topk_ids = torch.topk(score, topk)
46+
topk_ids = topk_ids.view(-1)
47+
for i in range(w.shape[0]):
48+
mask = topk_ids == i
49+
if mask.sum():
50+
out[mask] = a[mask] @ w[i].transpose(0, 1)
51+
return (out.view(B, -1, w.shape[1])).sum(dim=1)
52+
53+
3254
@pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
3355
@pytest.mark.parametrize("n", [2048, 256, 1024])
3456
@pytest.mark.parametrize("k", [128, 511, 1024])
@@ -43,11 +65,11 @@ def test_fused_moe(
4365
topk: int,
4466
dtype: torch.dtype,
4567
):
46-
a = torch.randn((m, k), device='cuda', dtype=dtype) / 10
47-
w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10
48-
w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10
68+
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
69+
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
70+
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
4971

50-
score = torch.randn((m, e), device='cuda', dtype=dtype)
72+
score = torch.randn((m, e), device="cuda", dtype=dtype)
5173
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
5274
torch_output = torch_moe(a, w1, w2, score, topk)
5375
torch.testing.assert_close(triton_output, torch_output, atol=1e-2, rtol=0)
@@ -99,3 +121,194 @@ def test_mixtral_moe(dtype: torch.dtype):
99121
vllm_states,
100122
rtol=mixtral_moe_tol[dtype],
101123
atol=mixtral_moe_tol[dtype])
124+
125+
126+
def stack_and_dev(tensors: List[torch.Tensor]):
127+
dev = tensors[0].device
128+
return torch.stack(tensors, dim=0).to(dev)
129+
130+
131+
def compute_max_diff(output, output_ref):
132+
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
133+
torch.abs(output_ref))
134+
135+
136+
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
137+
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
138+
@pytest.mark.parametrize("k", [128, 1024, 512])
139+
@pytest.mark.parametrize("e", [4, 8, 64])
140+
@pytest.mark.parametrize("topk", [2, 6])
141+
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
142+
@pytest.mark.parametrize("act_order", [True, False])
143+
def test_fused_marlin_moe(
144+
m: int,
145+
n: int,
146+
k: int,
147+
e: int,
148+
topk: int,
149+
group_size: int,
150+
act_order: bool,
151+
):
152+
torch.manual_seed(7)
153+
154+
if topk > e:
155+
return
156+
157+
# Filter act_order
158+
if act_order:
159+
if group_size == -1:
160+
return
161+
if group_size in (k, n):
162+
return
163+
164+
quant_type = scalar_types.uint4b8
165+
dtype = torch.float16
166+
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
167+
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
168+
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
169+
for i in range(w2.shape[0]):
170+
w2[0] = torch.eye(k, n, device="cuda", dtype=dtype)
171+
172+
w_ref1_l = []
173+
qweight1_l = []
174+
scales1_l = []
175+
g_idx1_l = []
176+
sort_indices1_l = []
177+
178+
for i in range(w1.shape[0]):
179+
test_perm = torch.randperm(k)
180+
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
181+
w1[i].transpose(1, 0), quant_type, group_size, act_order,
182+
test_perm)
183+
w_ref1_l.append(w_ref1)
184+
qweight1_l.append(qweight1)
185+
scales1_l.append(scales1)
186+
g_idx1_l.append(g_idx1)
187+
sort_indices1_l.append(sort_indices1)
188+
189+
w_ref1 = stack_and_dev(w_ref1_l)
190+
qweight1 = stack_and_dev(qweight1_l).contiguous()
191+
scales1 = stack_and_dev(scales1_l)
192+
g_idx1 = stack_and_dev(g_idx1_l)
193+
sort_indices1 = stack_and_dev(sort_indices1_l)
194+
195+
w_ref2_l = []
196+
qweight2_l = []
197+
scales2_l = []
198+
g_idx2_l = []
199+
sort_indices2_l = []
200+
201+
for i in range(w2.shape[0]):
202+
test_perm = torch.randperm(n)
203+
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
204+
w2[i].transpose(1, 0), quant_type, group_size, act_order,
205+
test_perm)
206+
w_ref2_l.append(w_ref2)
207+
qweight2_l.append(qweight2)
208+
scales2_l.append(scales2)
209+
g_idx2_l.append(g_idx2)
210+
sort_indices2_l.append(sort_indices2)
211+
212+
w_ref2 = stack_and_dev(w_ref2_l)
213+
qweight2 = stack_and_dev(qweight2_l).contiguous()
214+
scales2 = stack_and_dev(scales2_l)
215+
g_idx2 = stack_and_dev(g_idx2_l)
216+
sort_indices2 = stack_and_dev(sort_indices2_l)
217+
218+
score = torch.randn((m, e), device="cuda", dtype=dtype)
219+
220+
topk_weights, topk_ids = fused_topk(a, score, topk, False)
221+
222+
triton_output = fused_moe(
223+
a,
224+
w_ref1.transpose(1, 2).contiguous(),
225+
w_ref2.transpose(1, 2).contiguous(),
226+
score,
227+
topk,
228+
renormalize=False,
229+
)
230+
marlin_output = fused_marlin_moe(
231+
a,
232+
qweight1,
233+
qweight2,
234+
score,
235+
g_idx1,
236+
g_idx2,
237+
sort_indices1,
238+
sort_indices2,
239+
topk_weights,
240+
topk_ids,
241+
w1_scale=scales1,
242+
w2_scale=scales2,
243+
)
244+
245+
assert compute_max_diff(marlin_output, triton_output) < 4e-2
246+
247+
248+
@pytest.mark.skip("This test is here for the sake of debugging, "
249+
"don't run it in automated tests.")
250+
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
251+
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
252+
@pytest.mark.parametrize("k", [128, 1024, 512])
253+
@pytest.mark.parametrize("e", [4, 8, 64])
254+
@pytest.mark.parametrize("topk", [2, 6])
255+
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
256+
@pytest.mark.parametrize("act_order", [True, False])
257+
def test_marlin_moe_mmm(
258+
m: int,
259+
n: int,
260+
k: int,
261+
e: int,
262+
topk: int,
263+
group_size: int,
264+
act_order: bool,
265+
):
266+
if topk > e:
267+
return
268+
269+
# Filter act_order
270+
if act_order:
271+
if group_size == -1:
272+
return
273+
if group_size == k:
274+
return
275+
276+
quant_type = scalar_types.uint4b8
277+
dtype = torch.float16
278+
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
279+
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
280+
281+
w_ref_l = []
282+
qweights_l = []
283+
scales_l = []
284+
g_idx_l = []
285+
sort_indices_l = []
286+
287+
for i in range(w.shape[0]):
288+
test_perm = torch.randperm(k)
289+
w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
290+
w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm)
291+
w_ref_l.append(w_ref)
292+
qweights_l.append(qweight)
293+
scales_l.append(scales)
294+
g_idx_l.append(g_idx)
295+
sort_indices_l.append(sort_indices)
296+
297+
w_ref = stack_and_dev(w_ref_l)
298+
qweight = stack_and_dev(qweights_l).contiguous()
299+
scales = stack_and_dev(scales_l)
300+
g_idx = stack_and_dev(g_idx_l)
301+
sort_indices = stack_and_dev(sort_indices_l)
302+
303+
score = torch.randn((m, e), device="cuda", dtype=dtype)
304+
marlin_output = single_marlin_moe(a,
305+
qweight,
306+
scales,
307+
score,
308+
g_idx,
309+
sort_indices,
310+
topk,
311+
renormalize=False)
312+
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
313+
314+
assert compute_max_diff(marlin_output, torch_output) < 1e-2

tests/weight_loading/models-large.txt

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
2+
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
3+
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main

tests/weight_loading/models.txt

-2
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main
1919
compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
2020
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
2121
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
22-
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
23-
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
2422
compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main
2523
awq, casperhansen/mixtral-instruct-awq, main
2624
awq_marlin, casperhansen/mixtral-instruct-awq, main

vllm/model_executor/layers/fused_moe/__init__.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,22 @@
22
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
33
from vllm.triton_utils import HAS_TRITON
44

5-
__all__ = ["FusedMoE", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported"]
5+
__all__ = [
6+
"FusedMoE",
7+
"FusedMoEMethodBase",
8+
"FusedMoeWeightScaleSupported",
9+
]
610

711
if HAS_TRITON:
8-
12+
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
13+
fused_marlin_moe, single_marlin_moe)
914
from vllm.model_executor.layers.fused_moe.fused_moe import (
10-
fused_experts, fused_marlin_moe, fused_moe, fused_topk,
11-
get_config_file_name, grouped_topk)
15+
fused_experts, fused_moe, fused_topk, get_config_file_name,
16+
grouped_topk)
1217

1318
__all__ += [
1419
"fused_marlin_moe",
20+
"single_marlin_moe",
1521
"fused_moe",
1622
"fused_topk",
1723
"fused_experts",

0 commit comments

Comments
 (0)