2
2
3
3
Run `pytest tests/kernels/test_moe.py`.
4
4
"""
5
+ from typing import List
6
+
5
7
import pytest
6
8
import torch
7
9
from transformers import MixtralConfig
8
10
from transformers .models .mixtral .modeling_mixtral import MixtralSparseMoeBlock
9
11
10
12
from vllm .model_executor .layers .activation import SiluAndMul
11
13
from vllm .model_executor .layers .fused_moe import fused_moe
14
+ from vllm .model_executor .layers .fused_moe .fused_marlin_moe import (
15
+ fused_marlin_moe , single_marlin_moe )
16
+ from vllm .model_executor .layers .fused_moe .fused_moe import fused_topk
17
+ from vllm .model_executor .layers .quantization .utils .marlin_utils_test import (
18
+ marlin_quantize )
12
19
from vllm .model_executor .models .mixtral import MixtralMoE
20
+ from vllm .scalar_type import scalar_types
13
21
14
22
15
23
def torch_moe (a , w1 , w2 , score , topk ):
@@ -29,6 +37,20 @@ def torch_moe(a, w1, w2, score, topk):
29
37
topk_weight .view (B , - 1 , 1 ).to (out .dtype )).sum (dim = 1 )
30
38
31
39
40
+ def torch_moe_single (a , w , score , topk ):
41
+ B , D = a .shape
42
+ a = a .view (B , - 1 , D ).repeat (1 , topk , 1 ).reshape (- 1 , D )
43
+ out = torch .zeros (B * topk , w .shape [1 ], dtype = a .dtype , device = a .device )
44
+ score = torch .softmax (score , dim = - 1 , dtype = torch .float32 )
45
+ _ , topk_ids = torch .topk (score , topk )
46
+ topk_ids = topk_ids .view (- 1 )
47
+ for i in range (w .shape [0 ]):
48
+ mask = topk_ids == i
49
+ if mask .sum ():
50
+ out [mask ] = a [mask ] @ w [i ].transpose (0 , 1 )
51
+ return (out .view (B , - 1 , w .shape [1 ])).sum (dim = 1 )
52
+
53
+
32
54
@pytest .mark .parametrize ("m" , [1024 * 128 , 512 , 222 , 33 , 1 ])
33
55
@pytest .mark .parametrize ("n" , [2048 , 256 , 1024 ])
34
56
@pytest .mark .parametrize ("k" , [128 , 511 , 1024 ])
@@ -43,11 +65,11 @@ def test_fused_moe(
43
65
topk : int ,
44
66
dtype : torch .dtype ,
45
67
):
46
- a = torch .randn ((m , k ), device = ' cuda' , dtype = dtype ) / 10
47
- w1 = torch .randn ((e , 2 * n , k ), device = ' cuda' , dtype = dtype ) / 10
48
- w2 = torch .randn ((e , k , n ), device = ' cuda' , dtype = dtype ) / 10
68
+ a = torch .randn ((m , k ), device = " cuda" , dtype = dtype ) / 10
69
+ w1 = torch .randn ((e , 2 * n , k ), device = " cuda" , dtype = dtype ) / 10
70
+ w2 = torch .randn ((e , k , n ), device = " cuda" , dtype = dtype ) / 10
49
71
50
- score = torch .randn ((m , e ), device = ' cuda' , dtype = dtype )
72
+ score = torch .randn ((m , e ), device = " cuda" , dtype = dtype )
51
73
triton_output = fused_moe (a , w1 , w2 , score , topk , renormalize = False )
52
74
torch_output = torch_moe (a , w1 , w2 , score , topk )
53
75
torch .testing .assert_close (triton_output , torch_output , atol = 1e-2 , rtol = 0 )
@@ -99,3 +121,194 @@ def test_mixtral_moe(dtype: torch.dtype):
99
121
vllm_states ,
100
122
rtol = mixtral_moe_tol [dtype ],
101
123
atol = mixtral_moe_tol [dtype ])
124
+
125
+
126
+ def stack_and_dev (tensors : List [torch .Tensor ]):
127
+ dev = tensors [0 ].device
128
+ return torch .stack (tensors , dim = 0 ).to (dev )
129
+
130
+
131
+ def compute_max_diff (output , output_ref ):
132
+ return torch .mean (torch .abs (output - output_ref )) / torch .mean (
133
+ torch .abs (output_ref ))
134
+
135
+
136
+ @pytest .mark .parametrize ("m" , [64 , 512 , 222 , 33 , 1 ])
137
+ @pytest .mark .parametrize ("n" , [128 , 2048 , 256 , 1024 ])
138
+ @pytest .mark .parametrize ("k" , [128 , 1024 , 512 ])
139
+ @pytest .mark .parametrize ("e" , [4 , 8 , 64 ])
140
+ @pytest .mark .parametrize ("topk" , [2 , 6 ])
141
+ @pytest .mark .parametrize ("group_size" , [- 1 , 32 , 64 , 128 ])
142
+ @pytest .mark .parametrize ("act_order" , [True , False ])
143
+ def test_fused_marlin_moe (
144
+ m : int ,
145
+ n : int ,
146
+ k : int ,
147
+ e : int ,
148
+ topk : int ,
149
+ group_size : int ,
150
+ act_order : bool ,
151
+ ):
152
+ torch .manual_seed (7 )
153
+
154
+ if topk > e :
155
+ return
156
+
157
+ # Filter act_order
158
+ if act_order :
159
+ if group_size == - 1 :
160
+ return
161
+ if group_size in (k , n ):
162
+ return
163
+
164
+ quant_type = scalar_types .uint4b8
165
+ dtype = torch .float16
166
+ a = torch .randn ((m , k ), device = "cuda" , dtype = dtype ) / 10
167
+ w1 = torch .randn ((e , 2 * n , k ), device = "cuda" , dtype = dtype ) / 10
168
+ w2 = torch .randn ((e , k , n ), device = "cuda" , dtype = dtype ) / 10
169
+ for i in range (w2 .shape [0 ]):
170
+ w2 [0 ] = torch .eye (k , n , device = "cuda" , dtype = dtype )
171
+
172
+ w_ref1_l = []
173
+ qweight1_l = []
174
+ scales1_l = []
175
+ g_idx1_l = []
176
+ sort_indices1_l = []
177
+
178
+ for i in range (w1 .shape [0 ]):
179
+ test_perm = torch .randperm (k )
180
+ w_ref1 , qweight1 , scales1 , g_idx1 , sort_indices1 , _ = marlin_quantize (
181
+ w1 [i ].transpose (1 , 0 ), quant_type , group_size , act_order ,
182
+ test_perm )
183
+ w_ref1_l .append (w_ref1 )
184
+ qweight1_l .append (qweight1 )
185
+ scales1_l .append (scales1 )
186
+ g_idx1_l .append (g_idx1 )
187
+ sort_indices1_l .append (sort_indices1 )
188
+
189
+ w_ref1 = stack_and_dev (w_ref1_l )
190
+ qweight1 = stack_and_dev (qweight1_l ).contiguous ()
191
+ scales1 = stack_and_dev (scales1_l )
192
+ g_idx1 = stack_and_dev (g_idx1_l )
193
+ sort_indices1 = stack_and_dev (sort_indices1_l )
194
+
195
+ w_ref2_l = []
196
+ qweight2_l = []
197
+ scales2_l = []
198
+ g_idx2_l = []
199
+ sort_indices2_l = []
200
+
201
+ for i in range (w2 .shape [0 ]):
202
+ test_perm = torch .randperm (n )
203
+ w_ref2 , qweight2 , scales2 , g_idx2 , sort_indices2 , _ = marlin_quantize (
204
+ w2 [i ].transpose (1 , 0 ), quant_type , group_size , act_order ,
205
+ test_perm )
206
+ w_ref2_l .append (w_ref2 )
207
+ qweight2_l .append (qweight2 )
208
+ scales2_l .append (scales2 )
209
+ g_idx2_l .append (g_idx2 )
210
+ sort_indices2_l .append (sort_indices2 )
211
+
212
+ w_ref2 = stack_and_dev (w_ref2_l )
213
+ qweight2 = stack_and_dev (qweight2_l ).contiguous ()
214
+ scales2 = stack_and_dev (scales2_l )
215
+ g_idx2 = stack_and_dev (g_idx2_l )
216
+ sort_indices2 = stack_and_dev (sort_indices2_l )
217
+
218
+ score = torch .randn ((m , e ), device = "cuda" , dtype = dtype )
219
+
220
+ topk_weights , topk_ids = fused_topk (a , score , topk , False )
221
+
222
+ triton_output = fused_moe (
223
+ a ,
224
+ w_ref1 .transpose (1 , 2 ).contiguous (),
225
+ w_ref2 .transpose (1 , 2 ).contiguous (),
226
+ score ,
227
+ topk ,
228
+ renormalize = False ,
229
+ )
230
+ marlin_output = fused_marlin_moe (
231
+ a ,
232
+ qweight1 ,
233
+ qweight2 ,
234
+ score ,
235
+ g_idx1 ,
236
+ g_idx2 ,
237
+ sort_indices1 ,
238
+ sort_indices2 ,
239
+ topk_weights ,
240
+ topk_ids ,
241
+ w1_scale = scales1 ,
242
+ w2_scale = scales2 ,
243
+ )
244
+
245
+ assert compute_max_diff (marlin_output , triton_output ) < 4e-2
246
+
247
+
248
+ @pytest .mark .skip ("This test is here for the sake of debugging, "
249
+ "don't run it in automated tests." )
250
+ @pytest .mark .parametrize ("m" , [64 , 512 , 222 , 33 , 1 ])
251
+ @pytest .mark .parametrize ("n" , [128 , 2048 , 256 , 1024 ])
252
+ @pytest .mark .parametrize ("k" , [128 , 1024 , 512 ])
253
+ @pytest .mark .parametrize ("e" , [4 , 8 , 64 ])
254
+ @pytest .mark .parametrize ("topk" , [2 , 6 ])
255
+ @pytest .mark .parametrize ("group_size" , [- 1 , 32 , 64 , 128 ])
256
+ @pytest .mark .parametrize ("act_order" , [True , False ])
257
+ def test_marlin_moe_mmm (
258
+ m : int ,
259
+ n : int ,
260
+ k : int ,
261
+ e : int ,
262
+ topk : int ,
263
+ group_size : int ,
264
+ act_order : bool ,
265
+ ):
266
+ if topk > e :
267
+ return
268
+
269
+ # Filter act_order
270
+ if act_order :
271
+ if group_size == - 1 :
272
+ return
273
+ if group_size == k :
274
+ return
275
+
276
+ quant_type = scalar_types .uint4b8
277
+ dtype = torch .float16
278
+ a = torch .randn ((m , k ), device = "cuda" , dtype = dtype ) / 10
279
+ w = torch .randn ((e , n , k ), device = "cuda" , dtype = dtype ) / 10
280
+
281
+ w_ref_l = []
282
+ qweights_l = []
283
+ scales_l = []
284
+ g_idx_l = []
285
+ sort_indices_l = []
286
+
287
+ for i in range (w .shape [0 ]):
288
+ test_perm = torch .randperm (k )
289
+ w_ref , qweight , scales , g_idx , sort_indices , _ = marlin_quantize (
290
+ w [i ].transpose (1 , 0 ), quant_type , group_size , act_order , test_perm )
291
+ w_ref_l .append (w_ref )
292
+ qweights_l .append (qweight )
293
+ scales_l .append (scales )
294
+ g_idx_l .append (g_idx )
295
+ sort_indices_l .append (sort_indices )
296
+
297
+ w_ref = stack_and_dev (w_ref_l )
298
+ qweight = stack_and_dev (qweights_l ).contiguous ()
299
+ scales = stack_and_dev (scales_l )
300
+ g_idx = stack_and_dev (g_idx_l )
301
+ sort_indices = stack_and_dev (sort_indices_l )
302
+
303
+ score = torch .randn ((m , e ), device = "cuda" , dtype = dtype )
304
+ marlin_output = single_marlin_moe (a ,
305
+ qweight ,
306
+ scales ,
307
+ score ,
308
+ g_idx ,
309
+ sort_indices ,
310
+ topk ,
311
+ renormalize = False )
312
+ torch_output = torch_moe_single (a , w_ref .transpose (1 , 2 ), score , topk )
313
+
314
+ assert compute_max_diff (marlin_output , torch_output ) < 1e-2
0 commit comments