Skip to content

Commit 1e64ab6

Browse files
Marco Giordanofacebook-github-bot
authored andcommitted
Adding mixed quantization support
Summary: # Context This Diff adds support for mixed quantization operators in Executorch. Now weights and biases can be quantized, while inputs and activations are kept in floating point. # In this diff 1. Op nodes are returned from each pattern matching 2. Dequantize nodes are bypassed if not needed in the final graph. Differential Revision: D81519735
1 parent 8973eeb commit 1e64ab6

File tree

3 files changed

+55
-41
lines changed

3 files changed

+55
-41
lines changed

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
390390
pattern.partition_types(),
391391
)
392392
for fused_partition in fused_partitions:
393-
anchors = pattern.get_anchors(graph_module, fused_partition)
393+
anchors, op_node = pattern.get_anchors(graph_module, fused_partition)
394394
if not anchors or anchors.empty:
395395
continue
396396
if any(self.is_fused(p.nodes) for p in fused_partition):
@@ -431,11 +431,17 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
431431
bias_inputs = [node.args[0] for node in dequants_biases]
432432
other_inputs = [node.args[idx] for node, idx in anchors.others]
433433

434-
# The node is the first index of the list and first of the tuple
435-
op_node = anchors.output[0][0]
434+
# Check if there's a quantization node after the operation
435+
quant_node = None
436+
has_quant_node = False
436437

437-
assert len(op_node.users) == 1
438-
quant_node = list(op_node.users.keys())[0]
438+
if len(op_node.users) == 1:
439+
potential_quant = list(op_node.users.keys())[0]
440+
# Check if it's actually a quantization node
441+
if (hasattr(potential_quant, 'target') and
442+
potential_quant.target == torch.ops.quantized_decomposed.quantize_per_tensor.default):
443+
quant_node = potential_quant
444+
has_quant_node = True
439445

440446
with graph_module.graph.inserting_after(op_node):
441447
args = tuple(
@@ -516,8 +522,16 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
516522
args,
517523
kwargs,
518524
)
519-
fused.meta = quant_node.meta
520-
quant_node.replace_all_uses_with(fused)
525+
if has_quant_node:
526+
fused.meta = quant_node.meta
527+
quant_node.replace_all_uses_with(fused)
528+
if quant_node.op == "output":
529+
_ = graph_module.graph.output((fused,))
530+
else:
531+
fused.meta = op_node.meta
532+
op_node.replace_all_uses_with(fused)
533+
if op_node.op == "output":
534+
_ = graph_module.graph.output((fused,))
521535

522536
legalize_graph(graph_module)
523537
graph_module.graph.eliminate_dead_code()

backends/cadence/aot/quantizer/patterns.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def partition_types(self) -> list[OpOverload]:
6767
@abstractmethod
6868
def get_anchors(
6969
self, gm: torch.fx.GraphModule, fused_partition: List[fx.GraphModule]
70-
) -> Optional[PartitionAnchors]:
70+
) -> Tuple[PartitionAnchors, fx.Node]:
7171
pass
7272

7373
@abstractmethod
@@ -85,7 +85,7 @@ def partition_types(self) -> List[OpOverload]:
8585

8686
def get_anchors(
8787
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
88-
) -> PartitionAnchors:
88+
) -> Tuple[PartitionAnchors, fx.Node]:
8989
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
9090
addmm_node = fused_partition[0].nodes[-1]
9191

@@ -101,12 +101,12 @@ def get_anchors(
101101
qscheme=torch.per_tensor_affine,
102102
)
103103

104-
return PartitionAnchors(
104+
return (PartitionAnchors(
105105
inputs=[(addmm_node, 1)],
106106
weights=[(addmm_node, 2)],
107107
biases=[(addmm_node, 0, bias_qspec)],
108108
output=[(addmm_node,)],
109-
)
109+
), addmm_node)
110110

111111
def replacement_op(self) -> OpOverload:
112112
return torch.ops.cadence.quantized_linear.default
@@ -118,7 +118,7 @@ def partition_types(self) -> List[OpOverload]:
118118

