Skip to content
This repository was archived by the owner on Nov 15, 2022. It is now read-only.

Commit 4cc2a37

Browse files
cpuhrschfacebook-github-bot
authored andcommitted
NestedTensor import 20210805
Summary: Import from GH Reviewed By: mthrok Differential Revision: D30133587 fbshipit-source-id: 6b054a74bf05a13235b4c6dbb8e464ea7f595518
1 parent d00fd7c commit 4cc2a37

File tree

5 files changed

+177
-12
lines changed

5 files changed

+177
-12
lines changed

nestedtensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from .nested.nested import transpose_nhwc_nchw
1313

1414
from .nested.fuser import fuse_conv_bn
15+
from .nested.fuser import fuse_conv_relu
16+
from .nested.fuser import fuse_conv_add_relu
1517

1618
from . import nested
1719

nestedtensor/csrc/UnaryOps.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,14 +103,14 @@ Tensor NestedTensor_clamp_max(const Tensor& self, const c10::Scalar& min) {
103103

104104
Tensor& NestedTensor_clamp_max_out(
105105
const Tensor& self,
106-
const Scalar& min,
106+
const Scalar& max,
107107
Tensor& result) {
108108
apply_nested_tensor(
109-
[min](const Tensor self, Tensor result) {
110-
at::native::clamp_max_out(self, min, result);
109+
[max](Tensor result, const Tensor tensor) {
110+
at::clamp_max_out(result, tensor, max);
111111
},
112-
self,
113-
result);
112+
result,
113+
self);
114114
return result;
115115
}
116116

nestedtensor/nested/fuser.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,163 @@ def fuse_conv_bn(model: torch.nn.Module, inplace=False) -> torch.nn.Module:
8787
node.replace_all_uses_with(node.args[0])
8888
new_graph.erase_node(node)
8989
return fx.GraphModule(fx_model, new_graph)
90+
91+
class Conv2dReLU(torch.nn.Module):
92+
def __init__(self,
93+
weight,
94+
bias,
95+
stride,
96+
padding,
97+
dilation,
98+
groups):
99+
super(Conv2dReLU, self).__init__()
100+
self.weight = weight
101+
self.weight_is_channels_last = False
102+
self.bias = bias
103+
self.stride = stride
104+
self.padding = padding
105+
self.dilation = dilation
106+
self.groups = groups
107+
self.slow_fusion = False
108+
if self.weight.size(2) == 7 and self.weight.size(3) == 7:
109+
self.slow_fusion = True
110+
111+
def forward(self, inp):
112+
# NOTE: This will be faster once https://github.com/pytorch/pytorch/pull/62482 lands
113+
if not self.slow_fusion and inp.is_contiguous(memory_format=torch.contiguous_format):
114+
inp = inp.to(memory_format=torch.channels_last)
115+
if self.slow_fusion and inp.is_contiguous(memory_format=torch.channels_last):
116+
inp = inp.to(memory_format=torch.contiguous_format)
117+
if not self.slow_fusion and not self.weight_is_channels_last:
118+
self.weight.data = self.weight.to(memory_format=torch.channels_last)
119+
inp = inp.to(memory_format=torch.channels_last)
120+
self.weight_is_channels_last = True
121+
return torch.cudnn_convolution_relu(inp,
122+
self.weight,
123+
self.bias,
124+
self.stride,
125+
self.padding,
126+
self.dilation,
127+
self.groups)
128+
129+
class Conv2dAddReLU(torch.nn.Module):
130+
def __init__(self,
131+
weight,
132+
bias,
133+
stride,
134+
padding,
135+
dilation,
136+
groups):
137+
super(Conv2dAddReLU, self).__init__()
138+
self.weight = weight
139+
self.weight_is_channels_last = False
140+
self.bias = bias
141+
self.stride = stride
142+
self.padding = padding
143+
self.dilation = dilation
144+
self.groups = groups
145+
self.slow_fusion = False
146+
if self.weight.size(2) == 7 and self.weight.size(3) == 7:
147+
self.slow_fusion = True
148+
149+
def forward(self, inp, add_input):
150+
# TODO: Reactivate this once cudnn_convolution_add_relu is fixed.
151+
# weight = self.weight.to(memory_format=torch.contiguous_format)
152+
# if not self.slow_fusion and inp.is_contiguous(memory_format=torch.contiguous_format):
153+
# inp = inp.to(memory_format=torch.channels_last)
154+
# add_input = add_input.to(memory_format=torch.channels_last)
155+
# if self.slow_fusion and inp.is_contiguous(memory_format=torch.channels_last):
156+
# inp = inp.to(memory_format=torch.contiguous_format)
157+
# add_input = add_input.to(memory_format=torch.contiguous_format)
158+
# if not self.slow_fusion and not self.weight_is_channels_last:
159+
# self.weight.data = self.weight.to(memory_format=torch.channels_last)
160+
# inp = inp.to(memory_format=torch.channels_last)
161+
# add_input = add_input.to(memory_format=torch.channels_last)
162+
# self.weight_is_channels_last = True
163+
# return torch.cudnn_convolution_add_relu(inp,
164+
# self.weight,
165+
# add_input,
166+
# 1.0,
167+
# self.bias,
168+
# self.stride,
169+
# self.padding,
170+
# self.dilation,
171+
# self.groups)
172+
out = torch.conv2d(inp,
173+
self.weight,
174+
self.bias,
175+
self.stride,
176+
self.padding,
177+
self.dilation,
178+
self.groups)
179+
out.add_(add_input)
180+
out.relu_()
181+
return out
182+
183+
def fuse_conv_relu(model: torch.nn.Module, inplace=False) -> torch.nn.Module:
184+
"""
185+
Fuses convolution/BN layers for inference purposes. Will deepcopy your
186+
model by default, but can modify the model inplace as well.
187+
"""
188+
patterns = [(torch.nn.Conv2d, torch.nn.ReLU)]
189+
if not inplace:
190+
model = copy.deepcopy(model)
191+
fx_model = fx.symbolic_trace(model)
192+
modules = dict(fx_model.named_modules())
193+
new_graph = copy.deepcopy(fx_model.graph)
194+
195+
for pattern in patterns:
196+
for node in new_graph.nodes:
197+
if matches_module_pattern(pattern, node, modules):
198+
if len(node.args[0].users) > 1: # Output of conv is used by other nodes
199+
continue
200+
conv = modules[node.args[0].target]
201+
relu = modules[node.target]
202+
fused_conv = Conv2dReLU(conv.weight, conv.bias, conv.stride, conv.padding, conv.dilation, conv.groups)
203+
replace_node_module(node.args[0], modules, fused_conv)
204+
node.replace_all_uses_with(node.args[0])
205+
new_graph.erase_node(node)
206+
207+
208+
last_nodes = []
209+
count = 0
210+
for node in new_graph.nodes:
211+
if count == 31:
212+
break
213+
if (node.op == "call_function" or node.op == "call_module"):
214+
last_nodes.append(node)
215+
if len(last_nodes) == 4:
216+
last_nodes = last_nodes[1:]
217+
if len(last_nodes) < 3:
218+
continue
219+
is_match = True
220+
is_match = is_match and (last_nodes[0].op == "call_module")
221+
is_match = is_match and (last_nodes[1].op == "call_function")
222+
is_match = is_match and (last_nodes[2].op == "call_module")
223+
is_match = is_match and isinstance(modules[last_nodes[0].target], torch.nn.Conv2d)
224+
is_match = is_match and (str(last_nodes[1]).split("_")[0] == "add")
225+
is_match = is_match and isinstance(modules[last_nodes[2].target], torch.nn.ReLU)
226+
if (is_match):
227+
conv = modules[last_nodes[1].args[0].target]
228+
fused_conv = Conv2dAddReLU(conv.weight, conv.bias, conv.stride, conv.padding, conv.dilation, conv.groups)
229+
replace_node_module(last_nodes[2], modules, fused_conv)
230+
last_nodes[2].args = (last_nodes[0].args[0], last_nodes[1].args[1])
231+
new_graph.erase_node(last_nodes[1])
232+
new_graph.erase_node(last_nodes[0])
233+
count += 1
234+
return fx.GraphModule(fx_model, new_graph)
235+
236+
237+
def fuse_conv_add_relu(model: torch.nn.Module, inplace=False) -> torch.nn.Module:
238+
"""
239+
Fuses convolution/BN layers for inference purposes. Will deepcopy your
240+
model by default, but can modify the model inplace as well.
241+
"""
242+
if not inplace:
243+
model = copy.deepcopy(model)
244+
fx_model = fx.symbolic_trace(model)
245+
modules = dict(fx_model.named_modules())
246+
new_graph = copy.deepcopy(fx_model.graph)
247+
248+
new_graph.lint()
249+
return fx.GraphModule(fx_model, new_graph)

nestedtensor/version.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
__version__ = '0.1.4+66764fd'
2-
git_version = '66764fd10e9b6f9c0710840d0cb17369b9d994be'
1+
__version__ = '0.1.4+da883d9'
2+
git_version = 'da883d94a7cb250db7ec7d6d152764e6e8e8788a'
33
from nestedtensor import _C
44
if hasattr(_C, 'CUDA_VERSION'):
55
cuda = _C.CUDA_VERSION

test/test_nested_tensor_integration.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import nestedtensor
33
import unittest
44
from utils_test_case import TestCase
5+
from utils import debug_on
56

67
try:
78
import classy_vision
@@ -194,22 +195,24 @@ def _test(dtype, use_channels_last):
194195
from torch.fx import symbolic_trace
195196
model = build_model({"name": "resnext101_32x4d"}).eval().cuda()
196197
model._initialize_weights(False)
197-
fused = symbolic_trace(model)
198-
fused = nestedtensor.fuse_conv_bn(fused)
198+
# This is needed to allow tracing, but for makes no difference for resnext
199+
model = model.classy_model
200+
fused = nestedtensor.fuse_conv_bn(model)
201+
fused = nestedtensor.fuse_conv_relu(fused)
199202
model = model.to(dtype)
200203
fused = fused.to(dtype)
201204
data = torch.randn(2, 3, 50, 50, device=torch.device('cuda'), dtype=dtype)
205+
ref_output = model(data)
202206
if use_channels_last:
203207
data = data.contiguous(memory_format=torch.channels_last)
204-
ref_output = model(data)
205208
new_output = fused(data)
206209
if dtype == torch.float16:
207210
self.assertEqual(ref_output, new_output, prec=2e-3)
208211
else:
209212
self.assertEqual(ref_output, new_output)
210-
_test(torch.float16, False)
211213
_test(torch.float32, False)
212-
# _test(torch.float16, True)
214+
_test(torch.float16, False)
215+
_test(torch.float16, True)
213216
_test(torch.float32, True)
214217

215218

0 commit comments

Comments
 (0)