Skip to content

Commit 4081fdd

Browse files
authored
Revert "port masked_select from TH to ATen and optimize perf on CPU (pytorch#33269)" (pytorch#41829)
This reverts commit fe66bdb. This also makes a sense to THTensorEvenMoreMath because sumall was removed, see THTensor_wrap.
1 parent cefb9e0 commit 4081fdd

File tree

8 files changed

+106
-162
lines changed

8 files changed

+106
-162
lines changed

aten/src/ATen/Declarations.cwrap

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,38 @@
5656
- THBoolTensor* mask
5757
- THTensor* source
5858
]]
59+
[[
60+
name: _th_masked_select
61+
cname: maskedSelect
62+
cpu_bool: True
63+
cpu_bfloat16: True
64+
variants:
65+
- function
66+
backends:
67+
- CPU
68+
return: argument 0
69+
arguments:
70+
- arg: THTensor* result
71+
output: True
72+
- THTensor* self
73+
- THByteTensor* mask
74+
]]
75+
[[
76+
name: _th_masked_select_bool
77+
cname: maskedSelectBool
78+
cpu_bool: True
79+
cpu_bfloat16: True
80+
variants:
81+
- function
82+
backends:
83+
- CPU
84+
return: argument 0
85+
arguments:
86+
- arg: THTensor* result
87+
output: True
88+
- THTensor* self
89+
- THBoolTensor* mask
90+
]]
5991
[[
6092
name: _th_nonzero
6193
cname: nonzero

aten/src/ATen/native/LegacyDefinitions.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,32 @@ Tensor & masked_scatter__cpu(Tensor& self, const Tensor & mask, const Tensor & s
2222
}
2323
}
2424

25+
Tensor masked_select_cpu(const Tensor & self, const Tensor & mask) {
26+
namedinference::compute_broadcast_outnames(self, mask);
27+
28+
Tensor b_self, b_mask;
29+
std::tie(b_self, b_mask) = expand_outplace(self, mask, "masked_select");
30+
if (b_mask.dtype() == at::ScalarType::Byte) {
31+
TORCH_WARN("masked_select received a mask with dtype torch.uint8, this behavior is now deprecated," \
32+
"please use a mask with dtype torch.bool instead.");
33+
return legacy::cpu::_th_masked_select(b_self, b_mask);
34+
} else {
35+
return legacy::cpu::_th_masked_select_bool(b_self, b_mask);
36+
}
37+
}
38+
39+
Tensor & masked_select_out_cpu(Tensor & result, const Tensor & self, const Tensor & mask) {
40+
namedinference::compute_broadcast_outnames(self, mask);
41+
42+
Tensor b_self, b_mask;
43+
std::tie(b_self, b_mask) = expand_outplace(self, mask, "masked_select_out");
44+
if (b_mask.dtype() == at::ScalarType::Bool) {
45+
return legacy::cpu::_th_masked_select_bool_out(result, b_self, b_mask);
46+
} else {
47+
return legacy::cpu::_th_masked_select_out(result, b_self, b_mask);
48+
}
49+
}
50+
2551
Tensor argsort(const Tensor & self, int64_t dim, bool descending) {
2652
return std::get<1>(at::sort(self, dim, descending));
2753
}

aten/src/ATen/native/TensorAdvancedIndexing.cpp

Lines changed: 0 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,6 @@ DEFINE_DISPATCH(index_put_stub);
7171
DEFINE_DISPATCH(index_put_accum_stub);
7272
DEFINE_DISPATCH(masked_fill_stub);
7373
REGISTER_NO_CPU_DISPATCH(index_put_accum_stub, index_put_accum_fn);
74-
DEFINE_DISPATCH(masked_select_serial_stub);
75-
DEFINE_DISPATCH(masked_select_stub);
7674

7775
DEFINE_DISPATCH(gather_stub);
7876
DEFINE_DISPATCH(scatter_stub);
@@ -629,82 +627,6 @@ Tensor masked_fill(const Tensor & self, const Tensor & mask, const Tensor & sour
629627
return result;
630628
}
631629

632-
static Tensor & masked_select_out_impl_cpu(Tensor & result, const Tensor & self, const Tensor & mask) {
633-
NoNamesGuard guard;
634-
635-
TORCH_CHECK(mask.scalar_type() == ScalarType::Byte || mask.scalar_type() == ScalarType::Bool,
636-
"masked_select: expected BoolTensor or ByteTensor for mask");
637-
TORCH_CHECK(self.scalar_type() == result.scalar_type(),
638-
"masked_select(): self and result must have the same scalar type");
639-
640-
if (mask.dtype() == at::ScalarType::Byte) {
641-
TORCH_WARN("masked_select received a mask with dtype torch.uint8, this behavior is now deprecated," \
642-
"please use a mask with dtype torch.bool instead.");
643-
}
644-
645-
Tensor _mask, _self;
646-
std::tie(_mask, _self) = expand_outplace(mask, self);
647-
648-
auto shape = _self.sizes();
649-
int64_t numel = _mask.sum().item().toLong();
650-
result.resize_({numel});
651-
if (numel == 0) {
652-
return result;
653-
}
654-
655-
// Create strided view of result before feeding into TensorIterator
656-
auto strides = DimVector(shape.size(), 0);
657-
auto result_strided = result.as_strided(shape, strides);
658-
659-
// serial kernel
660-
bool use_serial_kernel = self.numel() < at::internal::GRAIN_SIZE || at::get_num_threads() == 1;
661-
if (use_serial_kernel) {
662-
auto iter = TensorIteratorConfig()
663-
.check_all_same_dtype(false)
664-
.resize_outputs(false)
665-
.add_output(result_strided)
666-
.add_input(_self)
667-
.add_input(_mask)
668-
.build();
669-
670-
masked_select_serial_stub(iter.device_type(), iter);
671-
return result;
672-
}
673-
674-
// Use a prefix sum to record the output locations of the masked elements,
675-
// so as to parallel with TensorIterator.
676-
auto mask_long = at::empty(shape, self.options().dtype(at::kLong)).copy_(_mask);
677-
auto mask_prefix_sum = at::empty(shape, self.options().dtype(at::kLong));
678-
auto mask_long_data = mask_long.data_ptr<int64_t>();
679-
auto mask_prefix_sum_data = mask_prefix_sum.data_ptr<int64_t>();
680-
// TODO: Here can only use std::partial_sum for C++14,
681-
// use std::exclusive_scan when PyTorch upgrades to C++17, which have better peformance.
682-
// std::exclusive_scan(mask_long_data, mask_long_data + mask_long.numel(), mask_prefix_sum_data, 0);
683-
std::partial_sum(mask_long_data, mask_long_data + mask_long.numel(), mask_prefix_sum_data);
684-
685-
auto iter = TensorIteratorConfig()
686-
.check_all_same_dtype(false)
687-
.resize_outputs(false)
688-
.add_output(result_strided)
689-
.add_input(_self)
690-
.add_input(_mask)
691-
.add_input(mask_prefix_sum)
692-
.build();
693-
694-
masked_select_stub(iter.device_type(), iter);
695-
return result;
696-
}
697-
698-
Tensor & masked_select_out_cpu(Tensor & result, const Tensor & self, const Tensor & mask) {
699-
namedinference::compute_broadcast_outnames(self, mask);
700-
return masked_select_out_impl_cpu(result, self, mask);
701-
}
702-
703-
Tensor masked_select_cpu(const Tensor & self, const Tensor & mask) {
704-
Tensor result = at::empty({0}, self.options());
705-
return masked_select_out_cpu(result, self, mask);
706-
}
707-
708630
Tensor _gather_sparse_backward(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& grad){
709631
// special case scalar input and/or index
710632
if (self.ndimension() == 0) return at::_sparse_coo_tensor_unsafe(at::empty({0,grad.numel()}, index.options()), grad, self.sizes());

aten/src/ATen/native/TensorAdvancedIndexing.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ using index_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRe
1515
using index_put_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides, bool accumulate);
1616
using index_put_accum_fn = void(*)(Tensor &, TensorList , const Tensor &, bool unsafe);
1717
using masked_fill_fn = void(*)(TensorIterator &, Scalar scalar);
18-
using masked_select_fn = void(*)(TensorIterator &);
1918

2019
using gather_fn = void (*)(Tensor & result, const Tensor & self, int64_t dim, const Tensor & index);
2120
using scatter_fn = void(*)(Tensor& self, int64_t dim, const Tensor& index, const Tensor& src);
@@ -26,8 +25,6 @@ DECLARE_DISPATCH(index_fn, index_stub);
2625
DECLARE_DISPATCH(index_put_fn, index_put_stub);
2726
DECLARE_DISPATCH(index_put_accum_fn, index_put_accum_stub);
2827
DECLARE_DISPATCH(masked_fill_fn, masked_fill_stub);
29-
DECLARE_DISPATCH(masked_select_fn, masked_select_serial_stub);
30-
DECLARE_DISPATCH(masked_select_fn, masked_select_stub);
3128

3229
DECLARE_DISPATCH(gather_fn, gather_stub);
3330
DECLARE_DISPATCH(scatter_fn, scatter_stub);

aten/src/ATen/native/cpu/IndexKernel.cpp

Lines changed: 1 addition & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -163,90 +163,11 @@ void masked_fill_kernel(TensorIterator& iter, Scalar value) {
163163
});
164164
}
165165

