Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions backends/cadence/aot/quantizer/fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
pattern.partition_types(),
)
for fused_partition in fused_partitions:
anchors = pattern.get_anchors(graph_module, fused_partition)
anchors, op_node = pattern.get_anchors(graph_module, fused_partition)
if not anchors or anchors.empty:
continue
if any(self.is_fused(p.nodes) for p in fused_partition):
Expand Down Expand Up @@ -431,11 +431,17 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
bias_inputs = [node.args[0] for node in dequants_biases]
other_inputs = [node.args[idx] for node, idx in anchors.others]

# The node is the first index of the list and first of the tuple
op_node = anchors.output[0][0]
# Check if there's a quantization node after the operation
quant_node = None
has_quant_node = False

assert len(op_node.users) == 1
quant_node = list(op_node.users.keys())[0]
if len(op_node.users) == 1:
potential_quant = list(op_node.users.keys())[0]
# Check if it's actually a quantization node
if (hasattr(potential_quant, 'target') and
potential_quant.target == torch.ops.quantized_decomposed.quantize_per_tensor.default):
quant_node = potential_quant
has_quant_node = True

with graph_module.graph.inserting_after(op_node):
args = tuple(
Expand Down Expand Up @@ -516,8 +522,16 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
args,
kwargs,
)
fused.meta = quant_node.meta
quant_node.replace_all_uses_with(fused)
if has_quant_node:
fused.meta = quant_node.meta
quant_node.replace_all_uses_with(fused)
if quant_node.op == "output":
_ = graph_module.graph.output((fused,))
else:
fused.meta = op_node.meta
op_node.replace_all_uses_with(fused)
if op_node.op == "output":
_ = graph_module.graph.output((fused,))

legalize_graph(graph_module)
graph_module.graph.eliminate_dead_code()
Expand Down
66 changes: 33 additions & 33 deletions backends/cadence/aot/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def partition_types(self) -> list[OpOverload]:
@abstractmethod
def get_anchors(
self, gm: torch.fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> Optional[PartitionAnchors]:
) -> Tuple[PartitionAnchors, fx.Node]:
pass

@abstractmethod
Expand All @@ -85,7 +85,7 @@ def partition_types(self) -> List[OpOverload]:

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
) -> Tuple[PartitionAnchors, fx.Node]:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
addmm_node = fused_partition[0].nodes[-1]

Expand All @@ -101,12 +101,12 @@ def get_anchors(
qscheme=torch.per_tensor_affine,
)

return PartitionAnchors(
return (PartitionAnchors(
inputs=[(addmm_node, 1)],
weights=[(addmm_node, 2)],
biases=[(addmm_node, 0, bias_qspec)],
output=[(addmm_node,)],
)
), addmm_node)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_linear.default
Expand All @@ -118,7 +118,7 @@ def partition_types(self) -> List[OpOverload]:

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
) -> Tuple[PartitionAnchors, fx.Node]:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
add_node = fused_partition[0].nodes[-1]

Expand All @@ -129,16 +129,16 @@ def get_anchors(
add_node.args[1], fx.Node
)
if not is_tensor_add or len(add_node.kwargs) > 0:
return PartitionAnchors(
return (PartitionAnchors(
empty=True,
)
), add_node)

return PartitionAnchors(
return (PartitionAnchors(
inputs=[(add_node, 0), (add_node, 1)],
weights=[],
biases=[],
output=[(add_node,)],
)
), add_node)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_add.default
Expand All @@ -150,16 +150,16 @@ def partition_types(self) -> List[OpOverload]:

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
) -> Tuple[PartitionAnchors, fx.Node]:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
bmm_node = fused_partition[0].nodes[-1]

return PartitionAnchors(
return (PartitionAnchors(
inputs=[(bmm_node, 0), (bmm_node, 1)],
weights=[],
biases=[],
output=[(bmm_node,)],
)
), bmm_node)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_matmul.default
Expand All @@ -171,7 +171,7 @@ def partition_types(self) -> List[OpOverload]:

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
) -> Tuple[PartitionAnchors, fx.Node]:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
cat_node = fused_partition[0].nodes[-1]

Expand All @@ -198,14 +198,14 @@ def get_anchors(
)
)

