Skip to content

[torchlib] Implement quantize_per_channel and dequantize_per_channel #2390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

Copilot
Copy link

@Copilot Copilot AI commented Jun 14, 2025

This PR implements the missing quantize_per_channel and dequantize_per_channel operations in the torchlib quantized_decomposed module.

Changes

Added two new functions to onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py:

quantized_decomposed_quantize_per_channel

  • Implements per-channel quantization using ONNX QuantizeLinear with per-axis support
  • Takes tensor inputs for scales and zero_points (one value per channel)
  • Supports axis parameter to specify the quantization dimension
  • Uses ONNX opset23 for per-axis quantization capabilities

quantized_decomposed_dequantize_per_channel

  • Implements per-channel dequantization using ONNX DequantizeLinear with per-axis support
  • Takes tensor inputs for scales and optional zero_points
  • zero_points parameter is Optional[TensorType] matching PyTorch reference
  • Supports both default output type and explicit output_dtype parameter

Implementation Details

Both functions:

  • Follow the existing code patterns using @torch_op decorator with trace_only=True
  • Have function signatures that match the PyTorch reference implementations from torch.ao.quantization.fx._decomposed
  • Use ONNX opset23 to leverage the axis and output_dtype parameters for per-axis quantization
  • Are properly registered and accessible as TracedOnnxFunctions

The implementation leverages ONNX's native per-axis quantization support rather than implementing the tensor manipulation logic from the PyTorch reference, making it more efficient and aligned with ONNX best practices.

Testing

Validated that:

  • Functions import correctly and have proper signatures
  • Function schemas are correctly generated with appropriate parameter types
  • No existing functionality is broken (per-tensor functions still work)
  • Functions can be used to build ONNX models

Fixes #2389.


💬 Share your feedback on Copilot coding agent for the chance to win a $200 gift card! Click here to start the survey.

@Copilot Copilot AI changed the title [WIP] [torchlib] Implement quantize_per_channel and dequantize_per_channel [torchlib] Implement quantize_per_channel and dequantize_per_channel Jun 14, 2025
@Copilot Copilot AI requested a review from justinchuby June 14, 2025 15:50
Copilot finished work on behalf of justinchuby June 14, 2025 15:50
Comment on lines +81 to +82
quant_min: int,
quant_max: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these unused? Why is that?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These parameters are unused because the ONNX QuantizeLinear/DequantizeLinear operators don't require explicit quant_min/quant_max parameters - they determine the quantization range from the data type and quantization parameters. The parameters are kept in the function signature for API compatibility with PyTorch's reference implementation, following the same pattern as the existing per-tensor functions above.

Copilot finished work on behalf of justinchuby June 14, 2025 16:07
@justinchuby justinchuby marked this pull request as ready for review June 14, 2025 16:10
Copy link

codecov bot commented Jun 14, 2025

❌ 25 Tests Failed:

Tests completed Failed Passed Skipped
15725 25 15700 1697
View the top 2 failed test(s) by shortest run time
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_0362_test_equal_bcast
Stack Traces | 0.004s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
C:\hostedtoolcache\windows\Python\3.11.9\x64\Lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_equal_bcast'

The above exception was the direct cause of the following exception:
.nox\test_ort_nightly\Lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_equal_bcast' (e=No module named 'tests.onnx_backend_test_code.test_equal_bcast') (file: 'C:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_equal_bcast.py', absolute path: 'C:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_equal_bcast.py', current folder: C:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import BOOL, INT32
E   from onnxscript.onnx_opset import opset19
E   
E   @script()
E   def bck_test_equal_bcast(x: INT32[3,4,5], y: INT32[5]) -> (BOOL[3,4,5]):
E       z = opset19.Equal(x, y)
E       return z
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_0788_test_quantizelinear_blocked_symmetric
Stack Traces | 0.004s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
C:\hostedtoolcache\windows\Python\3.11.9\x64\Lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_quantizelinear_blocked_symmetric'

The above exception was the direct cause of the following exception:
.nox\test_ort_nightly\Lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_quantizelinear_blocked_symmetric' (e=No module named 'tests.onnx_backend_test_code.test_quantizelinear_blocked_symmetric') (file: 'C:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_quantizelinear_blocked_symmetric.py', absolute path: 'C:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_quantizelinear_blocked_symmetric.py', current folder: C:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import FLOAT, INT16
E   from onnxscript.onnx_opset import opset21
E   
E   @script()
E   def bck_test_quantizelinear_blocked_symmetric(x: FLOAT[3,4], y_scale: FLOAT[3,2]) -> (INT16[3,4]):
E       y = opset21.QuantizeLinear(x, y_scale, axis=1, block_size=2, output_dtype=5)
E       return y
View the full list of 1 ❄️ flaky tests
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_0759_test_or_bcast4v4d

Flake rate in main: 5.26% (Passed 36 times, Failed 2 times)

Stack Traces | 0.004s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
C:\hostedtoolcache\windows\Python\3.11.9\x64\Lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_or_bcast4v4d'

The above exception was the direct cause of the following exception:
.nox\test_ort_nightly\Lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_or_bcast4v4d' (e=No module named 'tests.onnx_backend_test_code.test_or_bcast4v4d') (file: 'C:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_or_bcast4v4d.py', absolute path: 'C:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_or_bcast4v4d.py', current folder: C:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import BOOL
E   from onnxscript.onnx_opset import opset7
E   
E   @script()
E   def bck_test_or_bcast4v4d(x: BOOL[1,4,1,6], y: BOOL[3,1,5,6]) -> (BOOL[3,4,5,6]):
E       r_or = opset7.Or(x, y)
E       return r_or

To view more test analytics, go to the Test Analytics Dashboard
📋 Got 3 mins? Take this short survey to help us improve Test Analytics.

from onnxscript.onnx_types import TensorType
from typing import Optional

Check notice

Code scanning / lintrunner

PYLINT/C0411 Note

standard import "typing.Optional" should be placed before first party imports "onnxscript.function_libs.torch_lib.ops.common", "onnxscript.function_libs.torch_lib.registration.torch_op", "onnxscript.onnx_opset.opset18", "onnxscript.onnx_opset.opset23", "onnxscript.onnx_types.TensorType" (wrong-import-order)
See wrong-import-order. To disable, use # pylint: disable=wrong-import-order
) -> TensorType:
"""Affine per channel quantization for the Tensor using the same quantization
parameters for each channel/axis to map from floating point to quantized values.

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

) -> TensorType:
"""Affine per channel quantization for the Tensor using the same quantization
parameters for each channel/axis to map from floating point to quantized values.

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
) -> TensorType:
"""Affine per channel dequantization for the Tensor using the same quantization
parameters for each channel/axis to map from quantized values to floating point values.

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

) -> TensorType:
"""Affine per channel dequantization for the Tensor using the same quantization
parameters for each channel/axis to map from quantized values to floating point values.

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Development

Successfully merging this pull request may close these issues.

[torchlib] Implement quantize_per_channel and dequantize_per_channel
2 participants