166-
template <typename scalar_t, typename mask_t, typename func_t>
167-
void cpu_masked_select_serial_kernel(TensorIterator& iter, const func_t& f) {
168-
auto is_mask_bool = std::is_same<mask_t, bool>::value;
169-
int64_t offset = 0;
170-
auto loop = [&](char** data, const int64_t* strides, int64_t n) {
171-
char* dst = data[0];
172-
char* src = data[1];
173-
char* mask = data[2];
174-
for (int64_t i = 0; i < n; i++) {
175-
mask_t mask_value = *(mask_t*)(mask + strides[2] * i);
176-
if (!is_mask_bool) {
177-
TORCH_CHECK(mask_value == 0 || mask_value == 1, "Mask tensor can take 0 and 1 values only");
178-
}
179-
if (mask_value) {
180-
int64_t offset_bytes = offset * sizeof(scalar_t);
181-
f(dst, src + strides[1] * i, offset_bytes);
182-
offset++;
183-
}
184-
}
185-
};
186-
iter.serial_for_each(loop, {0, iter.numel()});
187-
}
188-
189-
void masked_select_serial_kernel(TensorIterator& iter) {
190-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Bool, at::ScalarType::BFloat16,
191-
iter.dtype(), "masked_select", [&] {
192-
auto mask_dtype = iter.input_dtype(1);
193-
if (mask_dtype == at::ScalarType::Bool) {
194-
cpu_masked_select_serial_kernel<scalar_t, bool>(iter, [](char* dst, char* src, int64_t offset) {
195-
*(scalar_t*)(dst + offset) = *(scalar_t*)src;
196-
});
197-
} else {
198-
cpu_masked_select_serial_kernel<scalar_t, unsigned char>(iter, [](char* dst, char* src, int64_t offset) {
199-
*(scalar_t*)(dst + offset) = *(scalar_t*)src;
200-
});
201-
}
202-
});
203-
}
204-
205-
template <typename scalar_t, typename mask_t, typename func_t>
206-
void cpu_masked_select_kernel(TensorIterator& iter, const func_t& f) {
207-
auto is_mask_bool = std::is_same<mask_t, bool>::value;
208-
auto loop = [&](char** data, const int64_t* strides, int64_t n) {
209-
char* dst = data[0];
210-
char* src = data[1];
211-
char* mask = data[2];
212-
char* mask_prefix_sum = data[3];
213-
for (int64_t i = 0; i < n; i++) {
214-
mask_t mask_value = *(mask_t*)(mask + strides[2] * i);
215-
if (!is_mask_bool) {
216-
TORCH_CHECK(mask_value == 0 || mask_value == 1, "Mask tensor can take 0 and 1 values only");
217-
}
218-
if (mask_value) {
219-
int64_t offset = *(int64_t*)(mask_prefix_sum + strides[3] * i);
220-
int64_t offset_bytes = (offset - 1) * sizeof(scalar_t);
221-
f(dst, src + strides[1] * i, offset_bytes);
222-
}
223-
}
224-
};
225-
iter.for_each(loop);
226-
}
227-
228-
void masked_select_kernel(TensorIterator& iter) {
229-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Bool, at::ScalarType::BFloat16,
230-
iter.dtype(), "masked_select", [&] {
231-
auto mask_dtype = iter.input_dtype(1);
232-
if (mask_dtype == at::ScalarType::Bool) {
233-
cpu_masked_select_kernel<scalar_t, bool>(iter, [](char* dst, char* src, int64_t offset) {
234-
*(scalar_t*)(dst + offset) = *(scalar_t*)src;
235-
});
236-
} else {
237-
cpu_masked_select_kernel<scalar_t, unsigned char>(iter, [](char* dst, char* src, int64_t offset) {
238-
*(scalar_t*)(dst + offset) = *(scalar_t*)src;
239-
});
240-
}
241-
});
242-
}
243-
244166
} // anonymous namespace
245167