119119
def get_anchors(
120120
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
121-
) -> PartitionAnchors:
121+
) -> Tuple[PartitionAnchors, fx.Node]:
122122
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
123123
add_node = fused_partition[0].nodes[-1]
124124

@@ -129,16 +129,16 @@ def get_anchors(
129129
add_node.args[1], fx.Node
130130
)
131131
if not is_tensor_add or len(add_node.kwargs) > 0:
132-
return PartitionAnchors(
132+
return (PartitionAnchors(
133133
empty=True,
134-
)
134+
), add_node)
135135

136-
return PartitionAnchors(
136+
return (PartitionAnchors(
137137
inputs=[(add_node, 0), (add_node, 1)],
138138
weights=[],
139139
biases=[],
140140
output=[(add_node,)],
141-
)
141+
), add_node)
142142

143143
def replacement_op(self) -> OpOverload:
144144
return torch.ops.cadence.quantized_add.default
@@ -150,16 +150,16 @@ def partition_types(self) -> List[OpOverload]:
150150

151151
def get_anchors(
152152
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
153-
) -> PartitionAnchors:
153+
) -> Tuple[PartitionAnchors, fx.Node]:
154154
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
155155
bmm_node = fused_partition[0].nodes[-1]
156156

157-
return PartitionAnchors(
157+
return (PartitionAnchors(
158158
inputs=[(bmm_node, 0), (bmm_node, 1)],
159159
weights=[],
160160
biases=[],
161161
output=[(bmm_node,)],
162-
)
162+
), bmm_node)
163163

164164
def replacement_op(self) -> OpOverload:
165165
return torch.ops.cadence.quantized_matmul.default
@@ -171,7 +171,7 @@ def partition_types(self) -> List[OpOverload]:
171171

172172
def get_anchors(
173173
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
174-
) -> PartitionAnchors:
174+
) -> Tuple[PartitionAnchors, fx.Node]:
175175
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
176176
cat_node = fused_partition[0].nodes[-1]
177177

@@ -198,14 +198,14 @@ def get_anchors(
198198
)
199199
)
200200

201-
return PartitionAnchors(
201+
return (PartitionAnchors(
202202
inputs=args,
203203
weights=[],
204204
biases=[],
205205
output=[
206206
(cat_node, SharedQuantizationSpec((cat_node.args[0][0], cat_node)))
207207
],
208-
)
208+
), cat_node)
209209

210210
def replacement_op(self) -> OpOverload:
211211
return torch.ops.aten.cat.default
@@ -217,7 +217,7 @@ def partition_types(self) -> List[OpOverload]:
217217

218218
def get_anchors(
219219
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
220-
) -> PartitionAnchors:
220+
) -> Tuple[PartitionAnchors, fx.Node]:
221221
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
222222
conv1d_node = fused_partition[0].nodes[-1]
223223

@@ -238,13 +238,13 @@ def get_anchors(
238238
if len(conv1d_node.args) > 2 and conv1d_node.args[2] is not None:
239239
bias = [(conv1d_node, 2, bias_qspec)]
240240

241-
return PartitionAnchors(
241+
return (PartitionAnchors(
242242
inputs=[(conv1d_node, 0)],
243243
weights=[(conv1d_node, 1)],
244244
# pyre-fixme[6]: Incompatible parameter type
245245
biases=bias,
246246
output=[(conv1d_node,)],
247-
)
247+
), conv1d_node)
248248

249249
def replacement_op(self) -> OpOverload:
250250
return torch.ops.cadence.quantized_conv_nchw.default
@@ -256,7 +256,7 @@ def partition_types(self) -> List[OpOverload]:
256256

257257
def get_anchors(
258258
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
259-
) -> PartitionAnchors:
259+
) -> Tuple[PartitionAnchors, fx.Node]:
260260
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
261261
conv2d_node = fused_partition[0].nodes[-1]
262262

