@@ -87,3 +87,163 @@ def fuse_conv_bn(model: torch.nn.Module, inplace=False) -> torch.nn.Module:
87
87
node .replace_all_uses_with (node .args [0 ])
88
88
new_graph .erase_node (node )
89
89
return fx .GraphModule (fx_model , new_graph )
90
+
91
+ class Conv2dReLU (torch .nn .Module ):
92
+ def __init__ (self ,
93
+ weight ,
94
+ bias ,
95
+ stride ,
96
+ padding ,
97
+ dilation ,
98
+ groups ):
99
+ super (Conv2dReLU , self ).__init__ ()
100
+ self .weight = weight
101
+ self .weight_is_channels_last = False
102
+ self .bias = bias
103
+ self .stride = stride
104
+ self .padding = padding
105
+ self .dilation = dilation
106
+ self .groups = groups
107
+ self .slow_fusion = False
108
+ if self .weight .size (2 ) == 7 and self .weight .size (3 ) == 7 :
109
+ self .slow_fusion = True
110
+
111
+ def forward (self , inp ):
112
+ # NOTE: This will be faster once https://github.com/pytorch/pytorch/pull/62482 lands
113
+ if not self .slow_fusion and inp .is_contiguous (memory_format = torch .contiguous_format ):
114
+ inp = inp .to (memory_format = torch .channels_last )
115
+ if self .slow_fusion and inp .is_contiguous (memory_format = torch .channels_last ):
116
+ inp = inp .to (memory_format = torch .contiguous_format )
117
+ if not self .slow_fusion and not self .weight_is_channels_last :
118
+ self .weight .data = self .weight .to (memory_format = torch .channels_last )
119
+ inp = inp .to (memory_format = torch .channels_last )
120
+ self .weight_is_channels_last = True
121
+ return torch .cudnn_convolution_relu (inp ,
122
+ self .weight ,
123
+ self .bias ,
124
+ self .stride ,
125
+ self .padding ,
126
+ self .dilation ,
127
+ self .groups )
128
+
129
+ class Conv2dAddReLU (torch .nn .Module ):
130
+ def __init__ (self ,
131
+ weight ,
132
+ bias ,
133
+ stride ,
134
+ padding ,
135
+ dilation ,
136
+ groups ):
137
+ super (Conv2dAddReLU , self ).__init__ ()
138
+ self .weight = weight
139
+ self .weight_is_channels_last = False
140
+ self .bias = bias
141
+ self .stride = stride
142
+ self .padding = padding
143
+ self .dilation = dilation
144
+ self .groups = groups
145
+ self .slow_fusion = False
146
+ if self .weight .size (2 ) == 7 and self .weight .size (3 ) == 7 :
147
+ self .slow_fusion = True
148
+
149
+ def forward (self , inp , add_input ):
150
+ # TODO: Reactivate this once cudnn_convolution_add_relu is fixed.
151
+ # weight = self.weight.to(memory_format=torch.contiguous_format)
152
+ # if not self.slow_fusion and inp.is_contiguous(memory_format=torch.contiguous_format):
153
+ # inp = inp.to(memory_format=torch.channels_last)
154
+ # add_input = add_input.to(memory_format=torch.channels_last)
155
+ # if self.slow_fusion and inp.is_contiguous(memory_format=torch.channels_last):
156
+ # inp = inp.to(memory_format=torch.contiguous_format)
157
+ # add_input = add_input.to(memory_format=torch.contiguous_format)
158
+ # if not self.slow_fusion and not self.weight_is_channels_last:
159
+ # self.weight.data = self.weight.to(memory_format=torch.channels_last)
160
+ # inp = inp.to(memory_format=torch.channels_last)
161
+ # add_input = add_input.to(memory_format=torch.channels_last)
162
+ # self.weight_is_channels_last = True
163
+ # return torch.cudnn_convolution_add_relu(inp,
164
+ # self.weight,
165
+ # add_input,
166
+ # 1.0,
167
+ # self.bias,
168
+ # self.stride,
169
+ # self.padding,
170
+ # self.dilation,
171
+ # self.groups)
172
+ out = torch .conv2d (inp ,
173
+ self .weight ,
174
+ self .bias ,
175
+ self .stride ,
176
+ self .padding ,
177
+ self .dilation ,
178
+ self .groups )
179
+ out .add_ (add_input )
180
+ out .relu_ ()
181
+ return out
182
+
183
+ def fuse_conv_relu (model : torch .nn .Module , inplace = False ) -> torch .nn .Module :
184
+ """
185
+ Fuses convolution/BN layers for inference purposes. Will deepcopy your
186
+ model by default, but can modify the model inplace as well.
187
+ """
188
+ patterns = [(torch .nn .Conv2d , torch .nn .ReLU )]
189
+ if not inplace :
190
+ model = copy .deepcopy (model )
191
+ fx_model = fx .symbolic_trace (model )
192
+ modules = dict (fx_model .named_modules ())
193
+ new_graph = copy .deepcopy (fx_model .graph )
194
+
195
+ for pattern in patterns :
196
+ for node in new_graph .nodes :
197
+ if matches_module_pattern (pattern , node , modules ):
198
+ if len (node .args [0 ].users ) > 1 : # Output of conv is used by other nodes
199
+ continue
200
+ conv = modules [node .args [0 ].target ]
201
+ relu = modules [node .target ]
202
+ fused_conv = Conv2dReLU (conv .weight , conv .bias , conv .stride , conv .padding , conv .dilation , conv .groups )
203
+ replace_node_module (node .args [0 ], modules , fused_conv )
204
+ node .replace_all_uses_with (node .args [0 ])
205
+ new_graph .erase_node (node )
206
+
207
+
208
+ last_nodes = []
209
+ count = 0
210
+ for node in new_graph .nodes :
211
+ if count == 31 :
212
+ break
213
+ if (node .op == "call_function" or node .op == "call_module" ):
214
+ last_nodes .append (node )
215
+ if len (last_nodes ) == 4 :
216
+ last_nodes = last_nodes [1 :]
217
+ if len (last_nodes ) < 3 :
218
+ continue
219
+ is_match = True
220
+ is_match = is_match and (last_nodes [0 ].op == "call_module" )
221
+ is_match = is_match and (last_nodes [1 ].op == "call_function" )
222
+ is_match = is_match and (last_nodes [2 ].op == "call_module" )
223
+ is_match = is_match and isinstance (modules [last_nodes [0 ].target ], torch .nn .Conv2d )
224
+ is_match = is_match and (str (last_nodes [1 ]).split ("_" )[0 ] == "add" )
225
+ is_match = is_match and isinstance (modules [last_nodes [2 ].target ], torch .nn .ReLU )
226
+ if (is_match ):
227
+ conv = modules [last_nodes [1 ].args [0 ].target ]
228
+ fused_conv = Conv2dAddReLU (conv .weight , conv .bias , conv .stride , conv .padding , conv .dilation , conv .groups )
229
+ replace_node_module (last_nodes [2 ], modules , fused_conv )
230
+ last_nodes [2 ].args = (last_nodes [0 ].args [0 ], last_nodes [1 ].args [1 ])
231
+ new_graph .erase_node (last_nodes [1 ])
232
+ new_graph .erase_node (last_nodes [0 ])
233
+ count += 1
234
+ return fx .GraphModule (fx_model , new_graph )
235
+
236
+
237
+ def fuse_conv_add_relu (model : torch .nn .Module , inplace = False ) -> torch .nn .Module :
238
+ """
239
+ Fuses convolution/BN layers for inference purposes. Will deepcopy your
240
+ model by default, but can modify the model inplace as well.
241
+ """
242
+ if not inplace :
243
+ model = copy .deepcopy (model )
244
+ fx_model = fx .symbolic_trace (model )
245
+ modules = dict (fx_model .named_modules ())
246
+ new_graph = copy .deepcopy (fx_model .graph )
247
+
248
+ new_graph .lint ()
249
+ return fx .GraphModule (fx_model , new_graph )
0 commit comments