Skip to content

[float8] Prevent quantize_affine_float8/dequantize_affine_float8 decomposed on inductor #2379

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

shiyang-weng
Copy link

@shiyang-weng shiyang-weng commented Jun 16, 2025

Fix #2228

What we want to do now is to enable FP8 quantization in PyTorch. And similar as INT8 quantization, we need to insert quantize and dequantize ops into the graph.
However we met problems with these q/dq ops both in the PyTorch core and Torchao.

PyTorch core:

The quantize_per_tensor op does not support FP8. We want to fix it via pytorch/pytorch#153601. And as you commented, the op is deprecated.
Torchao:

In the fusion pass in Inductor, we want to match the pattern fp8_weight -> torchao.dequantize_affine_float8 -> fp32_op and fuse it as fp8_weight -> weight_pack -> fp8_op. We have done so for INT8 PT2E quantization. However, the pattern matching pass is applied after a constant folding pass in Inductor:
https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/fx_passes/freezing_patterns.py#L69C1-L74C1
After constant_fold(gm), the pattern will be folded as fp32_weight -> fp32_op. Then the original pattern cannot be found any more and the FP8 semantics is lost since the pattern is entirely in fp32 now.
For INT8, the int8_weight -> quantized_decomposed.dequantize_per_channel -> fp32_op pattern won't be folded because we mark quantized_decomposed.dequantize_per_channel impure so that it won't be folded: https://github.com/pytorch/pytorch/blob/100ec0b34aeff2b948dae33937857d0c86cf1646/torch/_inductor/constant_folding.py#L139C1-L149C1 . But for the torchao.dequantize_affine_float8, we cannot do this because
It is an op from Torchao, which is unknown to the constant folder
It is decomposed to smaller ops, so we cannot put it in the list as a single op.

Need to prevent q/dq decomposed both on fp8 and int8.
Based on discussion on #2299
Create this PR to prevent fp8 q/dq decomposed firstly.
On this PR,

  1. Register quantize_affine_float8 and dequantize_affine_float8 to torch.ops.torchao
  2. Aligned dispatch_key with pt. quantize_affine_float8 and dequantize_affine_float8 will not be decomposed.
  3. Register meta func for quantize_affine_float8 and dequantize_affine_float8

Copy link

pytorch-bot bot commented Jun 16, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2379

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures

As of commit d6ab45d with merge base 6a8887f (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

Hi @shiyang-weng!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@shiyang-weng shiyang-weng marked this pull request as draft June 16, 2025 07:06
@shiyang-weng

This comment was marked as resolved.

torchao/utils.py Outdated
@@ -179,7 +179,7 @@ def find_multiple(n: int, *args: int) -> int:
return n + k - (n % k)


def _register_custom_op(lib):
def _register_custom_op(lib, decomposed=True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can be more explicit, e.g. inductor_decomposed

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@shiyang-weng shiyang-weng marked this pull request as ready for review June 19, 2025 01:14
@facebook-github-bot
Copy link
Contributor

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 19, 2025
@@ -2270,6 +2270,7 @@ def _expand_scale_to_tensor_shape(
return expanded_scale


@_register_custom_op(quant_lib, False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can probably make quantize_affine_float8 and choose_qparams_affine_float8 public as well since it's used in inductor lowering. cc @jainapurva @drisspg

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Quant] Can quant not be decomposed on inductor?
3 participants