Skip to content
This repository was archived by the owner on Nov 15, 2022. It is now read-only.

Commit 3afc66d

Browse files
cpuhrschfacebook-github-bot
authored andcommitted
Import nestedtensor 20210618
Summary: Import of latest changes Reviewed By: zou3519 Differential Revision: D29238347 fbshipit-source-id: 2abfbc7ef3a8359354450e144ba55d8d5c7506f3
1 parent 825729c commit 3afc66d

33 files changed

+1443
-1506
lines changed

benchmarks/classy.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import torch
2+
import numpy as np
3+
import time
4+
import random
5+
import nestedtensor
6+
from classy_vision.models import build_model
7+
8+
9+
@torch.inference_mode()
10+
def benchmark_torch_function(iters, f, *args, **kwargs):
11+
f(*args, **kwargs)
12+
if torch.cuda.is_available():
13+
torch.cuda.synchronize()
14+
start_event = torch.cuda.Event(enable_timing=True)
15+
end_event = torch.cuda.Event(enable_timing=True)
16+
start_event.record()
17+
else:
18+
t0 = time.time()
19+
for _ in range(iters):
20+
f(*args, **kwargs)
21+
if torch.cuda.is_available():
22+
end_event.record()
23+
torch.cuda.synchronize()
24+
return start_event.elapsed_time(end_event) / 1e3
25+
else:
26+
return (time.time() - t0)
27+
28+
29+
@torch.inference_mode()
30+
def run_benchmark(iters, shapes, model, model_name, bsz):
31+
ts = []
32+
for s in shapes:
33+
inp = torch.randn(*s, dtype=torch.half).cuda()
34+
ts.append(inp)
35+
ts_nt = nestedtensor.nested_tensor([t.squeeze(0) for t in ts], device=torch.device('cuda'), dtype=torch.half)
36+
37+
def _loop():
38+
model_outputs = []
39+
for inp in ts:
40+
model_outputs.append(model(inp))
41+
return model_outputs
42+
43+
# Test
44+
outputs_nt = model(ts_nt)
45+
model_outputs = _loop()
46+
for mo, ntmo in zip(model_outputs, outputs_nt.unbind()):
47+
# Using float16 tolerances from torch/testing/_core.yp
48+
assert torch.allclose(mo.squeeze(0), ntmo, rtol=1e-3, atol=1e-3)
49+
50+
loop_time = benchmark_torch_function(iters, _loop)
51+
nt_time = benchmark_torch_function(iters, lambda: model(ts_nt))
52+
53+
shapes_2_array = np.array([s[2] for s in shapes])
54+
shapes_3_array = np.array([s[3] for s in shapes])
55+
print(f"model_name: {model_name.rjust(18)},", end='')
56+
print(f" bsz: {bsz},", end='')
57+
print(f" mean±std shapes[2]: {shapes_2_array.mean():.2f}±{shapes_2_array.std():.2f},", end='')
58+
print(f" mean±std shapes[3]: {shapes_3_array.mean():.2f}±{shapes_3_array.std():.2f},", end='')
59+
print(f" loop: {loop_time / iters:.2f}s, nt: {nt_time / iters:.2f}s, speedup: {loop_time / nt_time:.2f}x")
60+
61+
if __name__ == "__main__":
62+
def _benchmark(model_name, bsz):
63+
model = build_model({"name": model_name})
64+
model = model.cuda().half().eval()
65+
random.seed(123)
66+
shapes = [(1, 3, random.randint(100, 600), random.randint(100, 600)) for _ in range(bsz)]
67+
run_benchmark(1, shapes, model, model_name, bsz)
68+
69+
_benchmark("resnext101_32x4d", 64)
70+
_benchmark("resnext101_32x4d", 128)
71+
_benchmark("resnext101_32x4d", 256)
72+
_benchmark("regnet_y_128gf", 64)
73+
_benchmark("regnet_y_128gf", 128)
74+
# Runs out of memory
75+
# _benchmark("regnet_y_128gf", 256)

