11
11
from torch .fx import Node
12
12
from torch .fx .passes .infra .pass_base import PassResult
13
13
14
+ # Operator aliases for better readability.
15
+ AddMM = exir_ops .edge .aten .addmm .default
16
+ ViewCopy = exir_ops .edge .aten .view_copy .default
17
+ MM = exir_ops .edge .aten .mm .default
18
+
14
19
15
20
def insert_qdq_pair_after_node (
16
21
graph : torch .fx .Graph , anchor : torch .fx .Node , q_params : tuple
@@ -41,15 +46,17 @@ def insert_qdq_pair_after_node(
41
46
42
47
def _is_dequantize (node_ : Node ) -> bool :
43
48
return (
44
- node_ .op == "call_function"
49
+ hasattr (node_ , "op" )
50
+ and node_ .op == "call_function"
45
51
and node_ .target
46
52
== exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default
47
53
)
48
54
49
55
50
56
def _is_quantize (node_ : Node ) -> bool :
51
57
return (
52
- node_ .op == "call_function"
58
+ hasattr (node_ , "op" )
59
+ and node_ .op == "call_function"
53
60
and node_ .target
54
61
== exir_ops .edge .quantized_decomposed .quantize_per_tensor .default
55
62
)
@@ -82,20 +89,19 @@ class MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass):
82
89
▼
83
90
"""
84
91
85
- allowed_auxiliary_nodes = [exir_ops .edge .aten .view_copy .default ]
86
-
87
- # List of approved nodes to which the <aux_node> can be connected in order for the pass to make the modification.
88
- allowed_main_cluster_nodes = [
89
- exir_ops .edge .aten .addmm .default ,
90
- exir_ops .edge .aten .mm .default ,
91
- ]
92
+ # Dictionary mapping main cluster nodes to auxiliary nodes, for which this optimization will be applied.
93
+ main_cluster_node_to_auxiliary_nodes = {
94
+ AddMM : [
95
+ ViewCopy ,
96
+ ],
97
+ MM : [
98
+ ViewCopy ,
99
+ ],
100
+ }
92
101
93
102
def run (self , graph_module : torch .fx .GraphModule ) -> PassResult :
94
103
for aux_node in graph_module .graph .nodes :
95
- if (
96
- aux_node .op != "call_function"
97
- or aux_node .target not in self .allowed_auxiliary_nodes
98
- ):
104
+ if aux_node .op != "call_function" :
99
105
continue
100
106
101
107
dequantize_node = aux_node .args [0 ]
@@ -109,11 +115,13 @@ def run(self, graph_module: torch.fx.GraphModule) -> PassResult:
109
115
continue
110
116
111
117
main_cluster_node = users [0 ]
112
- if (
113
- main_cluster_node .op != "call_function"
114
- or main_cluster_node .target not in self .allowed_main_cluster_nodes
118
+ if main_cluster_node .op != "call_function" :
119
+ continue
120
+
121
+ if aux_node .target not in self .main_cluster_node_to_auxiliary_nodes .get (
122
+ main_cluster_node .target , []
115
123
):
116
- # Unsupported `main_cluster_node` .
124
+ # Unsupported main cluster node and auxiliary node pair .
117
125
continue
118
126
119
127
# Make sure the nodes are part of the same QDQ cluster.
@@ -163,29 +171,33 @@ class MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass):
163
171
▼
164
172
"""
165
173
166
- allowed_auxiliary_nodes = [exir_ops .edge .aten .view_copy .default ]
167
-
168
- # List of approved nodes to which the `<aux_node>` can be connected in order for the pass to make the modification.
169
- allowed_main_cluster_nodes = [
170
- exir_ops .edge .aten .addmm .default ,
171
- exir_ops .edge .aten .mm .default ,
172
- ]
174
+ # Dictionary mapping main cluster nodes to auxiliary nodes, for which this optimization will be applied.
175
+ main_cluster_node_to_auxiliary_nodes = {
176
+ AddMM : [
177
+ ViewCopy ,
178
+ ],
179
+ MM : [
180
+ ViewCopy ,
181
+ ],
182
+ }
173
183
174
184
def run (self , graph_module : torch .fx .GraphModule ) -> PassResult :
175
185
176
186
for aux_node in graph_module .graph .nodes :
177
- if (
178
- aux_node .op != "call_function"
179
- or aux_node .target not in self .allowed_auxiliary_nodes
180
- ):
187
+ if aux_node .op != "call_function" :
181
188
continue
182
189
183
190
main_cluster_node = aux_node .args [0 ]
184
- if (
185
- main_cluster_node .op != "call_function"
186
- or main_cluster_node .target not in self .allowed_main_cluster_nodes
191
+ if not (
192
+ hasattr (main_cluster_node , "op" )
193
+ and main_cluster_node .op == "call_function"
194
+ ):
195
+ continue
196
+
197
+ if aux_node .target not in self .main_cluster_node_to_auxiliary_nodes .get (
198
+ main_cluster_node .target , []
187
199
):
188
- # Unsupported `main_cluster_node` .
200
+ # Unsupported main cluster node and auxiliary node pair .
189
201
continue
190
202
191
203
users = list (aux_node .users .keys ())
0 commit comments