Skip to content

Commit 6ac176f

Browse files
committed
Cortex_m backend: Add IO quantizers + tests of non rescaling ops
A number of ops only handles shape/meta-data without changing the dynamic range. In these cases, no rescaling needs to be performed and the int8 portable_ops kernel can be used directly. A new test is added to ensure this behaviour, as well as a test showing how operators which does change the dynamic range (SUB) are not supported. To support quantization of graphs with no-rescale ops in the beginning/ end of the graph, two new quantizers InputQuantizer and OutputQuantizer are introduced. By explicitly stating the dtype of the input/output, no-rescale ops inherit dtypes from them as with any other op. This change exposes the issue of mixing dtypes within the graph, which adds back xfails for the broadcasted add and mul tests. This can be fixed in a future patch after pytorch#15300 is resolved. Signed-off-by: Adrian Lundell <[email protected]> Change-Id: I8f79b86b633f9ad8d9f183c914754b0ee2f7a87c
1 parent 1812c81 commit 6ac176f

File tree

4 files changed

+225
-0
lines changed

4 files changed

+225
-0
lines changed

backends/cortex_m/quantizer/quantizer.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
INT8_BINARY_OPS_OPERATOR_CONFIG,
1616
INT8_LINEAR_OPERATOR_CONFIG,
1717
)
18+
from executorch.backends.cortex_m.quantizer.quantization_configs import (
19+
INT8_PER_TENSOR_CONFIG,
20+
)
1821
from torch._ops import OpOverload
1922
from torch.fx import GraphModule, Node
2023
from torchao.quantization.pt2e.quantizer import (
@@ -49,6 +52,8 @@ def __init__(self) -> None:
4952
INT8_BINARY_OPS_OPERATOR_CONFIG, filter_fn=self.broadcasting_filter
5053
),
5154
OperatorConfigQuantizer(INT8_LINEAR_OPERATOR_CONFIG),
55+
InputQuantizer(INT8_PER_TENSOR_CONFIG),
56+
OutputQuantizer(INT8_PER_TENSOR_CONFIG),
5257
]
5358
super().__init__(quantizers)
5459

@@ -196,3 +201,58 @@ def annotate(self, model: GraphModule) -> None:
196201