@@ -277,13 +277,13 @@ def get_anchors(
277277
if len(conv2d_node.args) > 2 and conv2d_node.args[2] is not None:
278278
bias = [(conv2d_node, 2, bias_qspec)]
279279

280-
return PartitionAnchors(
280+
return (PartitionAnchors(
281281
inputs=[(conv2d_node, 0)],
282282
weights=[(conv2d_node, 1)],
283283
# pyre-fixme[6]: Incompatible parameter type
284284
biases=bias,
285285
output=[(conv2d_node,)],
286-
)
286+
), conv2d_node)
287287

288288
def replacement_op(self) -> OpOverload:
289289
return torch.ops.cadence.quantized_conv_nchw.default
@@ -295,7 +295,7 @@ def partition_types(self) -> List[OpOverload]:
295295

296296
def get_anchors(
297297
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
298-
) -> PartitionAnchors:
298+
) -> Tuple[PartitionAnchors, fx.Node]:
299299
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
300300
layer_norm_node = fused_partition[0].nodes[-1]
301301

@@ -311,14 +311,14 @@ def get_anchors(
311311

312312
# Weights are used in quantized mode by our kernel, so they are
313313
# passed in as others here along with the normalized shape.
314-
return PartitionAnchors(
314+
return (PartitionAnchors(
315315
inputs=[(layer_norm_node, 0)],
316316
weights=[],
317317
biases=[],
318318
# Ordering: normalized_shape, weights, bias
319319
others=others,
320320
output=[(layer_norm_node,)],
321-
)
321+
), layer_norm_node)
322322

323323
def replacement_op(self) -> OpOverload:
324324
return torch.ops.cadence.quantized_layer_norm.default
@@ -330,7 +330,7 @@ def partition_types(self) -> List[OpOverload]:
330330

331331
def get_anchors(
332332
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
333-
) -> PartitionAnchors:
333+
) -> Tuple[PartitionAnchors, fx.Node]:
334334
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
335335
linear_node = fused_partition[0].nodes[-1]
336336

@@ -351,13 +351,13 @@ def get_anchors(
351351
if len(linear_node.args) > 2:
352352
bias = [(linear_node, 2, bias_qspec)]
353353

354-
return PartitionAnchors(
354+
return (PartitionAnchors(
355355
inputs=[(linear_node, 0)],
356356
weights=[(linear_node, 1)],
357357
# pyre-fixme[6]: Incompatible parameter type
358358
biases=bias,
359359
output=[(linear_node,)],
360-
)
360+
), linear_node)
361361

362362
def replacement_op(self) -> OpOverload:
363363
return torch.ops.cadence.quantized_linear.default
@@ -369,16 +369,16 @@ def partition_types(self) -> List[OpOverload]:
369369

370370
def get_anchors(
371371
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
372-
) -> PartitionAnchors:
372+
) -> Tuple[PartitionAnchors, fx.Node]:
373373
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
374374
matmul_node = fused_partition[0].nodes[-1]
375375

376-
return PartitionAnchors(
376+
return (PartitionAnchors(
377377
inputs=[(matmul_node, 0), (matmul_node, 1)],
378378
weights=[],
379379
biases=[],
380380
output=[(matmul_node,)],
381-
)
381+
), matmul_node)
382382

383383
def replacement_op(self) -> OpOverload:
384384
return torch.ops.cadence.quantized_matmul.default
@@ -392,16 +392,16 @@ def partition_types(self) -> List[OpOverload]:
392392

393393
def get_anchors(
394394
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
395-
) -> PartitionAnchors:
395+
) -> Tuple[PartitionAnchors, fx.Node]:
396396
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
397397
relu_node = fused_partition[0].nodes[-1]
398398

399-
return PartitionAnchors(
399+
return (PartitionAnchors(
400400
inputs=[(relu_node, 0)],
401401
weights=[],
402402
biases=[],
403403
output=[(relu_node,)],
404-
)
404+
), relu_node)
405405

406406
def replacement_op(self) -> OpOverload:
407407
return torch.ops.cadence.quantized_relu.default

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
112112
if not no_outside_users(fused_partition):
113113
continue
114114

115-
anchors = self.pattern.get_anchors(model, fused_partition)
115+
anchors, _ = self.pattern.get_anchors(model, fused_partition)
116116
if not anchors or anchors.empty:
117117
continue
118118
if is_annotated(

0 commit comments

Comments
 (0)