Skip to content

[float8] Prevent quantize_affine_float8/dequantize_affine_float8 decomposed on inductor #2379

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,46 @@ def test_preprocess_scale_3d_reshape(self):
expected_shape = (8, 1) # Flattened (2*2*2, 1)
self.assertEqual(result.shape, expected_shape)

@common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
@common_utils.parametrize("hp_dtype", [torch.float32, torch.bfloat16])
def test_quantize_dequantize_fp8_inductor(self, float8_dtype, hp_dtype):
quantize_affine_float8 = torch.ops.torchao.quantize_affine_float8
dequantize_affine_float8 = torch.ops.torchao.dequantize_affine_float8
input = torch.randn(10, 10)
with torch.no_grad():
torch._dynamo.reset()
expected_scale = torch.tensor(2.0)
expected_quantized = quantize_affine_float8(
input,
expected_scale,
float8_dtype=float8_dtype,
)
expected_dequantized = dequantize_affine_float8(
expected_quantized,
expected_scale,
output_dtype=hp_dtype,
)
test_q, (code_q,) = torch._inductor.utils.run_and_get_code(
torch.compile(quantize_affine_float8),
input,
expected_scale,
float8_dtype=float8_dtype,
)
torch.testing.FileCheck().check(
"torch.ops.torchao.quantize_affine_float8.default"
).run(code_q)
test_dq, (code_dq,) = torch._inductor.utils.run_and_get_code(
torch.compile(dequantize_affine_float8),
test_q,
expected_scale,
hp_dtype,
)
torch.testing.FileCheck().check(
"torch.ops.torchao.dequantize_affine_float8.default"
).run(code_dq)
torch.testing.assert_close(expected_quantized, test_q)
torch.testing.assert_close(expected_dequantized, test_dq)


common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)

Expand Down
20 changes: 20 additions & 0 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -2270,6 +2270,7 @@ def _expand_scale_to_tensor_shape(
return expanded_scale


@_register_custom_op(quant_lib, False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can probably make quantize_affine_float8 and choose_qparams_affine_float8 public as well since it's used in inductor lowering. cc @jainapurva @drisspg

def _quantize_affine_float8(
tensor: torch.Tensor,
scale: torch.Tensor,
Expand All @@ -2290,6 +2291,16 @@ def _quantize_affine_float8(
return fp8_tensor


@torch.library.impl(quant_lib, "quantize_affine_float8", "Meta")
def _quantize_affine_float8_meta(
tensor: torch.Tensor,
scale: torch.Tensor,
float8_dtype: torch.dtype = torch.float8_e4m3fn,
) -> torch.Tensor:
return torch.empty_like(tensor, dtype=float8_dtype)


@_register_custom_op(quant_lib, False)
def _dequantize_affine_float8(
tensor: torch.Tensor,
scale: torch.Tensor,
Expand All @@ -2305,3 +2316,12 @@ def _dequantize_affine_float8(

hp_tensor = fp8_tensor * scale_expanded
return hp_tensor.to(output_dtype)


@torch.library.impl(quant_lib, "dequantize_affine_float8", "Meta")
def _dequantize_affine_float8_meta(
tensor: torch.Tensor,
scale: torch.Tensor,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
return torch.empty_like(tensor, dtype=output_dtype)
13 changes: 10 additions & 3 deletions torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def find_multiple(n: int, *args: int) -> int:
return n + k - (n % k)


def _register_custom_op(lib):
def _register_custom_op(lib, inductor_decomposed=True):
"""This decorator is used to preserve some high level operators for torch.export.export
while still allow them to be decomposed for inductor path

Expand All @@ -206,6 +206,12 @@ def _the_op_that_needs_to_be_preserved(...)
"""
from torch._inductor.decomposition import register_decomposition

dispatch_key = (
"CompositeImplicitAutograd"
if inductor_decomposed
else "CompositeExplicitAutograd"
)

def decorator(fn):
if TORCH_VERSION_AT_LEAST_2_5:
from torch._library.infer_schema import infer_schema
Expand All @@ -221,11 +227,12 @@ def decorator(fn):
op_name = fn.__name__[1:]
schema = op_name + infer_schema(fn, mutates_args={})
lib.define(schema)
lib.impl(op_name, fn, "CompositeImplicitAutograd")
lib.impl(op_name, fn, dispatch_key)

lib_namespace = lib.ns
op = getattr(getattr(torch.ops, lib_namespace), op_name)
register_decomposition([op])(fn)
if inductor_decomposed:
register_decomposition([op])(fn)
return op
else:
return fn
Expand Down
Loading