benchmarks/conv2d.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import torch
2+
import time
3+
import nestedtensor
4+
5+
6+
@torch.inference_mode()
7+
def benchmark_torch_function(iters, f, *args):
8+
f(*args)
9+
if torch.cuda.is_available():
10+
torch.cuda.synchronize()
11+
start_event = torch.cuda.Event(enable_timing=True)
12+
end_event = torch.cuda.Event(enable_timing=True)
13+
start_event.record()
14+
else:
15+
t0 = time.time()
16+
for _ in range(iters):
17+
f(*args)
18+
if torch.cuda.is_available():
19+
end_event.record()
20+
torch.cuda.synchronize()
21+
return start_event.elapsed_time(end_event)
22+
else:
23+
return (time.time() - t0) * 1e3
24+
25+
26+
# def run(bdim, embedding_dim, out_dim, min_t, max_t, iters, device):
27+
def run(bdim, nchannel, min_t, max_t, iters, device):
28+
import random
29+
random.seed(1010)
30+
31+
# The following is meant to emulate the lenghts of randomly sampled tokenized sentences
32+
lengths1 = [random.randint(min_t, max_t) for _ in range(bdim)]
33+
lengths2 = [random.randint(min_t, max_t) for _ in range(bdim)]
34+
35+
# List of sentence embeddings
36+
tensors = [torch.rand(nchannel, l1, l2).to(device=device, dtype=torch.float) for (l1, l2) in zip(lengths1, lengths2)]
37+
# Create packed NestedTensor
38+
nt = nestedtensor.nested_tensor(tensors, device=device, dtype=torch.float)
39+
40+
lin = torch.nn.Conv2d(nchannel, nchannel, (1, 1), bias=False).to(device)
41+
42+
def _loop(tensors):
43+
result = []
44+
for t in tensors:
45+
result.append(lin(t.unsqueeze(0)).squeeze(0))
46+
return result
47+
48+
nt_time = benchmark_torch_function(iters, lin, nt)
49+
t_time = benchmark_torch_function(iters, _loop, tensors)
50+
51+
# print(f"batch size: {bdim:4.0f}, embedding dim: {embedding_dim}, out_dim: {out_dim}, T mean:{lengths_mean:5.0f}, T std: {lengths_std:4.0f}", end='')
52+
print(f"batch size: {bdim:4.0f}, nchannel: {nchannel:4.0f}", end='')
53+
# print(f", padding: {percentage_padded:3.0f}%, NT: {nt_time/iters:4.0f}ms, T: {t_time/iters:4.0f}ms, Speedup: {t_time/nt_time:3.2f}x")
54+
print(f", NT: {nt_time/iters:4.0f}ms, T: {t_time/iters:4.0f}ms, Speedup: {t_time/nt_time:3.2f}x")
55+
56+
57+
if torch.cuda.is_available():
58+
print("CUDA device: ", torch.cuda.get_device_name(0))
59+
iters = 10
60+
for nchannel in [3, 128, 256, 512]:
61+
for min_t, max_t in [(16, 128), (32, 128), (64, 128), (128, 128)]:
62+
run(256, nchannel, min_t, max_t, iters, torch.device('cuda'))
63+
break

nestedtensor/csrc/BinaryOps.cpp

Lines changed: 167 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
#include <nestedtensor/csrc/BinaryOps.h>
2+
#ifdef WITH_CUDA
3+
#include <c10/cuda/CUDAStream.h>
4+
#include <nestedtensor/csrc/cuda/add.h>
5+
#include <c10/util/Half.h>
6+
#endif
27

