Skip to content

[torchlib] Implement quantize_per_channel and dequantize_per_channel #2390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@

from __future__ import annotations

from typing import Optional

from onnxscript.function_libs.torch_lib.ops import common
from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_opset import opset23 as op23
from onnxscript.onnx_types import TensorType


Expand Down Expand Up @@ -61,3 +64,63 @@ def quantized_decomposed_dequantize_per_tensor(
return dequantized
assert out_dtype > 0, f"out_dtype must be -1 or > 0 not {out_dtype}"
return op.Cast(dequantized, to=out_dtype)


@torch_op(
(
"quantized_decomposed::quantize_per_channel",
"quantized_decomposed::quantize_per_channel.tensor",
"quantized_decomposed::quantize_per_channel.tensor2",
),
trace_only=True,
)
def quantized_decomposed_quantize_per_channel(
input: TensorType,
scales: TensorType,
zero_points: TensorType,
axis: int,
quant_min: int,
quant_max: int,
Comment on lines +82 to +83
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these unused? Why is that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These parameters are unused because the ONNX QuantizeLinear/DequantizeLinear operators don't require explicit quant_min/quant_max parameters - they determine the quantization range from the data type and quantization parameters. The parameters are kept in the function signature for API compatibility with PyTorch's reference implementation, following the same pattern as the existing per-tensor functions above.

dtype: int,
) -> TensorType:
"""Affine per channel quantization for the Tensor using the same quantization
parameters for each channel/axis to map from floating point to quantized values.

Uses ONNX QuantizeLinear with per-axis quantization support.
"""
# Use opset23 for per-axis quantization support
return op23.QuantizeLinear(input, scales, zero_points, axis=axis, output_dtype=dtype)


@torch_op(
(
"quantized_decomposed::dequantize_per_channel",
"quantized_decomposed::dequantize_per_channel.tensor",
"quantized_decomposed::dequantize_per_channel.tensor2",
),
trace_only=True,
)
def quantized_decomposed_dequantize_per_channel(
input: TensorType,
scales: TensorType,
zero_points: Optional[TensorType],
axis: int,
quant_min: int,
quant_max: int,
dtype: int,
out_dtype: int = -1,
) -> TensorType:
"""Affine per channel dequantization for the Tensor using the same quantization
parameters for each channel/axis to map from quantized values to floating point values.

Uses ONNX DequantizeLinear with per-axis quantization support.
"""
# Use opset23 for per-axis quantization support with optional output_dtype
if out_dtype in (-1, None):
# Use default output type (same as scales type)
return op23.DequantizeLinear(input, scales, zero_points, axis=axis)
else:
assert out_dtype > 0, f"out_dtype must be -1 or > 0 not {out_dtype}"
return op23.DequantizeLinear(
input, scales, zero_points, axis=axis, output_dtype=out_dtype
)
Loading