Skip to content
This repository was archived by the owner on Nov 15, 2022. It is now read-only.

Commit 9b6dcbc

Browse files
cpuhrschfacebook-github-bot
authored andcommitted
NestedTensor import 20210730
Summary: Code import Reviewed By: janeyx99 Differential Revision: D30016452 fbshipit-source-id: 211add77274159d590343e05920211b2a886b58f
1 parent 423a9e9 commit 9b6dcbc

File tree

8 files changed

+139
-24
lines changed

8 files changed

+139
-24
lines changed

nestedtensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from .nested.nested import transpose_nchw_nhwc
1212
from .nested.nested import transpose_nhwc_nchw
1313

14+
from .nested.fuser import fuse_conv_bn
15+
1416
from . import nested
1517

1618
from . import _C

nestedtensor/csrc/UnaryOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ Tensor& NestedTensor_clamp_out(
5959
Tensor& result) {
6060
apply_nested_tensor(
6161
[min, max](const at::Tensor self, at::Tensor result) {
62-
at::native::clamp_out(self, min, max, result);
62+
at::clamp_out(result, self, min, max);
6363
},
6464
self,
6565
result);

nestedtensor/csrc/conv2d.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ Tensor NestedTensor_conv2d(
2828
TORCH_CHECK(get_dim(input) == 4, "Expected input to be dim 4, but got ", get_dim(input), ".");
2929
#ifdef WITH_CUDA
3030
auto self_opt_sizes = get_opt_sizes(input);
31-
if (is_nested_tensor_impl(input) && !is_nested_tensor_impl(weight) && input.dtype() == torch::kFloat16) {
31+
if (is_nested_tensor_impl(input) &&
32+
!is_nested_tensor_impl(weight) &&
33+
(input.dtype() == torch::kFloat16 || input.dtype() == torch::kFloat32)) {
3234
if (get_dim(input) == 4 && !bias && weight.size(2) == 1 && weight.size(3) == 1 &&
3335
stride[0] == 1 && stride[1] == 1 &&
3436
padding[0] == 0 && padding[1] == 0 &&
@@ -38,7 +40,7 @@ Tensor NestedTensor_conv2d(
3840
*self_opt_sizes[1] &&
3941
get_is_cuda(input)
4042
) {
41-
if (get_is_contiguous(input, c10::MemoryFormat::ChannelsLast) && input.dtype() == torch::kHalf) {
43+
if (get_is_contiguous(input, c10::MemoryFormat::ChannelsLast)) {
4244
Tensor input_buffer = get_buffer(input);
4345
input_buffer = input_buffer.view({-1, weight.size(1)});
4446
at::Tensor result_buffer = at::matmul(input_buffer,
@@ -56,7 +58,7 @@ Tensor NestedTensor_conv2d(
5658
}, new_sizes);
5759
return wrap_buffer(result_buffer.view(-1), new_sizes, new_strides);
5860
}
59-
if (get_is_contiguous(input) && input.dtype() == torch::kHalf) {
61+
if (get_is_contiguous(input)) {
6062
input = transpose_nchw_nhwc(input);
6163
Tensor input_buffer = get_buffer(input);
6264
input_buffer = input_buffer.reshape({-1, weight.size(1)});

nestedtensor/csrc/mha.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,14 @@ at::Tensor min_mha(
5555

5656
q = q * torch::tensor(scaling);
5757

58-
q = q.reshape({-1, -1, num_heads, head_dim}).transpose(1, 2);
59-
k = k.reshape({-1, -1, num_heads, head_dim}).transpose(1, 2);
60-
v = v.reshape({-1, -1, num_heads, head_dim}).transpose(1, 2);
58+
q = q.reshape({*opt_sizes[0], -1, num_heads, head_dim}).transpose(1, 2);
59+
k = k.reshape({*opt_sizes[0], -1, num_heads, head_dim}).transpose(1, 2);
60+
v = v.reshape({*opt_sizes[0], -1, num_heads, head_dim}).transpose(1, 2);
6161
auto attn_output_weights = at::matmul(q, k.transpose(2, 3));
6262
attn_output_weights = at::softmax(attn_output_weights, -1);
6363
attn_output_weights = at::dropout(attn_output_weights, dropout_p, training);
6464
auto attn_output = at::matmul(attn_output_weights, v);
65-
attn_output = attn_output.transpose(1, 2).reshape({-1, -1, edim});
65+
attn_output = attn_output.transpose(1, 2).reshape({*opt_sizes[0], -1, edim});
6666
attn_output = at::matmul(attn_output, out_proj_weight.t());
6767
attn_output = attn_output + out_proj_bias;
6868
return attn_output;

nestedtensor/csrc/shape.cpp

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,8 @@ Tensor NestedTensor_view(const Tensor& self, IntArrayRef size) {
1313
TORCH_CHECK(
1414
int64_t(size.size()) > self_data->nested_dim(),
1515
"view cannot be exclusive to nested dimensions.");
16-
for (int64_t i = 0; i < self_data->nested_dim(); i++) {
17-
if (size[i] >= 0) {
18-
throw std::runtime_error(
19-
"Cannot view explicitly along irregular dimension " +
20-
std::to_string(i) + ". Please use -1 as a placeholder.");
21-
}
22-
}
16+
auto self_opt_sizes = get_opt_sizes(self);
17+
TORCH_CHECK(*self_opt_sizes[0] == size[0], "First dimension must be unchanged.");
2318
int64_t nested_dim = self_data->nested_dim();
2419
std::vector<int64_t> target_shape;
2520
for (int64_t i = nested_dim; i < int64_t(size.size()); i++) {
@@ -38,13 +33,8 @@ Tensor NestedTensor_reshape(const Tensor& self, IntArrayRef size) {
3833
TORCH_CHECK(
3934
int64_t(size.size()) > self_data->nested_dim(),
4035
"Reshape cannot be exclusive to nested dimensions.");
41-
for (int64_t i = 0; i < self_data->nested_dim(); i++) {
42-
if (size[i] >= 0) {
43-
throw std::runtime_error(
44-
"Cannot reshape explicitly along irregular dimension " +
45-
std::to_string(i) + ". Please use -1 as a placeholder.");
46-
}
47-
}
36+
auto self_opt_sizes = get_opt_sizes(self);
37+
TORCH_CHECK(*self_opt_sizes[0] == size[0], "First dimension must be unchanged.");
4838
int64_t nested_dim = self_data->nested_dim();
4939
std::vector<int64_t> target_shape;
5040
for (int64_t i = nested_dim; i < int64_t(size.size()); i++) {

nestedtensor/nested/fuser.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import torch.fx as fx
2+
from typing import Type, Dict, Any, Tuple, Iterable
3+
import torch
4+
import copy
5+
from torch.fx import symbolic_trace
6+
import time
7+
8+
def _parent_name(target : str) -> Tuple[str, str]:
9+
"""
10+
Splits a qualname into parent path and last atom.
11+
For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
12+
"""
13+
*parent, name = target.rsplit('.', 1)
14+
return parent[0] if parent else '', name
15+
16+
# Works for length 2 patterns with 2 modules
17+
def matches_module_pattern(pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any]):
18+
if len(node.args) == 0:
19+
return False
20+
nodes: Tuple[Any, fx.Node] = (node.args[0], node)
21+
for expected_type, current_node in zip(pattern, nodes):
22+
if not isinstance(current_node, fx.Node):
23+
return False
24+
if current_node.op != 'call_module':
25+
return False
26+
if not isinstance(current_node.target, str):
27+
return False
28+
if current_node.target not in modules:
29+
return False
30+
if type(modules[current_node.target]) is not expected_type:
31+
return False
32+
return True
33+
34+
35+
def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module):
36+
assert(isinstance(node.target, str))
37+
parent_name, name = _parent_name(node.target)
38+
setattr(modules[parent_name], name, new_module)
39+
40+
def computeUpdatedConvWeightAndBias(
41+
bn_rv,
42+
bn_eps,
43+
bn_w,
44+
bn_b,
45+
bn_rm,
46+
conv_w,
47+
conv_b=None):
48+
orig_dtype = bn_rv.dtype
49+
bn_var_rsqrt = (bn_w / torch.sqrt(bn_rv.to(torch.double) + bn_eps))
50+
new_w = (conv_w * (bn_var_rsqrt).reshape(-1, 1, 1, 1)).to(orig_dtype)
51+
if conv_b is None:
52+
return new_w
53+
new_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b
54+
return new_w, new_b
55+
56+
def fuse_conv_bn_eval(conv, bn):
57+
assert(not (conv.training or bn.training)), "Fusion only for eval!"
58+
fused_conv = copy.deepcopy(conv)
59+
fused_conv.bias = None
60+
61+
fused_conv.weight = \
62+
torch.nn.Parameter(computeUpdatedConvWeightAndBias(bn.running_var, bn.eps, bn.weight, bn.bias, bn.running_mean, fused_conv.weight))
63+
64+
return fused_conv
65+
66+
def fuse_conv_bn(model: torch.nn.Module, inplace=False) -> torch.nn.Module:
67+
"""
68+
Fuses convolution/BN layers for inference purposes. Will deepcopy your
69+
model by default, but can modify the model inplace as well.
70+
"""
71+
patterns = [(torch.nn.Conv2d, torch.nn.BatchNorm2d)]
72+
if not inplace:
73+
model = copy.deepcopy(model)
74+
fx_model = fx.symbolic_trace(model)
75+
modules = dict(fx_model.named_modules())
76+
new_graph = copy.deepcopy(fx_model.graph)
77+
78+
for pattern in patterns:
79+
for node in new_graph.nodes:
80+
if matches_module_pattern(pattern, node, modules):
81+
if len(node.args[0].users) > 1: # Output of conv is used by other nodes
82+
continue
83+
conv = modules[node.args[0].target]
84+
bn = modules[node.target]
85+
fused_conv = fuse_conv_bn_eval(conv, bn)
86+
replace_node_module(node.args[0], modules, fused_conv)
87+
node.replace_all_uses_with(node.args[0])
88+
new_graph.erase_node(node)
89+
return fx.GraphModule(fx_model, new_graph)

nestedtensor/version.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
__version__ = '0.1.4+33fb247'
2-
git_version = '33fb2477c856f8185f1e9c1e9a6ca28065e43cf9'
1+
__version__ = '0.1.4+66764fd'
2+
git_version = '66764fd10e9b6f9c0710840d0cb17369b9d994be'
33
from nestedtensor import _C
44
if hasattr(_C, 'CUDA_VERSION'):
55
cuda = _C.CUDA_VERSION

test/test_nested_tensor_integration.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33
import unittest
44
from utils_test_case import TestCase
55

6+
try:
7+
import classy_vision
8+
TEST_CLASSY_VISION=True
9+
except ModuleNotFoundError:
10+
TEST_CLASSY_VISION=False
11+
612

713
def ntnt(x): return nestedtensor.nested_tensor(x, requires_grad=True)
814
def ntnt_nograd(x): return nestedtensor.nested_tensor(x, requires_grad=False)
@@ -180,6 +186,32 @@ def test_transformer_forward(self):
180186
for t0, t1 in zip(res_nt.unbind(), [res_0, res_1]):
181187
self.assertEqual(t0, t1)
182188

189+
@unittest.skipIf(not TEST_CLASSY_VISION, "No classy vision")
190+
def test_fusion_resnext101_32x4d(self):
191+
@torch.inference_mode()
192+
def _test(dtype, use_channels_last):
193+
from classy_vision.models import build_model
194+
from torch.fx import symbolic_trace
195+
model = build_model({"name": "resnext101_32x4d"}).eval().cuda()
196+
model._initialize_weights(False)
197+
fused = symbolic_trace(model)
198+
fused = nestedtensor.fuse_conv_bn(fused)
199+
model = model.to(dtype)
200+
fused = fused.to(dtype)
201+
data = torch.randn(2, 3, 50, 50, device=torch.device('cuda'), dtype=dtype)
202+
if use_channels_last:
203+
data = data.contiguous(memory_format=torch.channels_last)
204+
ref_output = model(data)
205+
new_output = fused(data)
206+
if dtype == torch.float16:
207+
self.assertEqual(ref_output, new_output, prec=2e-3)
208+
else:
209+
self.assertEqual(ref_output, new_output)
210+
_test(torch.float16, False)
211+
_test(torch.float32, False)
212+
# _test(torch.float16, True)
213+
_test(torch.float32, True)
214+
183215

184216
if __name__ == "__main__":
185217
unittest.main()

0 commit comments

Comments
 (0)