return PartitionAnchors(
return (PartitionAnchors(
inputs=args,
weights=[],
biases=[],
output=[
(cat_node, SharedQuantizationSpec((cat_node.args[0][0], cat_node)))
],
)
), cat_node)

def replacement_op(self) -> OpOverload:
return torch.ops.aten.cat.default
Expand All @@ -217,7 +217,7 @@ def partition_types(self) -> List[OpOverload]:

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
) -> Tuple[PartitionAnchors, fx.Node]:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
conv1d_node = fused_partition[0].nodes[-1]

Expand All @@ -238,13 +238,13 @@ def get_anchors(
if len(conv1d_node.args) > 2 and conv1d_node.args[2] is not None:
bias = [(conv1d_node, 2, bias_qspec)]

return PartitionAnchors(
return (PartitionAnchors(
inputs=[(conv1d_node, 0)],
weights=[(conv1d_node, 1)],
# pyre-fixme[6]: Incompatible parameter type
biases=bias,
output=[(conv1d_node,)],
)
), conv1d_node)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_conv_nchw.default
Expand All @@ -256,7 +256,7 @@ def partition_types(self) -> List[OpOverload]:

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
) -> Tuple[PartitionAnchors, fx.Node]:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
conv2d_node = fused_partition[0].nodes[-1]

Expand All @@ -277,13 +277,13 @@ def get_anchors(
if len(conv2d_node.args) > 2 and conv2d_node.args[2] is not None:
bias = [(conv2d_node, 2, bias_qspec)]

return PartitionAnchors(
return (PartitionAnchors(
inputs=[(conv2d_node, 0)],
weights=[(conv2d_node, 1)],
# pyre-fixme[6]: Incompatible parameter type
biases=bias,
output=[(conv2d_node,)],
)
), conv2d_node)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_conv_nchw.default
Expand All @@ -295,7 +295,7 @@ def partition_types(self) -> List[OpOverload]:

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
) -> Tuple[PartitionAnchors, fx.Node]:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
layer_norm_node = fused_partition[0].nodes[-1]

Expand All @@ -311,14 +311,14 @@ def get_anchors(

# Weights are used in quantized mode by our kernel, so they are
# passed in as others here along with the normalized shape.
return PartitionAnchors(
return (PartitionAnchors(
inputs=[(layer_norm_node, 0)],
weights=[],
biases=[],
# Ordering: normalized_shape, weights, bias
others=others,
output=[(layer_norm_node,)],
)
), layer_norm_node)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_layer_norm.default
Expand All @@ -330,7 +330,7 @@ def partition_types(self) -> List[OpOverload]:

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
) -> Tuple[PartitionAnchors, fx.Node]:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
linear_node = fused_partition[0].nodes[-1]

Expand All @@ -351,13 +351,13 @@ def get_anchors(
if len(linear_node.args) > 2:
bias = [(linear_node, 2, bias_qspec)]

return PartitionAnchors(
return (PartitionAnchors(
inputs=[(linear_node, 0)],
weights=[(linear_node, 1)],
# pyre-fixme[6]: Incompatible parameter type
biases=bias,
output=[(linear_node,)],
)
), linear_node)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_linear.default
Expand All @@ -369,16 +369,16 @@ def partition_types(self) -> List[OpOverload]:

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
) -> Tuple[PartitionAnchors, fx.Node]:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
matmul_node = fused_partition[0].nodes[-1]

return PartitionAnchors(
return (PartitionAnchors(
inputs=[(matmul_node, 0), (matmul_node, 1)],
weights=[],
biases=[],
output=[(matmul_node,)],
)
), matmul_node)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_matmul.default
Expand All @@ -392,16 +392,16 @@ def partition_types(self) -> List[OpOverload]:

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
) -> Tuple[PartitionAnchors, fx.Node]:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
relu_node = fused_partition[0].nodes[-1]

return PartitionAnchors(
return (PartitionAnchors(
inputs=[(relu_node, 0)],
weights=[],
biases=[],
output=[(relu_node,)],
)
), relu_node)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_relu.default
Expand Down
2 changes: 1 addition & 1 deletion backends/cadence/aot/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
if not no_outside_users(fused_partition):
continue

anchors = self.pattern.get_anchors(model, fused_partition)
anchors, _ = self.pattern.get_anchors(model, fused_partition)
if not anchors or anchors.empty:
continue
if is_annotated(
Expand Down