Skip to content
This repository was archived by the owner on Nov 15, 2022. It is now read-only.

Commit 423a9e9

Browse files
cpuhrschfacebook-github-bot
authored andcommitted
Import nestedtensor 20210712
Summary: Import of recent nestedtensor repo Reviewed By: ngimel Differential Revision: D29672660 fbshipit-source-id: c52b031a6026c41596f03eceedec1214a950ae62
1 parent b4b7144 commit 423a9e9

35 files changed

+1717
-278
lines changed

benchmarks/classy.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def benchmark_torch_function(iters, f, *args, **kwargs):
2121
if torch.cuda.is_available():
2222
end_event.record()
2323
torch.cuda.synchronize()
24-
return start_event.elapsed_time(end_event) / 1e3
24+
return start_event.elapsed_time(end_event)
2525
else:
2626
return (time.time() - t0)
2727

@@ -33,43 +33,53 @@ def run_benchmark(iters, shapes, model, model_name, bsz):
3333
inp = torch.randn(*s, dtype=torch.half).cuda()
3434
ts.append(inp)
3535
ts_nt = nestedtensor.nested_tensor([t.squeeze(0) for t in ts], device=torch.device('cuda'), dtype=torch.half)
36+
ts_padded = ts_nt.to_padded_tensor()
37+
ts_nt = nestedtensor.nested_tensor([t.squeeze(0) for t in ts], device=torch.device('cuda'), dtype=torch.half, channels_last=True)
3638

3739
def _loop():
3840
model_outputs = []
3941
for inp in ts:
4042
model_outputs.append(model(inp))
4143
return model_outputs
4244

45+
def _padded():
46+
return model(ts_padded)
47+
4348
# Test
4449
outputs_nt = model(ts_nt)
50+
# import time; time.sleep(1)
51+
# outputs_nt = model(ts_nt)
52+
# import sys; sys.exit(1)
4553
model_outputs = _loop()
4654
for mo, ntmo in zip(model_outputs, outputs_nt.unbind()):
4755
# Using float16 tolerances from torch/testing/_core.yp
4856
assert torch.allclose(mo.squeeze(0), ntmo, rtol=1e-3, atol=1e-3)
4957

5058
loop_time = benchmark_torch_function(iters, _loop)
59+
padded_time = benchmark_torch_function(iters, _padded)
5160
nt_time = benchmark_torch_function(iters, lambda: model(ts_nt))
5261

5362
shapes_2_array = np.array([s[2] for s in shapes])
5463
shapes_3_array = np.array([s[3] for s in shapes])
5564
print(f"model_name: {model_name.rjust(18)},", end='')
56-
print(f" bsz: {bsz},", end='')
65+
print(f" bsz: {bsz:3.0f},", end='')
5766
print(f" mean±std shapes[2]: {shapes_2_array.mean():.2f}±{shapes_2_array.std():.2f},", end='')
5867
print(f" mean±std shapes[3]: {shapes_3_array.mean():.2f}±{shapes_3_array.std():.2f},", end='')
59-
print(f" loop: {loop_time / iters:.2f}s, nt: {nt_time / iters:.2f}s, speedup: {loop_time / nt_time:.2f}x")
68+
print(f" padded_size: {tuple(ts_padded.size())},", end='')
69+
print(f" loop: {loop_time / iters:7.2f}ms, nt: {nt_time / iters:7.2f}ms, padded: {padded_time / iters:7.2f}ms, speedup: {loop_time / nt_time:.2f}x")
6070

6171
if __name__ == "__main__":
72+
iters = 10
73+
6274
def _benchmark(model_name, bsz):
6375
model = build_model({"name": model_name})
6476
model = model.cuda().half().eval()
6577
random.seed(123)
6678
shapes = [(1, 3, random.randint(100, 600), random.randint(100, 600)) for _ in range(bsz)]
67-
run_benchmark(1, shapes, model, model_name, bsz)
79+
run_benchmark(iters, shapes, model, model_name, bsz)
80+
81+
for bsz in [16, 32, 64, 128]:
82+
_benchmark("resnext101_32x4d", bsz)
6883

