Skip to content

Commit 8042ac3

Browse files
committed
moe quant with dedicated kernels [wip]
Summary: extending the torchao moe support to have more performant kernels. This PR supports both scaled_grouped_mm and fbgemm's grouped_gemm_fp8_rowwise though it seems like grouped_gemm_fp8_rowwise is a bit buggy (need to make a clear repro) todo: run benchmarks, debug fbgemm kernel, unit tests Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent f0f1f6c commit 8042ac3

File tree

4 files changed

+290
-9
lines changed

4 files changed

+290
-9
lines changed

test/quantization/test_moe_quant.py

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,25 @@
2525
Int8WeightOnlyConfig,
2626
LinearActivationQuantizedTensor,
2727
quantize_,
28+
PerRow,
29+
PerTensor,
2830
)
2931
from torchao.quantization.utils import compute_error
3032
from torchao.utils import (
3133
TORCH_VERSION_AT_LEAST_2_5,
3234
TORCH_VERSION_AT_LEAST_2_6,
3335
is_sm_at_least_90,
3436
)
37+
from torchao.quantization.utils import compute_error
3538

3639
if torch.version.hip is not None:
3740
pytest.skip(
3841
"ROCm support for MoE quantization is under development",
3942
allow_module_level=True,
4043
)
44+
from torchao.prototype.moe_quant.kernels import fp8_dq_moe_op
4145

46+
torch.manual_seed(0)
4247

4348
class TestMoEQuantCompile(unittest.TestCase):
4449
DEFAULT_PARAMS = (512, 256, 8, 2) # hidden_dim, expert_dim, num_experts, top_k
@@ -68,7 +73,6 @@ def _test_impl_moe_quant(
6873
.to(device)
6974
)
7075
input = torch.randn(input_shape, dtype=torch.bfloat16, device=device)
71-
7276
out = model(input)
7377

7478
quantize_(model, config, cond_ffn_filter)
@@ -363,6 +367,113 @@ def test_fp8dq_base(self, name, num_tokens, fullgraph):
363367
fullgraph=fullgraph,
364368
)
365369

370+
class TestFusedMoEQuant(unittest.TestCase):
371+
DEFAULT_PARAMS = (512, 256, 8, 2) # hidden_dim, expert_dim, num_experts, top_k
372+
373+
@parameterized.expand(
374+
[
375+
("multiple_tokens", 8),
376+
]
377+
)
378+
def test_pytorch_scaled_grouped_gemm(self, name, num_tokens):
379+
if not torch.cuda.is_available():
380+
self.skipTest("Need CUDA available")
381+
if not is_sm_at_least_90():
382+
self.skipTest("Requires CUDA capability >= 9.0")
383+
384+
device = "cuda"
385+
dtype = torch.bfloat16
386+
387+
config = MoEQuantConfig(Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()))
388+
389+
model_params = self.DEFAULT_PARAMS
390+
391+
input_shape = (num_tokens, model_params[0])
392+
input = torch.randn(input_shape, dtype=torch.bfloat16, device=device)
393+
394+
model = (
395+
MOEFeedForwardAOQuantizable(*model_params, empty_init=False)
396+
)
397+
model = model.to(dtype).to(device)
398+
399+
out_orig = model(input)
400+
401+
quantize_(model, config, cond_ffn_filter)
402+
403+
w1 = model.experts.w1
404+
w2 = model.experts.w2
405+
w3 = model.experts.w3
406+
407+
router = model.router
408+
top_k = model.top_k
409+
410+
# preprocess
411+
scores = router(input) # [T, E]
412+
scores = torch.nn.functional.softmax(scores, dim=-1)
413+
scores, expert_indices = torch.topk(
414+
scores, top_k, dim=-1
415+
) # [T, A], [T, A]
416+
scores /= scores.sum(dim=-1, keepdim=True).to(input.dtype) # [T, A]
417+
418+
out = fp8_dq_moe_op(input, w1, w2, w3, expert_indices, scores)
419+
out2 = model(input)
420+
421+
self.assertTrue(compute_error(out_orig, out) > 20)
422+
self.assertTrue(compute_error(out_orig, out2) > 20)
423+
424+
class TestFusedMoEQuant(unittest.TestCase):
425+
DEFAULT_PARAMS = (512, 256, 8, 2) # hidden_dim, expert_dim, num_experts, top_k
426+
427+
@parameterized.expand(
428+
[
429+
("multiple_tokens", 8),
430+
]
431+
)
432+
def test_fbgemm_scaled_grouped_gemm(self, name, num_tokens):
433+
if not torch.cuda.is_available():
434+
self.skipTest("Need CUDA available")
435+
if not is_sm_at_least_90():
436+
self.skipTest("Requires CUDA capability >= 9.0")
437+
438+
device = "cuda"
439+
dtype = torch.bfloat16
440+
441+
config = MoEQuantConfig(Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()))
442+
443+
model_params = self.DEFAULT_PARAMS
444+
445+
input_shape = (num_tokens, model_params[0])
446+
input = torch.randn(input_shape, dtype=torch.bfloat16, device=device)
447+
448+
model = (
449+
MOEFeedForwardAOQuantizable(*model_params, empty_init=False)
450+
)
451+
model = model.to(dtype).to(device)
452+
453+
out_orig = model(input)
454+
455+
quantize_(model, config, cond_ffn_filter)
456+
457+
w1 = model.experts.w1
458+
w2 = model.experts.w2
459+
w3 = model.experts.w3
460+
461+
router = model.router
462+
top_k = model.top_k
463+
464+
# preprocess
465+
scores = router(input) # [T, E]
466+
scores = torch.nn.functional.softmax(scores, dim=-1)
467+
scores, expert_indices = torch.topk(
468+
scores, top_k, dim=-1
469+
) # [T, A], [T, A]
470+
scores /= scores.sum(dim=-1, keepdim=True).to(input.dtype) # [T, A]
471+
472+
out = fp8_dq_moe_op(input, w1, w2, w3, expert_indices, scores, use_fbgemm_kernel=True)
473+
out2 = model(input)
474+
475+
self.assertTrue(compute_error(out_orig, out) > 20)
476+
self.assertTrue(compute_error(out_orig, out2) > 20)
366477