38
namespace at {
49

@@ -31,11 +36,56 @@ Tensor NestedTensor_add_Tensor(
3136
}
3237
}
3338
if (is_nested_tensor_impl(self) && !is_nested_tensor_impl(other)) {
34-
if (!get_is_contiguous(self)) {
35-
self = NestedTensor_contiguous(self);
36-
}
39+
self = NestedTensor_contiguous(self);
3740
int64_t self_dim = get_dim(self);
3841
auto self_opt_sizes = get_opt_sizes(self);
42+
#ifdef WITH_CUDA
43+
if (self_dim == 4 && other.dim() == 4 &&
44+
self_opt_sizes[0] &&
45+
self_opt_sizes[1] &&
46+
(*self_opt_sizes[1]) == other.size(1) &&
47+
other.size(0) == 1 &&
48+
other.size(2) == 1 &&
49+
other.size(3) == 1 &&
50+
self.dtype() == c10::ScalarType::Half &&
51+
other.dtype() == c10::ScalarType::Half) {
52+
other = other.contiguous();
53+
at::Tensor self_buffer = get_buffer(self);
54+
Tensor nt_sizes_ =
55+
get_efficient_nested_size(self).sizes().to(torch::kInt32);
56+
Tensor nt_sizes_1 = at::native::narrow(nt_sizes_, 1, 1, 1);
57+
Tensor nt_sizes_2 = at::native::narrow(nt_sizes_, 1, 2, 1);
58+
Tensor nt_sizes_all = nt_sizes_1 * nt_sizes_2;
59+
std::vector<int> numbers;
60+
for (int64_t i = 0; i < nt_sizes_all.size(0); i++) {
61+
for (int64_t j = 0; j < *self_opt_sizes[1]; j++) {
62+
numbers.push_back(nt_sizes_all[i].item<int>());
63+
}
64+
}
65+
at::Tensor numbers_t = torch::tensor(numbers).to(torch::kInt32);
66+
Tensor nt_sizes_cumsum =
67+
at::native::cumsum(numbers_t, 0).to(torch::kInt32).reshape({-1});
68+
TORCH_CHECK(nt_sizes_.dim() == 2, "NestedTensor metadata of unexpected dimension.")
69+
Tensor nt_sizes = at::cat({torch::tensor({0}, torch::kInt32), nt_sizes_cumsum});
70+
nt_sizes = nt_sizes.to(torch::kCUDA);
71+
at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream();
72+
at::Tensor result_buffer = self_buffer.clone();
73+
74+
c10::Half* self_ptr = self_buffer.data_ptr<c10::Half>();
75+
c10::Half* other_ptr = other.data_ptr<c10::Half>();
76+
c10::Half* result_ptr = result_buffer.data_ptr<c10::Half>();
77+
nested_tensor::cuda::add_scalar_kernelLauncher(
78+
self_ptr,
79+
other_ptr,
80+
result_ptr,
81+
(int)(*self_opt_sizes[0] * *self_opt_sizes[1]),
82+
(int)(*self_opt_sizes[0]),
83+
nt_sizes.data_ptr<int>(),
84+
defaultStream);
85+
return wrap_buffer(std::move(result_buffer), get_efficient_nested_size(self),
86+
get_efficient_nested_stride(self));
87+
}
88+
#endif
3989
if (self_opt_sizes[self_dim - 1] && other.dim() == 1 &&
4090
(*(self_opt_sizes[self_dim - 1])) == other.size(0)) {
4191
Tensor self_buffer = get_buffer(self);
@@ -50,7 +100,8 @@ Tensor NestedTensor_add_Tensor(
50100
}
51101
std::tie(self, other) = _expand_other_as(self_, other_);
52102
return map_nested_tensor(
53-
[&alpha](Tensor s, Tensor o) { return at::add(s, o, alpha); },
103+
[&alpha](Tensor s, Tensor o) {
104+
return at::add(s, o, alpha); },
54105
self,
55106
other);
56107
}
@@ -180,11 +231,64 @@ Tensor& NestedTensor_floor_divide_out(
180231
}
181232

182233
Tensor NestedTensor_mul_Tensor(const Tensor& self_, const Tensor& other_) {
183-
Tensor self;
184-
Tensor other;
234+
Tensor self = self_;
235+
Tensor other = other_;
236+
if (is_nested_tensor_impl(self) && !is_nested_tensor_impl(other)) {
237+
self = NestedTensor_contiguous(self);
238+
int64_t self_dim = get_dim(self);
239+
auto self_opt_sizes = get_opt_sizes(self);
240+
#ifdef WITH_CUDA
241+
if (self_dim == 4 && other.dim() == 4 &&
242+
self_opt_sizes[0] &&
243+
self_opt_sizes[1] &&
244+
(*self_opt_sizes[1]) == other.size(1) &&
245+
other.size(0) == 1 &&
246+
other.size(2) == 1 &&
247+
other.size(3) == 1 &&
248+
self.dtype() == c10::ScalarType::Half &&
249+
other.dtype() == c10::ScalarType::Half) {
250+
other = other.contiguous();
251+
at::Tensor self_buffer = get_buffer(self);
252+
Tensor nt_sizes_ =
253+
get_efficient_nested_size(self).sizes().to(torch::kInt32);
254+
Tensor nt_sizes_1 = at::native::narrow(nt_sizes_, 1, 1, 1);
255+
Tensor nt_sizes_2 = at::native::narrow(nt_sizes_, 1, 2, 1);
256+
Tensor nt_sizes_all = nt_sizes_1 * nt_sizes_2;
257+
std::vector<int> numbers;
258+
for (int64_t i = 0; i < nt_sizes_all.size(0); i++) {
259+
for (int64_t j = 0; j < *self_opt_sizes[1]; j++) {
260+
numbers.push_back(nt_sizes_all[i].item<int>());
261+
}
262+
}
263+
at::Tensor numbers_t = torch::tensor(numbers).to(torch::kInt32);
264+
Tensor nt_sizes_cumsum =
265+
at::native::cumsum(numbers_t, 0).to(torch::kInt32).reshape({-1});
266+
TORCH_CHECK(nt_sizes_.dim() == 2, "NestedTensor metadata of unexpected dimension.")
267+
Tensor nt_sizes = at::cat({torch::tensor({0}, torch::kInt32), nt_sizes_cumsum});
268+
nt_sizes = nt_sizes.to(torch::kCUDA);
269+
at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream();
270+
at::Tensor result_buffer = self_buffer.clone();
271+
272+
c10::Half* self_ptr = self_buffer.data_ptr<c10::Half>();
273+
c10::Half* other_ptr = other.data_ptr<c10::Half>();
274+
c10::Half* result_ptr = result_buffer.data_ptr<c10::Half>();
275+
nested_tensor::cuda::mul_scalar_kernelLauncher(
276+
self_ptr,
277+
other_ptr,
278+
result_ptr,
279+
(int)(*self_opt_sizes[0] * *self_opt_sizes[1]),
280+
(int)(*self_opt_sizes[0]),
281+
nt_sizes.data_ptr<int>(),
282+
defaultStream);
283+
return wrap_buffer(std::move(result_buffer), get_efficient_nested_size(self),
284+
get_efficient_nested_stride(self));
285+
}
286+
#endif
287+
}
185288
std::tie(self, other) = _expand_other_as(self_, other_);
186289
return map_nested_tensor(
187-
[](Tensor s, Tensor o) { return at::mul(s, o); }, self, other);
290+
[](Tensor s, Tensor o) {
291+
return at::mul(s, o); }, self, other);
188292
}
189293