197202
def validate(self, model: GraphModule) -> bool:
198203
return True
204+
205+
206+
class InputQuantizer(Quantizer):
207+
"""
208+
Quantizes only the input activations of the graph.
209+
"""
210+
211+
def __init__(
212+
self,
213+
quantization_config: QuantizationConfig,
214+
filter_fn: Callable[[Node], bool] = lambda node: False,
215+
) -> None:
216+
self.quantization_config = quantization_config
217+
self.filter_fn = filter_fn
218+
219+
def annotate(self, model: GraphModule) -> None:
220+
for node in model.graph.nodes:
221+
is_placeholder = node.op == "placeholder"
222+
is_filtered = self.filter_fn(node)
223+
if is_placeholder and not is_filtered:
224+
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
225+
{}, self.quantization_config.output_activation
226+
)
227+
228+
def validate(self, model: GraphModule) -> bool:
229+
return True
230+
231+
232+
class OutputQuantizer(Quantizer):
233+
"""
234+
Quantizes only the output activations of the graph.
235+
"""
236+
237+
def __init__(
238+
self,
239+
quantization_config: QuantizationConfig,
240+
filter_fn: Callable[[Node], bool] = lambda node: False,
241+
) -> None:
242+
self.quantization_config = quantization_config
243+
self.filter_fn = filter_fn
244+
245+
def annotate(self, model: GraphModule) -> None:
246+
output_node = model.graph.output_node()
247+
input_qspec_map = {
248+
n: self.quantization_config.input_activation
249+
for n in output_node.all_input_nodes
250+
if not self.filter_fn(n)
251+
}
252+
output_qspec = self.quantization_config.output_activation
253+
output_node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
254+
input_qspec_map, output_qspec
255+
)
256+
257+
def validate(self, model: GraphModule) -> bool:
258+
return True
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
import pytest
8+
import torch
9+
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
10+
from executorch.backends.arm.test.common import parametrize
11+
from executorch.backends.cortex_m.test.tester import (
12+
CortexMTester,
13+
McuTestCase,
14+
ramp_tensor,
15+
)
16+
from executorch.exir.dialects._ops import ops as exir_ops
17+
18+
19+
class CortexMInheritAllOps(torch.nn.Module):
20+
ops_before_transforms = {
21+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3,
22+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3,
23+
}
24+
25+
ops_after_transforms = {
26+
"executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1,
27+
"executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1,
28+
}
29+
30+
def forward(self, x):
31+
# x shape: (1, 3, 4, 5)
32+
x = x + x
33+
x = torch.ops.aten.squeeze.default(x) # Remove dim 0: (3, 4, 5)
34+
x = torch.ops.aten.unsqueeze.default(x, 0) # Add dim at 0: (1, 3, 4, 5)
35+
x = torch.ops.aten.squeeze_copy.default(x) # (3, 4, 5)
36+
x = torch.ops.aten.unsqueeze_copy.default(x, 0) # (1, 3, 4, 5)
37+
x = torch.ops.aten.squeeze.dims(x, [0]) # (3, 4, 5)
38+
x = torch.ops.aten.squeeze_copy.dim(
39+
x, 0
40+
) # Remove first dim if size 1, otherwise same
41+
x = torch.ops.aten.squeeze.dim(x, 0) # Same
42+
x = torch.ops.aten.unbind.int(x, 0)[0] # Unbind and take first: (4, 5)
43+
x = torch.ops.aten.reshape.default(x, (1, 4, 5, 1)) # (1, 4, 5, 1)
44+
x = torch.ops.aten.repeat.default(x, [1, 1, 1, 2]) # (1, 4, 5, 2)
45+
x = torch.ops.aten.view.default(x, (1, 4, 10)) # (1, 4, 10)
46+
target_shape = torch.zeros(1, 4, 10)
47+
x = torch.ops.aten.view_as.default(x, target_shape) # (1, 4, 10)
48+
x = torch.ops.aten.view_copy.default(x, (1, 2, 20)) # (1, 2, 20)
49+
x = torch.ops.aten.unflatten.int(x, 2, [4, 5]) # (1, 2, 4, 5)
50+
x = torch.ops.aten.flatten.using_ints(x, 1, 3) # (1, 40)
51+
x = torch.ops.aten.repeat_interleave.self_int(x, 2, 1) # (1, 80)
52+
x = torch.ops.aten.expand_copy.default(x, (2, 80)) # (2, 80)
53+
x = torch.ops.aten.expand.default(x, (2, 80)) # (2, 80)
54+
x = torch.ops.aten.tile.default(x, [1, 1]) # (2, 80)
55+
x = torch.ops.aten.split.Tensor(x, 40, 1)[0] # (2, 40)
56+
x = torch.ops.aten.split_with_sizes.default(x, [20, 20], 1)[0] # (2, 20)
57+
x = torch.ops.aten.split_copy.Tensor(x, 10, 1)[0] # (2, 10)
58+
x = torch.ops.aten.chunk.default(x, 2, 1)[0] # (2, 5)
59+
x = torch.ops.aten.pad.default(x, [1, 1, 0, 0], "constant", 0.0) # (2, 7)
60+
x = torch.ops.aten.select.int(x, 1, 0) # (2,)
61+
x = torch.ops.aten.select_copy.int(x, 0, 0) # scalar -> need to reshape
62+
x = torch.ops.aten.unsqueeze.default(x, 0) # (1,)
63+
x = torch.ops.aten.unsqueeze.default(x, 1) # (1, 1)
64+
x = torch.ops.aten.slice.Tensor(x, 0, 0, 1) # (1, 1)
65+
x = torch.ops.aten.slice_copy.Tensor(x, 1, 0, 1) # (1, 1)
66+
x = torch.ops.aten.reshape.default(x, (1, 1)) # Ensure shape for transpose
67+
x = torch.ops.aten.transpose.int(x, 0, 1) # (1, 1)
68+
x = torch.ops.aten.transpose_copy.int(x, 0, 1) # (1, 1)
69+
x = torch.ops.aten.t_copy.default(x) # (1, 1)
70+
x = torch.ops.aten.contiguous.default(x) # (1, 1)
71+
x = torch.ops.aten.permute.default(x, [1, 0]) # (1, 1)
72+
x = torch.ops.aten.permute_copy.default(x, [0, 1]) # (1, 1)
73+
x = torch.ops.aten.flip.default(x, [0]) # (1, 1)
74+
y = torch.zeros_like(x)
75+
x = torch.ops.aten.copy_.default(y, x) # (1, 1)
76+
x = torch.ops.aten.clone.default(x) # (1, 1)
77+
return x
78+
79+
80+
class CortexMOnlyInheritOps(torch.nn.Module):
81+
ops_before_transforms = {
82+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2,
83+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2,
84+
}
85+
86+
ops_after_transforms = {
87+
"executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1,
88+
"executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1,
89+
}
90+
91+
def forward(self, x):
92+
return torch.permute(torch.clone(x), (0, 1, 3, 2))
93+
94+
95+
class CortexMQuantizeNonSupportedSub(torch.nn.Module):
96+
ops_before_transforms = {}
97+
98+
ops_after_transforms = {}
99+
100+
def forward(self, x, y):
101+
return y - x
102+
103+
104+
test_cases = {
105+
"all_ops": McuTestCase(
106+
CortexMInheritAllOps(),
107+
(ramp_tensor(0, 10, (1, 3, 4, 5)),),
108+
),
109+
"only_inherit_ops": McuTestCase(
110+
CortexMOnlyInheritOps(),
111+
(ramp_tensor(0, 10, (1, 3, 4, 5)),),
112+
),
113+
}
114+
115+
116+
@parametrize("test_case", test_cases)
117+
def test_inherit_int8_dtype(test_case):
118+
"""
119+
Test that ops which does not change dynamic range are able to use int8 portable kernels.
120+
"""
121+
tester = CortexMTester(test_case.model, test_case.example_inputs)
122+
tester.test_dialect(
123+
test_case.model.ops_before_transforms, test_case.model.ops_after_transforms
124+
)
125+
126+
# Check that all nodes in the graph are in int8
127+
artifact = tester.get_artifact()
128+
for node in artifact.exported_program().module().graph.nodes:
129+
if node.op != "call_function":
130+
continue
131+
if node.target == exir_ops.edge.cortex_m.dequantize_per_tensor.default:
132+
continue
133+
134+
assert get_first_fake_tensor(node).dtype == torch.int8, f"{node.name}"
135+
136+
137+
test_cases = {
138+
"sub": McuTestCase(
139+
CortexMQuantizeNonSupportedSub(),
140+
(ramp_tensor(0, 10, (1, 3, 4, 5)), ramp_tensor(0, 1, (1, 3, 4, 5))),
141+
),
142+
}
143+
144+
145+
@pytest.mark.xfail(
146+
reason="Non handled ops which change dynamic range currently not supported."
147+
)
148+
@parametrize("test_case", test_cases)
149+
def test_quantize_unsupported_op(test_case):
150+
"""
151+
Test an op which does change dynamic range and which is not suported by CMSIS-NN. Currently not supported.
152+
"""
153+
tester = CortexMTester(test_case.model, test_case.example_inputs)
154+
tester.test_dialect(
155+
test_case.model.ops_before_transforms, test_case.model.ops_after_transforms
156+
)