367478
if __name__ == "__main__":
368479
unittest.main()
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
import torch
2+
import torch.nn.functional as F
3+
from torchao.quantization.utils import _torchtitan_available, _fbgemm_available
4+
5+
grouped_gemm_fp8_rowwise = None
6+
if _fbgemm_available:
7+
try:
8+
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import grouped_gemm_fp8_rowwise
9+
except:
10+
pass
11+
12+
__all__ = ["fp8_dq_moe_op",
13+
"manual_pad",
14+
"torchtitan_pad",
15+
]
16+
17+
18+
def fp8_dq_moe_op(input, w1, w2, w3, expert_indices, scores, fast_accum=True, use_fbgemm_kernel=False):
19+
# parameters
20+
orig_in_shape = input.shape
21+
input.reshape(-1, orig_in_shape[-1])
22+
num_tokens, dim = input.shape
23+
num_experts, expert_dim, _ = w1.shape
24+
scores = scores.view(-1, scores.shape[-1])
25+
top_k = scores.shape[-1]
26+
total_activations = num_tokens*top_k
27+
28+
# preprocess indices
29+
expert_indices = expert_indices.view(-1)
30+
activation_shuffle = expert_indices.argsort(stable=True)
31+
token_shuffle = activation_shuffle.div(top_k).floor().to(torch.int64)
32+
num_tokens_per_expert = torch.histc(expert_indices, bins=num_experts, min=0, max=num_experts)
33+
34+
# padding
35+
alignment = 16
36+
if _torchtitan_available:
37+
num_ranks = 1
38+
padded_indices, m_offsets = torchtitan_pad(num_tokens_per_expert, alignment, num_ranks)
39+
else:
40+
padded_indices, m_offsets = manual_pad(num_tokens_per_expert, alignment)
41+
42+
pad_len = padded_indices.shape[0]
43+
valid_values = padded_indices >= 0
44+
45+
# get data for weights
46+
w1_fp8 = w1.original_weight_tensor.tensor_impl.float8_data
47+
w1_scale = w1.original_weight_tensor.tensor_impl.scale.squeeze()
48+
w1_qfunc = w1.input_quant_func
49+
w1_quant_kwargs = w1.quant_kwargs
50+
51+
w3_fp8 = w3.original_weight_tensor.tensor_impl.float8_data
52+
w3_scale = w3.original_weight_tensor.tensor_impl.scale.squeeze()
53+
54+
w2_fp8 = w2.original_weight_tensor.tensor_impl.float8_data
55+
w2_scale = w2.original_weight_tensor.tensor_impl.scale.squeeze()
56+
w2_qfunc = w2.input_quant_func
57+
w2_quant_kwargs = w2.quant_kwargs
58+
59+
60+
# quantize then shuffle input
61+
q_input = w1_qfunc(input, **w1_quant_kwargs)
62+
q_input_data = q_input.tensor_impl.float8_data
63+
q_input_scale = q_input.tensor_impl.scale.squeeze()
64+
input_fp8 = torch.zeros((pad_len, q_input_data.shape[-1]), dtype=q_input_data.dtype, device=q_input_data.device)
65+
input_scale = torch.zeros(pad_len, dtype=q_input_scale.dtype, device=q_input_scale.device)
66+
input_fp8[valid_values] = q_input_data[token_shuffle]
67+
input_scale[valid_values] = q_input_scale[token_shuffle] if q_input_scale.numel()>1 else q_input_scale
68+
69+
if use_fbgemm_kernel:
70+
assert grouped_gemm_fp8_rowwise is not None, "fbgemm kernel requires fbgemm-gpu-genai to be installed: https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gen_ai/README.md"
71+
y1 = grouped_gemm_fp8_rowwise(input_fp8, w1_fp8.reshape(-1, w1_fp8.shape[-1]), m_offsets, input_scale, w1_scale.reshape(-1), use_fast_accum=True)
72+
y3 = grouped_gemm_fp8_rowwise(input_fp8, w3_fp8.reshape(-1, w3_fp8.shape[-1]), m_offsets, input_scale, w3_scale.reshape(-1), use_fast_accum=True)
73+
74+
y = F.silu(y1)*y3
75+
76+
y_q = w2_qfunc(y, **w2_quant_kwargs)
77+
78+
y_fp8 = y_q.tensor_impl.float8_data
79+
y_scale = y_q.tensor_impl.scale.squeeze()
80+
out = grouped_gemm_fp8_rowwise(y_fp8, w2_fp8.view(-1, w1_fp8.shape[-1]), m_offsets, y_scale, w2_scale.view(-1), use_fast_accum=fast_accum)
81+
# unpad and combine output with weights
82+
out = out[valid_values]
83+
sorted_scores = scores.reshape(-1,1)[activation_shuffle]
84+
out = out*sorted_scores
85+
86+
# sum weighted outputs
87+
final_out = torch.zeros_like(input)
88+
final_out = final_out.scatter_add(
89+
dim=0,
90+
index=token_shuffle.unsqueeze(-1).expand(total_activations, dim).to(torch.int64),
91+
src=out
92+
)
93+
final_out = final_out.reshape(orig_in_shape)
94+
return final_out
95+
96+
else:
97+
y1 = torch._scaled_grouped_mm(input_fp8, w1_fp8.transpose(-2, -1), input_scale, w1_scale, offs=m_offsets, out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
98+
y3 = torch._scaled_grouped_mm(input_fp8, w3_fp8.transpose(-2, -1), input_scale, w3_scale, offs=m_offsets, out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
99+
y = F.silu(y1)*y3
100+
y_q = w2_qfunc(y, **w2_quant_kwargs)
101+
102+
y_fp8 = y_q.tensor_impl.float8_data
103+
y_scale = y_q.tensor_impl.scale.squeeze()
104+
out = torch._scaled_grouped_mm(y_fp8, w2_fp8.transpose(-2, -1), y_scale, w2_scale, offs=m_offsets, out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
105+
106+
# unpad and combine output with weights
107+
out = out[valid_values]
108+
sorted_scores = scores.reshape(-1,1)[activation_shuffle]
109+
out = out*sorted_scores
110+
111+
# sum weighted outputs
112+
final_out = torch.zeros_like(input)
113+
final_out = final_out.scatter_add(
114+
dim=0,
115+
index=token_shuffle.unsqueeze(-1).expand(total_activations, dim).to(torch.int64),
116+
src=out
117+
)
118+
final_out = final_out.reshape(orig_in_shape)
119+
return final_out
120+
121+
def torchtitan_pad(num_tokens_per_expert, alignment, num_ranks):
122+
from torchtitan.experiments.kernels.moe.indices import generate_permute_indices
123+
num_experts = num_tokens_per_expert.shape[0]
124+
125+
# pad to nearest multiple of alignment that's greater than 0
126+
padded_sizes = (((num_tokens_per_expert + (num_tokens_per_expert==0))/alignment).ceil() * alignment)
127+
pad_len = int(padded_sizes.sum().item())
128+
129+
padded_indices, _, m_offsets = generate_permute_indices(
130+
num_tokens_per_expert,
131+
num_experts,
132+
num_ranks,
133+
pad_len,
134+
alignment,
135+
use_cpu=False
136+
)
137+
return padded_indices, m_offsets
138+
139+
def manual_pad(num_tokens_per_expert, alignment):
140+
num_experts = num_tokens_per_expert.shape[0]
141+
142+
padded_sizes = (((num_tokens_per_expert + (num_tokens_per_expert==0))/alignment).ceil() * alignment)
143+
pad_len = int(padded_sizes.sum().item())
144+
145+
padded_indices = torch.zeros(pad_len, dtype=torch.int32, device=num_tokens_per_expert.device)-1
146+
start_tok_index = 0
147+
start_pad_index = 0
148+
for i in range(num_experts):
149+
end_tok_index = int(start_tok_index+num_tokens_per_expert[i].item())
150+
end_pad_index = int(start_pad_index+num_tokens_per_expert[i].item())
151+
padded_indices[start_pad_index:end_pad_index] = torch.arange(start_tok_index, end_tok_index, dtype=torch.int32, device=num_tokens_per_expert.device)
152+
start_tok_index = end_tok_index
153+
start_pad_index = start_pad_index + int(padded_sizes[i].item())
154+
m_offsets = padded_sizes.cumsum(0).to(torch.int32)
155+
return padded_indices, m_offsets

torchao/prototype/moe_quant/quantizable_moe_modules.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import torch
2+
import torchao
23
import torch.nn.functional as F
34
from torch import Tensor, nn
45

56
from torchao.prototype.moe_quant.utils import FakeExtraDimTensor
6-
7+
from torchao.quantization.utils import _torchtitan_available
8+
from torchao.prototype.moe_quant.kernels import fp8_dq_moe_op
79

810
class MOEFeedForwardAOQuantizable(nn.Module):
911
def __init__(
@@ -28,7 +30,7 @@ def __init__(
2830
self.return_scores = return_scores
2931

3032
def forward(self, x: Tensor) -> Tensor:
31-
batch_size = x.shape[0]
33+
shape_no_dim = x.shape[:-1]
3234
x = x.view(-1, self.hidden_dim) # x: [T, D]
3335
scores = self.router(x) # [T, E]
3436
scores = F.softmax(scores, dim=-1)
@@ -40,11 +42,12 @@ def forward(self, x: Tensor) -> Tensor:
4042
out = self.experts(x, expert_indices, scores, self.top_k)
4143
if self.shared_expert:
4244
out += self.shared_expert(x)
43-
45+
out = out.reshape(*shape_no_dim, -1)
46+
4447
if self.return_scores:
45-
return out.reshape(batch_size, -1, self.hidden_dim), scores
48+
return out, scores
4649
else:
47-
return out.reshape(batch_size, -1, self.hidden_dim)
50+
return out
4851

4952

5053
class ConditionalFeedForwardAOQuantizable(nn.Module):
@@ -79,7 +82,7 @@ def forward(
7982
self,
8083
x: Tensor, # T, D
8184
expert_indices: Tensor, # T, A
82-
expert_weights: Tensor, # T, A
85+
scores: Tensor, # T, A
8386
top_k: int,
8487
) -> Tensor:
8588
num_tokens, _hidden_dim = x.shape
@@ -105,11 +108,20 @@ def forward(
105108

106109
# combine outputs
107110
final_out = (
108-
(torch.cat(outs, dim=0) * expert_weights.view(-1, 1))
111+
(torch.cat(outs, dim=0) * scores.view(-1, 1))
109112
.sum(dim=0)
110113
.reshape(x.shape)
111114
)
112115
return final_out
116+
117+
# fp8 dq moe
118+
elif (
119+
isinstance(self.w1, torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor) and
120+
isinstance(self.w1.original_weight_tensor._layout, torchao.dtypes.floatx.float8_layout.Float8Layout)
121+
):
122+
final_out = fp8_dq_moe_op(x, self.w1, self.w2, self.w3, expert_indices, scores)
123+
return final_out
124+
113125
else:
114126
expert_list = [x for x in range(self.num_experts)]
115127

@@ -172,7 +184,7 @@ def group_tokens_by_expert(
172184

173185
# weigh outputs
174186
ordered_outs = torch.cat(outs, dim=0) # [T*A, D]
175-
ordered_token_activation_weights = expert_weights.view(-1, 1)[
187+
ordered_token_activation_weights = scores.view(-1, 1)[
176188
ordered_token_activations
177189
].view(-1, 1) # [T*A, 1]
178190
weighted_ordered_outs = (

torchao/quantization/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@
5454

5555
_lm_eval_available = importlib.util.find_spec("lm_eval") is not None
5656

57+
_torchtitan_available = importlib.util.find_spec("torchtitan") is not None
58+
59+
_fbgemm_available = importlib.util.find_spec("fbgemm_gpu") is not None
5760

5861
# basic SQNR
5962
def compute_error(x, y):

0 commit comments

Comments
 (0)