Skip to content

[Inductor] Support scaled mm on inductor #2411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,5 +392,59 @@ def test_dynamic_scale_numeric_parity(
assert torch.equal(float8_eager._data, float8_compile._data)


@pytest.mark.parametrize(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is the training float8 test file, float8 inference is using https://github.com/pytorch/ao/blob/main/test/dtypes/test_affine_quantized_float.py

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is the training float8 test file, float8 inference is using https://github.com/pytorch/ao/blob/main/test/dtypes/test_affine_quantized_float.py

Ok. I change the ut path on last pr #2379

"float8_dtype",
[
torch.float8_e4m3fn,
torch.float8_e5m2,
],
)
@pytest.mark.parametrize(
"hp_dtype",
[
torch.float32,
torch.float16,
torch.bfloat16,
],
)
def test_quantize_dequantize_fp8_inductor(float8_dtype, hp_dtype):
quantize_affine_float8 = torch.ops.torchao.quantize_affine_float8
dequantize_affine_float8 = torch.ops.torchao.dequantize_affine_float8
input = torch.randn(10, 10)
with torch.no_grad():
torch._dynamo.reset()
expected_scale = torch.tensor(2.0)
expected_quantized = quantize_affine_float8(
input,
expected_scale,
float8_dtype=float8_dtype,
)
expected_dequantized = dequantize_affine_float8(
expected_quantized,
expected_scale,
output_dtype=hp_dtype,
)
test_q, (code_q,) = torch._inductor.utils.run_and_get_code(
torch.compile(quantize_affine_float8),
input,
expected_scale,
float8_dtype=float8_dtype,
)
torch.testing.FileCheck().check(
"torch.ops.torchao.quantize_affine_float8.default"
).run(code_q)
test_dq, (code_dq,) = torch._inductor.utils.run_and_get_code(
torch.compile(dequantize_affine_float8),
test_q,
expected_scale,
hp_dtype,
)
torch.testing.FileCheck().check(
"torch.ops.torchao.dequantize_affine_float8.default"
).run(code_dq)
torch.testing.assert_close(expected_quantized, test_q)
torch.testing.assert_close(expected_dequantized, test_dq)


if __name__ == "__main__":
pytest.main([__file__])
64 changes: 64 additions & 0 deletions test/quantization/pt2e/test_x86inductor_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2426,6 +2426,70 @@ def matcher_check_fn():
if test_for_pointwise_binary:
self.assertEqual(counters["inductor"]["qlinear_binary_matcher_count"], 1)

