Skip to content

Commit 2a06efb

Browse files
NXP backend: Update the pass move_auxiliary_operator_into_separate_qdq_cluster. (#13841)
### Summary Update the passes in `move_auxiliary_operator_into_separate_qdq_cluster_pass.py` to only transform specific combinations of operators. ### Test plan Tested by existing tests in `test_edge_passes.py`.
1 parent dbac09c commit 2a06efb

File tree

1 file changed

+44
-32
lines changed

1 file changed

+44
-32
lines changed

backends/nxp/edge_passes/move_auxiliary_operator_into_separate_qdq_cluster_pass.py

Lines changed: 44 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@
1111
from torch.fx import Node
1212
from torch.fx.passes.infra.pass_base import PassResult
1313

14+
# Operator aliases for better readability.
15+
AddMM = exir_ops.edge.aten.addmm.default
16+
ViewCopy = exir_ops.edge.aten.view_copy.default
17+
MM = exir_ops.edge.aten.mm.default
18+
1419

1520
def insert_qdq_pair_after_node(
1621
graph: torch.fx.Graph, anchor: torch.fx.Node, q_params: tuple
@@ -41,15 +46,17 @@ def insert_qdq_pair_after_node(
4146

4247
def _is_dequantize(node_: Node) -> bool:
4348
return (
44-
node_.op == "call_function"
49+
hasattr(node_, "op")
50+
and node_.op == "call_function"
4551
and node_.target
4652
== exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
4753
)
4854

4955

5056
def _is_quantize(node_: Node) -> bool:
5157
return (
52-
node_.op == "call_function"
58+
hasattr(node_, "op")
59+
and node_.op == "call_function"
5360
and node_.target
5461
== exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
5562
)
@@ -82,20 +89,19 @@ class MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass):
8289
8390
"""
8491

85-
allowed_auxiliary_nodes = [exir_ops.edge.aten.view_copy.default]
86-
87-
# List of approved nodes to which the <aux_node> can be connected in order for the pass to make the modification.
88-
allowed_main_cluster_nodes = [
89-
exir_ops.edge.aten.addmm.default,
90-
exir_ops.edge.aten.mm.default,
91-
]
92+
# Dictionary mapping main cluster nodes to auxiliary nodes, for which this optimization will be applied.
93+
main_cluster_node_to_auxiliary_nodes = {
94+
AddMM: [
95+
ViewCopy,
96+
],
97+
MM: [
98+
ViewCopy,
99+
],
100+
}
92101

93102
def run(self, graph_module: torch.fx.GraphModule) -> PassResult:
94103
for aux_node in graph_module.graph.nodes:
95-
if (
96-
aux_node.op != "call_function"
97-
or aux_node.target not in self.allowed_auxiliary_nodes
98-
):
104+
if aux_node.op != "call_function":
99105
continue
100106

101107
dequantize_node = aux_node.args[0]
@@ -109,11 +115,13 @@ def run(self, graph_module: torch.fx.GraphModule) -> PassResult:
109115
continue
110116

111117
main_cluster_node = users[0]
112-
if (
113-
main_cluster_node.op != "call_function"
114-
or main_cluster_node.target not in self.allowed_main_cluster_nodes
118+
if main_cluster_node.op != "call_function":
119+
continue
120+
121+
if aux_node.target not in self.main_cluster_node_to_auxiliary_nodes.get(
122+
main_cluster_node.target, []
115123
):
116-
# Unsupported `main_cluster_node`.
124+
# Unsupported main cluster node and auxiliary node pair.
117125
continue
118126

119127
# Make sure the nodes are part of the same QDQ cluster.
@@ -163,29 +171,33 @@ class MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass):
163171
164172
"""
165173

166-
allowed_auxiliary_nodes = [exir_ops.edge.aten.view_copy.default]
167-
168-
# List of approved nodes to which the `<aux_node>` can be connected in order for the pass to make the modification.
169-
allowed_main_cluster_nodes = [
170-
exir_ops.edge.aten.addmm.default,
171-
exir_ops.edge.aten.mm.default,
172-
]
174+
# Dictionary mapping main cluster nodes to auxiliary nodes, for which this optimization will be applied.
175+
main_cluster_node_to_auxiliary_nodes = {
176+
AddMM: [
177+
ViewCopy,
178+
],
179+
MM: [
180+
ViewCopy,
181+
],
182+
}
173183

174184
def run(self, graph_module: torch.fx.GraphModule) -> PassResult:
175185

176186
for aux_node in graph_module.graph.nodes:
177-
if (
178-
aux_node.op != "call_function"
179-
or aux_node.target not in self.allowed_auxiliary_nodes
180-
):
187+
if aux_node.op != "call_function":
181188
continue
182189

183190
main_cluster_node = aux_node.args[0]
184-
if (
185-
main_cluster_node.op != "call_function"
186-
or main_cluster_node.target not in self.allowed_main_cluster_nodes
191+
if not (
192+
hasattr(main_cluster_node, "op")
193+
and main_cluster_node.op == "call_function"
194+
):
195+
continue
196+
197+
if aux_node.target not in self.main_cluster_node_to_auxiliary_nodes.get(
198+
main_cluster_node.target, []
187199
):
188-
# Unsupported `main_cluster_node`.
200+
# Unsupported main cluster node and auxiliary node pair.
189201
continue
190202

191203
users = list(aux_node.users.keys())

0 commit comments

Comments
 (0)