-
Notifications
You must be signed in to change notification settings - Fork 66
[torchlib] Implement quantize_per_channel and dequantize_per_channel #2390
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Co-authored-by: justinchuby <[email protected]>
quant_min: int, | ||
quant_max: int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these unused? Why is that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These parameters are unused because the ONNX QuantizeLinear/DequantizeLinear operators don't require explicit quant_min/quant_max parameters - they determine the quantization range from the data type and quantization parameters. The parameters are kept in the function signature for API compatibility with PyTorch's reference implementation, following the same pattern as the existing per-tensor functions above.
❌ 25 Tests Failed:
View the top 2 failed test(s) by shortest run time
View the full list of 1 ❄️ flaky tests
To view more test analytics, go to the Test Analytics Dashboard |
from onnxscript.onnx_types import TensorType | ||
from typing import Optional |
Check notice
Code scanning / lintrunner
PYLINT/C0411 Note
See wrong-import-order. To disable, use # pylint: disable=wrong-import-order
) -> TensorType: | ||
"""Affine per channel quantization for the Tensor using the same quantization | ||
parameters for each channel/axis to map from floating point to quantized values. | ||
|
Check warning
Code scanning / lintrunner
RUFF/W293 Warning
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace
) -> TensorType: | ||
"""Affine per channel quantization for the Tensor using the same quantization | ||
parameters for each channel/axis to map from floating point to quantized values. | ||
|
Check warning
Code scanning / lintrunner
EDITORCONFIG-CHECKER/editorconfig Warning
) -> TensorType: | ||
"""Affine per channel dequantization for the Tensor using the same quantization | ||
parameters for each channel/axis to map from quantized values to floating point values. | ||
|
Check warning
Code scanning / lintrunner
RUFF/W293 Warning
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace
) -> TensorType: | ||
"""Affine per channel dequantization for the Tensor using the same quantization | ||
parameters for each channel/axis to map from quantized values to floating point values. | ||
|
Check warning
Code scanning / lintrunner
EDITORCONFIG-CHECKER/editorconfig Warning
This PR implements the missing
quantize_per_channel
anddequantize_per_channel
operations in the torchlib quantized_decomposed module.Changes
Added two new functions to
onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py
:quantized_decomposed_quantize_per_channel
scales
andzero_points
(one value per channel)axis
parameter to specify the quantization dimensionquantized_decomposed_dequantize_per_channel
scales
and optionalzero_points
zero_points
parameter isOptional[TensorType]
matching PyTorch referenceoutput_dtype
parameterImplementation Details
Both functions:
@torch_op
decorator withtrace_only=True
torch.ao.quantization.fx._decomposed
axis
andoutput_dtype
parameters for per-axis quantizationThe implementation leverages ONNX's native per-axis quantization support rather than implementing the tensor manipulation logic from the PyTorch reference, making it more efficient and aligned with ONNX best practices.
Testing
Validated that:
Fixes #2389.
💬 Share your feedback on Copilot coding agent for the chance to win a $200 gift card! Click here to start the survey.