@skipIfNoONEDNN
@parametrize("has_bias", [True, False])
@parametrize("dtype", [torch.float32, torch.bfloat16])
@parametrize("input_dim_exceeds_two", [True, False])
def test_scaled_mm(self, has_bias, dtype, input_dim_exceeds_two):
class FP8QDQLinear(torch.nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.qtype = torch.float8_e4m3fn
self.weight = torch.randn((out_features, in_features)).to(self.qtype)
self.weight_scale = 2.0
self.scale = 2.0
self.bias = None
if has_bias:
self.bias = torch.randn((out_features,)).to(dtype)

def forward(self, input):
weight = torch.ops.torchao.dequantize_affine_float8(
tensor=self.weight.data,
scale=torch.tensor(self.weight_scale),
output_dtype=torch.float,
)
if dtype != torch.float:
weight = weight.to(dtype)

q_input = torch.ops.torchao.quantize_affine_float8(
tensor=input,
scale=torch.tensor(self.scale),
float8_dtype=self.qtype,
)
dq_input = torch.ops.torchao.dequantize_affine_float8(
tensor=q_input,
scale=torch.tensor(self.scale),
output_dtype=torch.float,
)
if dtype != torch.float:
dq_input = dq_input.to(dtype)

out = torch.nn.functional.linear(dq_input, weight, self.bias)
return out

class Mod(torch.nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.l0 = FP8QDQLinear(in_features, out_features)

def forward(self, x):
y = self.l0(x)
return y

M1, M2, N, K = 2, 3, 13, 16
M = M1 * M2
mod = Mod(N, K)
if input_dim_exceeds_two:
v = torch.randn(M1, M2, N)
else:
v = torch.randn(M, N)
v = v.to(dtype)

def matcher_check_fn():
self.assertEqual(counters["inductor"]["scaled_mm_matcher_count"], 1)

self._test_common(mod, (v,), matcher_check_fn)


@dynamo_config.patch(
{
Expand Down
237 changes: 237 additions & 0 deletions torchao/quantization/pt2e/inductor_passes/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -2740,6 +2740,241 @@ def _register_qlinear_binary_fusion():
)


def _generate_dequant_fp8_linear_node_pattern(dtype, input_dim_exceeds_two):
# + - - - - | - - - - - - | - - - - +
# | dq_per_tensor dq_per_tensor |
# | | | |
# | OPT(to_bf16) OPT(to_bf16) |
# | | | |
# | OPT(reshape) permute |
# | \ / |
# | addmm/mm |
# | | |
# | OPT(quant_per_tensor) |
# | | |
# | OPT(reshape) |
assert dtype in [torch.float32, torch.bfloat16]
dequant_wgt_pattern = CallFunction(
torch.ops.torchao.dequantize_affine_float8.default,
KeywordArg("q_weight"),
KeywordArg("w_scale"),
output_dtype=KeywordArg("w_dtype"),
)
t_pattern = CallFunction(
aten.permute.default,
_may_generate_pattern_with_dtype_convert(
dequant_wgt_pattern,
KeywordArg("autocast_wgt_dtype"),
dtype == torch.bfloat16,
),
KeywordArg("permute_axes"),
)
dequantize_per_tensor_activation_pattern = CallFunction(
torch.ops.torchao.dequantize_affine_float8.default,
KeywordArg("x"),
KeywordArg("x_scale"),
output_dtype=KeywordArg("x_dq_dtype"),
)

dequant_fp8_linear_bias_pattern = _may_generate_pattern_with_reshape(
CallFunction(
aten.addmm.default,
KeywordArg("b"),
_may_generate_pattern_with_reshape(
_may_generate_pattern_with_dtype_convert(
dequantize_per_tensor_activation_pattern,
KeywordArg("autocast_act_dtype"),
dtype == torch.bfloat16,
),
KeywordArg("act_reshape_size"),
input_dim_exceeds_two,
),
t_pattern,
),
KeywordArg("output_reshape_size"),
input_dim_exceeds_two,
)
dequant_fp8_linear_no_bias_pattern = _may_generate_pattern_with_reshape(
CallFunction(
aten.mm.default,
_may_generate_pattern_with_reshape(
_may_generate_pattern_with_dtype_convert(
dequantize_per_tensor_activation_pattern,
KeywordArg("autocast_act_dtype"),
dtype == torch.bfloat16,
),
KeywordArg("act_reshape_size"),
input_dim_exceeds_two,
),
t_pattern,
),
KeywordArg("output_reshape_size"),
input_dim_exceeds_two,
)
return dequant_fp8_linear_bias_pattern, dequant_fp8_linear_no_bias_pattern


def _is_valid_scaled_mm_pattern(dtype, input_dim_exceeds_two):
def _inner(match):
input_contiguous = True
# Check dequant pattern has only 1 user.
(
linear_node,
_,
) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous)

input_index = 1 if linear_node.target is aten.addmm.default else 0
assert dtype in [torch.float32, torch.bfloat16]
(
dequant_node,
_,
_,
_,
) = _get_linear_dq_node(
linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous
)
assert dequant_node.target is torch.ops.torchao.dequantize_affine_float8.default

# only support float8_e4m3 input
if dequant_node.meta["eager_input_vals"][0][0].dtype != torch.float8_e4m3fn:
return False

if len(list(dequant_node.users)) != 1:
# Ensure the dequant pattern only has 1 user
# since we will delete the dequant pattern here
return False

return True

return _inner


def _register_scaled_mm_pass(pattern, dtype, input_dim_exceeds_two):
@register_freezing_graph_pattern(
pattern,
extra_check=_is_valid_scaled_mm_pattern(dtype, input_dim_exceeds_two),
pass_number=0,
)
def scaled_mm_fusion(match: Match, *args, **kwargs):
input_contiguous = True
assert dtype in [torch.float32, torch.bfloat16]
(
linear_node,
output_reshape_node,
) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous)
input_index = 1 if linear_node.target is aten.addmm.default else 0
weight_index = input_index + 1

(
dequant_node,
act_reshape_node,
activation_to_bf16_node,
act_expand_node,
) = _get_linear_dq_node(
linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous
)

if input_dim_exceeds_two and not input_contiguous:
wgt_expand_node = linear_node.args[weight_index]
assert wgt_expand_node.target is aten.expand.default
t_node = wgt_expand_node.args[0]
else:
t_node = linear_node.args[weight_index]

if dtype == torch.float32:
dequant_per_tensor = t_node.args[0]
else:
weight_to_bf16_node = t_node.args[0]
dequant_per_tensor = weight_to_bf16_node.args[0]
assert (
dequant_per_tensor.target
is torch.ops.torchao.dequantize_affine_float8.default
)

# Activation QParams
qx, x_scale = (
kwargs["x"],
kwargs["x_scale"],
)

# Weight QParams
qw, w_scale = (
kwargs["q_weight"],
kwargs["w_scale"],
)

# Params
bias = kwargs["b"] if "b" in kwargs else None

x_shape = qx.meta.get("tensor_meta").shape
if has_free_symbols(x_shape):
# For dynamic shape case, we can't get activation shape ahead of runtime.
x_shape = None
graph = match.graph
with graph.inserting_before(linear_node):
scaled_mm_input_node = qx
if input_dim_exceeds_two:
new_reshape_args: tuple[Any, ...] = (qx, act_reshape_node.args[1])
new_act_reshape_node = graph.call_function(
torch.ops.aten.reshape.default, args=new_reshape_args
)
scaled_mm_input_node = new_act_reshape_node
# Insert weight prepack node and the qlinear node
permute_weight_inputs = (
qw,
t_node.args[1],
)
permute_weight_op = torch.ops.aten.permute.default
permute_weight_node = graph.call_function(
permute_weight_op, args=permute_weight_inputs
)
output_scale = torch.tensor(1.0)
new_args: tuple[Any, ...] = (
scaled_mm_input_node,
permute_weight_node,
x_scale,
w_scale,
bias,
output_scale, # output_scale
dtype, # output_dtype
False, # use_fast_accum
)
new_linear_node = graph.call_function(
torch.ops.aten._scaled_mm.default, args=new_args
)

linear_node.replace_all_uses_with(new_linear_node)
new_linear_node.meta.update(linear_node.meta)

graph.erase_node(linear_node)
if input_dim_exceeds_two:
graph.erase_node(act_reshape_node)
if dtype == torch.bfloat16:
graph.erase_node(activation_to_bf16_node)
# Erase the dequant pattern
graph.erase_node(dequant_node)
# Erase the dequant per channel pattern
graph.erase_node(t_node)
if dtype == torch.bfloat16:
graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined]
graph.erase_node(dequant_per_tensor)

counters["inductor"]["scaled_mm_matcher_count"] += 1
counters["inductor"]["scaled_mm_matcher_nodes"] += len(match.nodes)


def _register_scaled_mm():
fp8_linear_weight_prepack_cases = itertools.product(
[torch.float32, torch.bfloat16], [False, True]
)
for dtype, input_dim_exceeds_two in fp8_linear_weight_prepack_cases:
patterns = _generate_dequant_fp8_linear_node_pattern(
dtype, input_dim_exceeds_two
)
for pattern in patterns:
_register_scaled_mm_pass(pattern, dtype, input_dim_exceeds_two)


@functools.lru_cache(None)
def _register_quantization_weight_pack_pass():
# Step 1: Dequant promotion for int8-mixed-fp32/bf16
Expand All @@ -2763,6 +2998,8 @@ def _register_quantization_weight_pack_pass():
_register_qlinear_unary_fusion()
_register_qlinear_binary_fusion()

_register_scaled_mm()


def quant_lift_up(module_graph: torch.fx.graph.Graph):
"""
Expand Down
Loading