168+
246169
REGISTER_DISPATCH(index_stub, &index_kernel);
247170
REGISTER_DISPATCH(index_put_stub, &index_put_kernel);
248171
REGISTER_DISPATCH(masked_fill_stub, &masked_fill_kernel);
249-
REGISTER_DISPATCH(masked_select_serial_stub, &masked_select_serial_kernel);
250-
REGISTER_DISPATCH(masked_select_stub, &masked_select_kernel);
251172

252173
}} // namespace at::native

aten/src/TH/generic/THTensorEvenMoreMath.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,50 @@ accreal THTensor_(dot)(THTensor *tensor, THTensor *src)
9999

100100
#if !defined(TH_REAL_IS_HALF) /* non half part */
101101

102+
void THTensor_(maskedSelect)(THTensor *tensor, THTensor *src, THByteTensor *mask)
103+
{
104+
at::NoNamesGuard guard;
105+
ptrdiff_t numel = THTensor_wrap(mask).sum().item<int64_t>();
106+
scalar_t *tensor_data;
107+
108+
#ifdef DEBUG
109+
THAssert(numel <= LONG_MAX);
110+
#endif
111+
THTensor_(resize1d)(tensor,numel);
112+
tensor_data = tensor->data<scalar_t>();
113+
TH_TENSOR_APPLY2(scalar_t, src, unsigned char, mask,
114+
if (*mask_data > 1)
115+
{
116+
THFree(mask_counter);
117+
THFree(src_counter);
118+
THError("Mask tensor can take 0 and 1 values only");
119+
}
120+
else if (*mask_data == 1)
121+
{
122+
*tensor_data = *src_data;
123+
tensor_data++;
124+
});
125+
}
126+
127+
void THTensor_(maskedSelectBool)(THTensor *tensor, THTensor *src, THBoolTensor *mask)
128+
{
129+
at::NoNamesGuard guard;
130+
ptrdiff_t numel = THTensor_wrap(mask).sum().item<int64_t>();
131+
scalar_t *tensor_data;
132+
133+
#ifdef DEBUG
134+
THAssert(numel <= LONG_MAX);
135+
#endif
136+
THTensor_(resize1d)(tensor,numel);
137+
tensor_data = tensor->data<scalar_t>();
138+
TH_TENSOR_APPLY2(scalar_t, src, bool, mask,
139+
if (*mask_data)
140+
{
141+
*tensor_data = *src_data;
142+
tensor_data++;
143+
});
144+
}
145+
102146
void THTensor_(maskedCopy)(THTensor *tensor, THByteTensor *mask, THTensor* src )
103147
{
104148
THTensor *srct = THTensor_(newContiguous)(src);

aten/src/TH/generic/THTensorMath.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ TH_API int THTensor_(equal)(THTensor *ta, THTensor *tb);
99

1010
#if !defined(TH_REAL_IS_HALF)
1111

12+
TH_API void THTensor_(maskedSelect)(THTensor *tensor, THTensor* src, THByteTensor *mask);
13+
TH_API void THTensor_(maskedSelectBool)(THTensor *tensor, THTensor* src, THBoolTensor *mask);
1214
TH_API void THTensor_(maskedCopy)(THTensor *tensor, THByteTensor *mask, THTensor* src);
1315
TH_API void THTensor_(maskedCopyBool)(THTensor *tensor, THBoolTensor *mask, THTensor* src);
1416

test/test_torch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12644,7 +12644,7 @@ def test_masked_select(self, device, dtype):
1264412644
src = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=dtype, device=device)
1264512645
mask = torch.rand(num_src, device=device).clamp(0, 1).mul(2).floor().to(maskType)
1264612646

12647-
if dtype == torch.half and torch.device(device).type == 'cpu':
12647+
if (dtype.is_complex or dtype == torch.half) and torch.device(device).type == 'cpu':
1264812648
self.assertRaises(RuntimeError, lambda: src.masked_select(mask))
1264912649
continue
1265012650

0 commit comments

Comments
 (0)