190294
Tensor& NestedTensor_mul__Tensor(Tensor& self_, const Tensor& other_) {
@@ -246,11 +350,64 @@ Tensor NestedTensor_sub_Tensor(
246350
const Tensor& self_,
247351
const Tensor& other_,
248352
const Scalar& alpha) {
249-
Tensor self;
250-
Tensor other;
353+
Tensor self = self_;
354+
Tensor other = other_;
355+
if (is_nested_tensor_impl(self) && !is_nested_tensor_impl(other)) {
356+
self = NestedTensor_contiguous(self);
357+
int64_t self_dim = get_dim(self);
358+
auto self_opt_sizes = get_opt_sizes(self);
359+
#ifdef WITH_CUDA
360+
if (self_dim == 4 && other.dim() == 4 &&
361+
self_opt_sizes[0] &&
362+
self_opt_sizes[1] &&
363+
(*self_opt_sizes[1]) == other.size(1) &&
364+
other.size(0) == 1 &&
365+
other.size(2) == 1 &&
366+
other.size(3) == 1 &&
367+
self.dtype() == c10::ScalarType::Half &&
368+
other.dtype() == c10::ScalarType::Half) {
369+
other = other.contiguous();
370+
at::Tensor self_buffer = get_buffer(self);
371+
Tensor nt_sizes_ =
372+
get_efficient_nested_size(self).sizes().to(torch::kInt32);
373+
Tensor nt_sizes_1 = at::native::narrow(nt_sizes_, 1, 1, 1);
374+
Tensor nt_sizes_2 = at::native::narrow(nt_sizes_, 1, 2, 1);
375+
Tensor nt_sizes_all = nt_sizes_1 * nt_sizes_2;
376+
std::vector<int> numbers;
377+
for (int64_t i = 0; i < nt_sizes_all.size(0); i++) {
378+
for (int64_t j = 0; j < *self_opt_sizes[1]; j++) {
379+
numbers.push_back(nt_sizes_all[i].item<int>());
380+
}
381+
}
382+
at::Tensor numbers_t = torch::tensor(numbers).to(torch::kInt32);
383+
Tensor nt_sizes_cumsum =
384+
at::native::cumsum(numbers_t, 0).to(torch::kInt32).reshape({-1});
385+
TORCH_CHECK(nt_sizes_.dim() == 2, "NestedTensor metadata of unexpected dimension.")
386+
Tensor nt_sizes = at::cat({torch::tensor({0}, torch::kInt32), nt_sizes_cumsum});
387+
nt_sizes = nt_sizes.to(torch::kCUDA);
388+
at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream();
389+
at::Tensor result_buffer = self_buffer.clone();
390+
391+
c10::Half* self_ptr = self_buffer.data_ptr<c10::Half>();
392+
c10::Half* other_ptr = other.data_ptr<c10::Half>();
393+
c10::Half* result_ptr = result_buffer.data_ptr<c10::Half>();
394+
nested_tensor::cuda::sub_scalar_kernelLauncher(
395+
self_ptr,
396+
other_ptr,
397+
result_ptr,
398+
(int)(*self_opt_sizes[0] * *self_opt_sizes[1]),
399+
(int)(*self_opt_sizes[0]),
400+
nt_sizes.data_ptr<int>(),
401+
defaultStream);
402+
return wrap_buffer(std::move(result_buffer), get_efficient_nested_size(self),
403+
get_efficient_nested_stride(self));
404+
}
405+
#endif
406+
}
251407
std::tie(self, other) = _expand_other_as(self_, other_);
252408
return map_nested_tensor(
253-
[&alpha](Tensor s, Tensor o) { return at::sub(s, o, alpha); },
409+
[&alpha](Tensor s, Tensor o) {
410+
return at::sub(s, o, alpha); },
254411
self,
255412
other);
256413
}

nestedtensor/csrc/activation.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,24 @@ Tensor NestedTensor_relu(const Tensor& self) {
2727
#ifdef TRACEPACKED
2828
std::cout << "calling packed relu" << std::endl;
2929
#endif
30-
return wrap_buffer(at::relu(get_buffer(self)), impl->nested_size());
30+
return wrap_buffer(at::relu(get_buffer(self)),
31+
get_efficient_nested_size(self),
32+
get_efficient_nested_stride(self));
3133
}
3234
return map_nested_tensor(
3335
[](at::Tensor tensor) { return at::relu(tensor); }, self);
3436
}
3537

3638
// Registered below autograd
3739
Tensor& NestedTensor_relu_(Tensor& self) {
40+
if (get_is_contiguous(self)) {
41+
#ifdef TRACEPACKED
42+
std::cout << "calling packed relu_" << std::endl;
43+
#endif
44+
Tensor buffer = get_buffer(self);
45+
at::relu_(buffer);
46+
return self;
47+
}
3848
apply_nested_tensor([](at::Tensor& tensor) { at::relu_(tensor); }, self);
3949
return self;
4050
}

0 commit comments

Comments
 (0)