Skip to content

Commit 501d118

Browse files
jerryzh168pytorchmergebot
authored andcommitted
[quant][pt2e] Add transform_for_annotation method in Quantizer (pytorch#113115)
Summary: Adding the method so that people can do some transformations before annotation to make the graph easier to annotate Test Plan: python test/test_quantization.py TestQuantizePT2E.test_transform_for_annotation Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D51141080](https://our.internmc.facebook.com/intern/diff/D51141080) Pull Request resolved: pytorch#113115 Approved by: https://github.com/kimishpatel
1 parent e53da90 commit 501d118

File tree

9 files changed

+96
-3
lines changed

9 files changed

+96
-3
lines changed

docs/source/conf.py

-1
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,6 @@
358358
# torch.ao.quantization.quantizer.embedding_quantizer
359359
"get_embedding_operators_config",
360360
# torch.ao.quantization.quantizer.xnnpack_quantizer_utils
361-
"convert_scalars_to_attrs",
362361
"get_bias_qspec",
363362
"get_input_act_qspec",
364363
"get_output_act_qspec",

test/allowlist_for_publicAPI.json

-1
Original file line numberDiff line numberDiff line change
@@ -1567,7 +1567,6 @@
15671567
"propagate_annotation"
15681568
],
15691569
"torch.ao.quantization.quantizer.xnnpack_quantizer_utils": [
1570-
"convert_scalars_to_attrs",
15711570
"register_annotator"
15721571
],
15731572
"torch.backends.xeon.run_cpu": [

test/quantization/pt2e/test_quantize_pt2e.py

+30
Original file line numberDiff line numberDiff line change
@@ -1358,6 +1358,36 @@ def validate(self, model: torch.fx.GraphModule) -> None:
13581358
),
13591359
)
13601360