69-
_benchmark("resnext101_32x4d", 64)
70-
_benchmark("resnext101_32x4d", 128)
71-
_benchmark("resnext101_32x4d", 256)
72-
_benchmark("regnet_y_128gf", 64)
73-
_benchmark("regnet_y_128gf", 128)
74-
# Runs out of memory
75-
# _benchmark("regnet_y_128gf", 256)
84+
for bsz in [16, 32]:
85+
_benchmark("regnet_y_128gf", bsz)

benchmarks/gat.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import torch
2+
import torch.nn.functional as F
3+
from torch_geometric.nn import GATConv
4+
import random
5+
import time
6+
import nestedtensor
7+
from nestedtensor import nested_tensor as ntnt
8+
9+
@torch.inference_mode()
10+
def benchmark_torch_function(iters, f, *args, **kwargs):
11+
f(*args, **kwargs)
12+
if torch.cuda.is_available():
13+
torch.cuda.synchronize()
14+
start_event = torch.cuda.Event(enable_timing=True)
15+
end_event = torch.cuda.Event(enable_timing=True)
16+
start_event.record()
17+
else:
18+
t0 = time.time()
19+
for _ in range(iters):
20+
f(*args, **kwargs)
21+
if torch.cuda.is_available():
22+
end_event.record()
23+
torch.cuda.synchronize()
24+
return start_event.elapsed_time(end_event)
25+
else:
26+
return (time.time() - t0)
27+
28+
29+
num_features = 1433
30+
num_classes = 7
31+
32+
33+
class Net(torch.nn.Module):
34+
def __init__(self):
35+
super(Net, self).__init__()
36+
self.conv1 = GATConv(num_features, 8, heads=8,
37+
dropout=0.6)
38+
39+
self.conv2 = GATConv(64, num_classes, heads=1, concat=True,
40+
dropout=0.6)
41+
42+
def forward(self, x, edge_index):
43+
x = F.dropout(x, p=0.6, training=self.training)
44+
x = F.elu(self.conv1(x, edge_index))
45+
x = F.dropout(x, p=0.6, training=self.training)
46+
x = self.conv2(x, edge_index)
47+
return F.log_softmax(x, dim=1)
48+
49+
50+
class NTNet(torch.nn.Module):
51+
def __init__(self):
52+
super(NTNet, self).__init__()
53+
self.conv1 = GATConv(num_features, 8, heads=8,
54+
dropout=0.6)
55+
56+
self.conv2 = GATConv(64, num_classes, heads=1, concat=True,
57+
dropout=0.6)
58+
59+
def forward(self, x, edge_index):
60+
x = F.dropout(x, p=0.6, training=self.training)
61+
x = ntnt([self.conv1(xi, edge_index_i) for (xi, edge_index_i) in zip(x.unbind(), edge_index.unbind())], dtype=x.dtype, device=x.device)
62+
x = F.elu(x)
63+
x = F.dropout(x, p=0.6, training=self.training)
64+
x = ntnt([self.conv2(xi, edge_index_i) for (xi, edge_index_i) in zip(x.unbind(), edge_index.unbind())], dtype=x.dtype, device=x.device)
65+
return F.log_softmax(x, dim=1)
66+
67+
68+
def create_models(device):
69+
model = Net().to(device).eval()
70+
nt_model = NTNet().to(device).eval()
71+
return model, nt_model
72+
73+
def create_tensors():
74+
random.seed(1010)
75+
nnodes_list = []
76+
nedges_list = []
77+
for i in range(50):
78+
nnodes_list.append(random.randint(100, 4000))
79+
nedges_list.append(random.randint(8000, 15000))
80+
81+
tensors_x = []
82+
tensors_edge_index = []
83+
for nnodes, nedges in zip(nnodes_list, nedges_list):
84+
x = torch.normal(-10, 4, (nnodes, 1433))
85+
x[x < 0] = 0.
86+
x[x > 1] = 1.
87+
edge_index = torch.randint(0, nnodes, (2, nedges), dtype=torch.int64)
88+
tensors_x.append(x)
89+
tensors_edge_index.append(edge_index)
90+
return tensors_x, tensors_edge_index
91+
92+
93+
@torch.inference_mode()
94+
def loop(model, tensors_x, tensors_edge_index):
95+
for x, edge_index in zip(tensors_x, tensors_edge_index):
96+
model(x, edge_index)
97+
98+
99+
@torch.inference_mode()
100+
def nt(nt_model, nt_x, nt_edge_index):
101+
nt_model(nt_x, nt_edge_index)
102+
103+
if __name__ == "__main__":
104+
device = torch.device('cuda')
105+
model, nt_model = create_models(device)
106+
tensors_x, tensors_edge_index = create_tensors()
107+
print(benchmark_torch_function(10, loop, model, tensors_x, tensors_edge_index))
108+
nt_x = ntnt(tensors_x, device=device)
109+
nt_edge_index = ntnt(tensors_edge_index, device=device, dtype=torch.int64)
110+
print(benchmark_torch_function(10, nt, nt_model, nt_x, nt_edge_index))

nestedtensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
from .nested.nested import NestedTensor
1010
from .nested.nested import to_nested_tensor
11+
from .nested.nested import transpose_nchw_nhwc
12+
from .nested.nested import transpose_nhwc_nchw
1113

1214
from . import nested
1315

nestedtensor/csrc/BinaryOps.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@ Tensor NestedTensor_add_Tensor(
2222
get_efficient_nested_size(other);
2323
if (efficient_size_matches(
2424
self_efficient_nested_size, other_efficient_nested_size)) {
25+
if (get_is_contiguous(self, c10::MemoryFormat::ChannelsLast) &&
26+
get_is_contiguous(other, c10::MemoryFormat::ChannelsLast)) {
27+
return wrap_buffer(
28+
at::add(
29+
get_buffer(self).view({-1}), get_buffer(other).view({-1})),
30+
self_efficient_nested_size,
31+
get_efficient_nested_stride(self));
32+
}
2533
if (!get_is_contiguous(self)) {
2634
self = NestedTensor_contiguous(self);
2735
}

nestedtensor/csrc/SoftMax.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,27 @@ Tensor NestedTensor_softmax(
2727
input);
2828
}
2929

30+
Tensor NestedTensor_log_softmax(
31+
const Tensor& input,
32+
const int64_t dim_,
33+
c10::optional<ScalarType> dtype) {
34+
int64_t dim = maybe_wrap_dim(dim_, get_dim(input));
35+
auto input_data = get_nested_tensor_impl(input);
36+
int64_t nested_dim = input_data->nested_dim();
37+
TORCH_CHECK(
38+
dim >= nested_dim,
39+
"Cannot apply log_softmax across nested dimensions ",
40+
std::to_string(dim));
41+
return map_nested_tensor(
42+
[dim, nested_dim, dtype](const at::Tensor t) {
43+
return at::log_softmax(t, dim - nested_dim, dtype);
44+
},
45+
input);
46+
}
47+
3048
TORCH_LIBRARY_IMPL(aten, NestedTensor, m) {
3149
nt_impl(m, "softmax.int", NestedTensor_softmax);
50+
nt_impl(m, "log_softmax.int", NestedTensor_log_softmax);
3251
}
3352

3453
} // namespace at

nestedtensor/csrc/activation.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,17 @@ Tensor NestedTensor_gelu(const Tensor& self) {
1919
[](at::Tensor tensor) { return at::gelu(tensor); }, self);
2020
}
2121