backends/cortex_m/test/ops/test_add.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,9 @@ class CortexMAlphaAdd(ModelAlpha):
163163
"Expecting kwargs for aten op IR to be empty - alpha arg not supported.",
164164
AssertionError,
165165
),
166+
"broadcast_1": "Mixed fp/quantized ops currently not supported.",
167+
"broadcast_2": "Mixed fp/quantized ops currently not supported.",
168+
"broadcast_3": "Mixed fp/quantized ops currently not supported.",
166169
}
167170

168171

@@ -187,6 +190,9 @@ def test_dialect_add(test_case):
187190
"Expecting kwargs for aten op IR to be empty - alpha arg not supported.",
188191
AssertionError,
189192
),
193+
"broadcast_1": "Mixed fp/quantized ops currently not supported.",
194+
"broadcast_2": "Mixed fp/quantized ops currently not supported.",
195+
"broadcast_3": "Mixed fp/quantized ops currently not supported.",
190196
}
191197

192198

backends/cortex_m/test/ops/test_mul.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@ class CortexMTensorMulBroadCast(Model):
127127
xfail_cases = {
128128
"self_scalar": "lift_constant_tensor_pass assumes fake tensors for scalars",
129129
"scalar_scalar": "lift_constant_tensor_pass assumes fake tensors for scalars",
130+
"broadcast_1": "Mixed fp/quantized ops currently not supported.",
131+
"broadcast_2": "Mixed fp/quantized ops currently not supported.",
132+
"broadcast_3": "Mixed fp/quantized ops currently not supported.",
130133
}
131134

132135

0 commit comments

Comments
 (0)