1361+
def test_transform_for_annotation(self):
1362+
class TestQuantizer(Quantizer):
1363+
def transform_for_annotation(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
1364+
for n in model.graph.nodes:
1365+
if n.target == torch.ops.aten.add.Tensor:
1366+
n.target = torch.ops.aten.mul.Tensor
1367+
return model
1368+
1369+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
1370+
return model
1371+
1372+
def validate(self, model: torch.fx.GraphModule) -> None:
1373+
pass
1374+
1375+
class M(torch.nn.Module):
1376+
def forward(self, x):
1377+
return x + 3
1378+
1379+
m = M().eval()
1380+
quantizer = TestQuantizer()
1381+
example_inputs = (torch.randn(1, 2, 3, 3),)
1382+
m = capture_pre_autograd_graph(m, example_inputs)
1383+
m = prepare_pt2e(m, quantizer)
1384+
m(*example_inputs)
1385+
node_occurrence = {
1386+
ns.call_function(torch.ops.aten.add.Tensor): 0,
1387+
ns.call_function(torch.ops.aten.mul.Tensor): 1,
1388+
}
1389+
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
1390+
13611391
def test_embedding_quantizer(self):
13621392
m_eager = TestHelperModules.EmbeddingModule().eval()
13631393
indices = torch.tensor(

test/quantization/pt2e/test_xnnpack_quantizer.py

+33
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,39 @@ def test_mul_and_inplace_mul(self):
665665
node_list,
666666
)
667667

668+
def test_add_mul_scalar(self):
669+
quantizer = XNNPACKQuantizer()
670+
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
671+
quantizer.set_global(quantization_config)
672+
example_inputs = (torch.randn(1, 3, 5, 5),)
673+
node_occurrence = {
674+
# two input and one output for first add, and output for second add
675+
torch.ops.quantized_decomposed.quantize_per_tensor.default: 5,
676+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 7,
677+
}
678+
node_list = [
679+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
680+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
681+
torch.ops.aten.add.Tensor,
682+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
683+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
684+
torch.ops.aten.mul.Tensor,
685+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
686+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
687+
torch.ops.aten.add_.Tensor,
688+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
689+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
690+
torch.ops.aten.mul_.Tensor,
691+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
692+
]
693+
self._test_quantizer(
694+
TestHelperModules.AddMulScalar(),
695+
example_inputs,
696+
quantizer,
697+
node_occurrence,
698+
node_list,
699+
)
700+
668701

669702
# TODO: express this using self._test_quantizer, add test for inception_v4
670703
class TestXNNPACKQuantizerModels(PT2EQuantizationTestCase):

torch/ao/quantization/quantize_pt2e.py

+2
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def calibrate(model, data_loader):
103103
# to be quantized before fusion
104104
# TODO: (maybe) rewrite this with subgraph_rewriter
105105
_fuse_conv_bn_(model)
106+
quantizer.transform_for_annotation(model)
106107
quantizer.annotate(model)
107108
quantizer.validate(model)
108109
model = prepare(model, node_name_to_scope, is_qat=False)
@@ -170,6 +171,7 @@ def train_loop(model, train_data):
170171
torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_qat_pt2e")
171172
original_graph_meta = model.meta
172173
node_name_to_scope = _get_node_name_to_scope(model)
174+
quantizer.transform_for_annotation(model)
173175
quantizer.annotate(model)
174176
quantizer.validate(model)
175177
# Perform fusion after annotate to avoid quantizing ops in the new

torch/ao/quantization/quantizer/quantizer.py

+14
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,20 @@ class QuantizationAnnotation:
157157

158158

159159
class Quantizer(ABC):
160+
def transform_for_annotation(
161+
self, model: torch.fx.GraphModule
162+
) -> torch.fx.GraphModule:
163+
"""Allows for user defined transforms to run before annotating the graph.
164+
This allows quantizer to allow quantizing part of the model that are otherwise not quantizable.
165+
For example quantizer can
166+
a) decompose a compound operator like scaled dot product attention,
167+
into bmm and softmax if quantizer knows how to quantize bmm/softmax but not sdpa
168+
or b) transform scalars to tensor to allow quantizing scalares.
169+
170+
Note: this is an optional method
171+
"""
172+
return model
173+
160174
# annotate nodes in the graph with observer or fake quant constructors
161175
# to convey the desired way of quantization
162176
@abstractmethod

torch/ao/quantization/quantizer/xnnpack_quantizer.py

+7
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
2424

2525
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
26+
_convert_scalars_to_attrs,
2627
OP_TO_ANNOTATOR,
2728
OperatorConfig,
2829
OperatorPatternType,
@@ -341,6 +342,12 @@ def set_module_name(
341342
self.module_name_config[module_name] = quantization_config
342343
return self
343344

345+
def transform_for_annotation(
346+
self, model: torch.fx.GraphModule
347+
) -> torch.fx.GraphModule:
348+
"""Transforms scalar values to tensor attributes"""
349+
return _convert_scalars_to_attrs(model)
350+
344351
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
345352
"""just handling global spec for now"""
346353
# hacked for handling dynamic linear quant. will fix later.

torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -861,7 +861,8 @@ def propagate_annotation(model: torch.fx.GraphModule) -> None:
861861
)
862862

863863

864-
def convert_scalars_to_attrs(model: torch.fx.GraphModule) -> torch.fx.GraphModule:
864+
# TODO: make the list of ops customizable
865+
def _convert_scalars_to_attrs(model: torch.fx.GraphModule) -> torch.fx.GraphModule:
865866
for n in model.graph.nodes:
866867
if n.op != "call_function" or n.target not in [
867868
torch.ops.aten.add.Tensor,

torch/testing/_internal/common_quantization.py

+8
Original file line numberDiff line numberDiff line change
@@ -2628,6 +2628,14 @@ def forward(self, x, y):
26282628
x *= y
26292629
return x
26302630

2631+
class AddMulScalar(torch.nn.Module):
2632+
def forward(self, x):
2633+
x = x + 3
2634+
x = x * 3
2635+
x += 3
2636+
x *= 3
2637+
return x
2638+
26312639
class ConvBnReLU2dAndLinearReLU(torch.nn.Module):
26322640
def __init__(self):
26332641
super().__init__()

0 commit comments

Comments
 (0)