22+
Tensor NestedTensor_elu(const Tensor& self, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale) {
23+
if (is_nested_tensor_impl(self) && get_is_contiguous(self)) {
24+
return wrap_buffer(
25+
at::elu(get_buffer(self), alpha, scale, input_scale),
26+
get_efficient_nested_size(self),
27+
get_efficient_nested_stride(self));
28+
}
29+
return map_nested_tensor(
30+
[&alpha, &scale, &input_scale](at::Tensor tensor) { return at::elu(tensor, alpha, scale, input_scale); }, self);
31+
}
32+
2233
// Registered below autograd
2334
Tensor NestedTensor_relu(const Tensor& self) {
2435
auto impl = get_nested_tensor_impl(self);
@@ -37,7 +48,7 @@ Tensor NestedTensor_relu(const Tensor& self) {
3748

3849
// Registered below autograd
3950
Tensor& NestedTensor_relu_(Tensor& self) {
40-
if (get_is_contiguous(self)) {
51+
if (get_is_contiguous(self) || get_is_contiguous(self, c10::MemoryFormat::ChannelsLast)) {
4152
#ifdef TRACEPACKED
4253
std::cout << "calling packed relu_" << std::endl;
4354
#endif
@@ -51,6 +62,7 @@ Tensor& NestedTensor_relu_(Tensor& self) {
5162

5263
TORCH_LIBRARY_IMPL(aten, NestedTensor, m) {
5364
nt_impl(m, "gelu", NestedTensor_gelu);
65+
nt_impl(m, "elu", NestedTensor_elu);
5466
}
5567

5668
TORCH_LIBRARY_IMPL(aten, NestedTensor, m) {

nestedtensor/csrc/autograd_functions.cpp

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ Tensor NestedTensor_batch_norm(
105105
check_dims_match_num_input_features("bias", n_input, get_numel(*bias));
106106
}
107107

108-
auto scalar_shape = make_scalar_shape(get_dim(input), n_input);
109108
at::Tensor mean = *running_mean;
110109
at::Tensor var = *running_var;
111110
#ifdef WITH_CUDA
@@ -120,46 +119,64 @@ Tensor NestedTensor_batch_norm(
120119
(mean.dtype() == torch::kHalf) &&
121120
(var.dtype() == torch::kHalf) &&
122121
(bias->dtype() == torch::kHalf) &&
123-
(weight->dtype() == torch::kHalf)
122+
(weight->dtype() == torch::kHalf) &&
123+
get_is_cuda(input)
124124
)
125125
{
126-
127126
// Custom CUDA Half implementation.
128127
mean = mean.contiguous();
129128
Tensor bias_cont = (*bias).contiguous();
130129
Tensor weight_cont = (*weight).contiguous();
131130
Tensor running_var_cont = (*running_var).contiguous();
131+
132+
c10::Half* mean_ptr = mean.data_ptr<c10::Half>();
133+
c10::Half* bias_ptr = bias_cont.data_ptr<c10::Half>();
134+
c10::Half* weight_ptr = weight_cont.data_ptr<c10::Half>();
135+
c10::Half* running_var_ptr = running_var_cont.data_ptr<c10::Half>();
136+
137+
if (get_is_contiguous(input, c10::MemoryFormat::ChannelsLast)) {
138+
Tensor input_buffer = get_buffer(input);
139+
int64_t num_channel = weight_cont.size(0);
140+
at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream();
141+
nested_tensor::cuda::batchnorm_inference_channels_last_kernelLauncher(
142+
input_buffer.data_ptr<c10::Half>(),
143+
mean_ptr,
144+
running_var_ptr,
145+
c10::Half((float)(eps)),
146+
weight_ptr,
147+
bias_ptr,
148+
input_buffer.data_ptr<c10::Half>(),
149+
num_channel,
150+
input_buffer.numel(),
151+
defaultStream);
152+
input_buffer = input_buffer.view(-1);
153+
return wrap_buffer(std::move(input_buffer), get_efficient_nested_size(input), get_efficient_nested_stride(input));
154+
}
132155

133156
Tensor output = input;
134157
output = NestedTensor_contiguous(output);
135158
Tensor input_buffer = get_buffer(output);
136-
Tensor output_buffer = input_buffer.clone();
159+
// Tensor output_buffer = input_buffer.clone();
137160

138161
auto self_opt_sizes = get_opt_sizes(input);
139162

140163
Tensor nt_sizes_ =
141-
get_efficient_nested_size(input).sizes().to(torch::kInt32);
164+
get_efficient_nested_size(input).sizes(); // .to(torch::kInt32);
142165
Tensor nt_sizes_1 = at::native::narrow(nt_sizes_, 1, 1, 1);
143166
Tensor nt_sizes_2 = at::native::narrow(nt_sizes_, 1, 2, 1);
144167
Tensor nt_sizes_all = nt_sizes_1 * nt_sizes_2;
145-
int* nt_sizes_all_ptr = nt_sizes_all.data_ptr<int>();
146-
std::vector<int> numbers;
147-
numbers.reserve(1 + (nt_sizes_all.size(0) * *self_opt_sizes[1]));
148-
numbers.push_back(0);
168+
int64_t* nt_sizes_all_ptr = nt_sizes_all.data_ptr<int64_t>();
169+
at::Tensor numbers_t = at::empty({1 + (nt_sizes_all.size(0) * *self_opt_sizes[1])}, torch::kInt64);
170+
int64_t* numbers_t_ptr = numbers_t.data_ptr<int64_t>();
171+
numbers_t_ptr[0] = 0;
149172
int64_t index = 1;
150173
for (int64_t i = 0; i < nt_sizes_all.size(0); i++) {
151174
for (int64_t j = 0; j < *self_opt_sizes[1]; j++) {
152-
numbers.push_back(numbers[index - 1] + nt_sizes_all_ptr[i]);
175+
numbers_t_ptr[index] = (numbers_t_ptr[index - 1] + nt_sizes_all_ptr[i]);
153176
index++;
154177
}
155178
}
156-
at::Tensor numbers_t = torch::tensor(numbers).to(torch::kInt32);
157-
Tensor nt_sizes = numbers_t.to(torch::kCUDA);
158-
159-
c10::Half* mean_ptr = mean.data_ptr<c10::Half>();
160-
c10::Half* running_var_ptr = running_var_cont.data_ptr<c10::Half>();
161-
c10::Half* bias_ptr = bias_cont.data_ptr<c10::Half>();
162-
c10::Half* weight_ptr = weight_cont.data_ptr<c10::Half>();
179+
Tensor nt_sizes = numbers_t.to(at::Device(kCUDA), torch::kInt32, true, true);
163180

164181
at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream();
165182
nested_tensor::cuda::batchnorm_inference_kernelLauncher(
@@ -169,15 +186,21 @@ Tensor NestedTensor_batch_norm(
169186
c10::Half((float)(eps)),
170187
weight_ptr,
171188
bias_ptr,
172-
output_buffer.data_ptr<c10::Half>(),
173-
(int)(*self_opt_sizes[0] * *self_opt_sizes[1]),
189+
input_buffer.data_ptr<c10::Half>(),
190+
// output_buffer.data_ptr<c10::Half>(),
174191
(int)(*self_opt_sizes[0]),
192+
(int)(weight_cont.size(0)),
193+
(int)(*self_opt_sizes[0] *
194+
*self_opt_sizes[1] *
195+
*self_opt_sizes[2] *
196+
*self_opt_sizes[3]),
175197
nt_sizes.data_ptr<int>(),
176198
defaultStream
177199
);
178-
return wrap_buffer(std::move(output_buffer), get_efficient_nested_size(output), get_efficient_nested_stride(output));
200+
return wrap_buffer(std::move(input_buffer), get_efficient_nested_size(output), get_efficient_nested_stride(output));
179201
}
180202
#endif
203+
auto scalar_shape = make_scalar_shape(get_dim(input), n_input);
181204

182205
at::Tensor invstd = 1 / at::sqrt(*running_var + eps);
183206

0 commit comments

Comments
 (0)