Skip to content

Commit 8cc37bd

Browse files
authored
Merge pull request #144 from sterrettm2/reverse_argsort
Adds descending order sort to argsort
2 parents 0792fbd + d58de52 commit 8cc37bd

11 files changed

+115
-43
lines changed

README.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ how fast this is relative to `std::sort`.
3636

3737
## Sort an array of built-in integers and floats
3838
```cpp
39-
void x86simdsort::qsort(T* arr, size_t size, bool hasnan);
40-
void x86simdsort::qselect(T* arr, size_t k, size_t size, bool hasnan);
41-
void x86simdsort::partial_qsort(T* arr, size_t k, size_t size, bool hasnan);
39+
void x86simdsort::qsort(T* arr, size_t size, bool hasnan, bool descending);
40+
void x86simdsort::qselect(T* arr, size_t k, size_t size, bool hasnan, bool descending);
41+
void x86simdsort::partial_qsort(T* arr, size_t k, size_t size, bool hasnan, bool descending);
4242
```
4343
Supported datatypes: `T` $\in$ `[_Float16, uint16_t, int16_t, float, uint32_t,
4444
int32_t, double, uint64_t, int64_t]`
@@ -53,7 +53,7 @@ data types.
5353

5454
## Arg sort routines on arrays
5555
```cpp
56-
std::vector<size_t> arg = x86simdsort::argsort(T* arr, size_t size, bool hasnan);
56+
std::vector<size_t> arg = x86simdsort::argsort(T* arr, size_t size, bool hasnan, bool descending);
5757
std::vector<size_t> arg = x86simdsort::argselect(T* arr, size_t k, size_t size, bool hasnan);
5858
```
5959
Supported datatypes: `T` $\in$ `[_Float16, uint16_t, int16_t, float, uint32_t,

benchmarks/bench-argsort.hpp

+17
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,22 @@ static void simdargsort(benchmark::State &state, Args &&...args)
4545
}
4646
}
4747

48+
template <typename T, class... Args>
49+
static void simd_revargsort(benchmark::State &state, Args &&...args)
50+
{
51+
// get args
52+
auto args_tuple = std::make_tuple(std::move(args)...);
53+
size_t arrsize = std::get<0>(args_tuple);
54+
std::string arrtype = std::get<1>(args_tuple);
55+
// set up array
56+
std::vector<T> arr = get_array<T>(arrtype, arrsize);
57+
std::vector<size_t> inx;
58+
// benchmark
59+
for (auto _ : state) {
60+
inx = x86simdsort::argsort(arr.data(), arrsize, false, true);
61+
}
62+
}
63+
4864
template <typename T, class... Args>
4965
static void simd_ordern_argsort(benchmark::State &state, Args &&...args)
5066
{
@@ -68,6 +84,7 @@ static void simd_ordern_argsort(benchmark::State &state, Args &&...args)
6884

6985
#define BENCH_BOTH(type) \
7086
BENCH_SORT(simdargsort, type) \
87+
BENCH_SORT(simd_revargsort, type) \
7188
BENCH_SORT(simd_ordern_argsort, type) \
7289
BENCH_SORT(scalarargsort, type)
7390

lib/x86simdsort-avx2.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@
2424
avx2_partial_qsort(arr, k, arrsize, hasnan, descending); \
2525
} \
2626
template <> \
27-
std::vector<size_t> argsort(type *arr, size_t arrsize, bool hasnan) \
27+
std::vector<size_t> argsort( \
28+
type *arr, size_t arrsize, bool hasnan, bool descending) \
2829
{ \
29-
return avx2_argsort(arr, arrsize, hasnan); \
30+
return avx2_argsort(arr, arrsize, hasnan, descending); \
3031
} \
3132
template <> \
3233
std::vector<size_t> argselect( \

lib/x86simdsort-internal.h

+12-6
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@ namespace avx512 {
3030
bool descending = false);
3131
// argsort
3232
template <typename T>
33-
XSS_HIDE_SYMBOL std::vector<size_t>
34-
argsort(T *arr, size_t arrsize, bool hasnan = false);
33+
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr,
34+
size_t arrsize,
35+
bool hasnan = false,
36+
bool descending = false);
3537
// argselect
3638
template <typename T>
3739
XSS_HIDE_SYMBOL std::vector<size_t>
@@ -62,8 +64,10 @@ namespace avx2 {
6264
bool descending = false);
6365
// argsort
6466
template <typename T>
65-
XSS_HIDE_SYMBOL std::vector<size_t>
66-
argsort(T *arr, size_t arrsize, bool hasnan = false);
67+
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr,
68+
size_t arrsize,
69+
bool hasnan = false,
70+
bool descending = false);
6771
// argselect
6872
template <typename T>
6973
XSS_HIDE_SYMBOL std::vector<size_t>
@@ -94,8 +98,10 @@ namespace scalar {
9498
bool descending = false);
9599
// argsort
96100
template <typename T>
97-
XSS_HIDE_SYMBOL std::vector<size_t>
98-
argsort(T *arr, size_t arrsize, bool hasnan = false);
101+
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr,
102+
size_t arrsize,
103+
bool hasnan = false,
104+
bool descending = false);
99105
// argselect
100106
template <typename T>
101107
XSS_HIDE_SYMBOL std::vector<size_t>

lib/x86simdsort-scalar.h

+12-3
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,21 @@ namespace scalar {
7070
xss::utils::get_cmp_func<T>(hasnan, reversed));
7171
}
7272
template <typename T>
73-
std::vector<size_t> argsort(T *arr, size_t arrsize, bool hasnan)
73+
std::vector<size_t>
74+
argsort(T *arr, size_t arrsize, bool hasnan, bool reversed)
7475
{
7576
UNUSED(hasnan);
7677
std::vector<size_t> arg(arrsize);
7778
std::iota(arg.begin(), arg.end(), 0);
78-
std::sort(arg.begin(), arg.end(), compare_arg<T, std::less<T>>(arr));
79+
if (reversed) {
80+
std::sort(arg.begin(),
81+
arg.end(),
82+
compare_arg<T, std::greater<T>>(arr));
83+
}
84+
else {
85+
std::sort(
86+
arg.begin(), arg.end(), compare_arg<T, std::less<T>>(arr));
87+
}
7988
return arg;
8089
}
8190
template <typename T>
@@ -93,7 +102,7 @@ namespace scalar {
93102
template <typename T1, typename T2>
94103
void keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan)
95104
{
96-
std::vector<size_t> arg = argsort(key, arrsize, hasnan);
105+
std::vector<size_t> arg = argsort(key, arrsize, hasnan, false);
97106
utils::apply_permutation_in_place(key, arg);
98107
utils::apply_permutation_in_place(val, arg);
99108
}

lib/x86simdsort-skx.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@
2424
avx512_partial_qsort(arr, k, arrsize, hasnan, descending); \
2525
} \
2626
template <> \
27-
std::vector<size_t> argsort(type *arr, size_t arrsize, bool hasnan) \
27+
std::vector<size_t> argsort( \
28+
type *arr, size_t arrsize, bool hasnan, bool descending) \
2829
{ \
29-
return avx512_argsort(arr, arrsize, hasnan); \
30+
return avx512_argsort(arr, arrsize, hasnan, descending); \
3031
} \
3132
template <> \
3233
std::vector<size_t> argselect( \

lib/x86simdsort.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,14 @@ namespace x86simdsort {
8686
}
8787

8888
#define DECLARE_INTERNAL_argsort(TYPE) \
89-
static std::vector<size_t> (*internal_argsort##TYPE)(TYPE *, size_t, bool) \
89+
static std::vector<size_t> (*internal_argsort##TYPE)( \
90+
TYPE *, size_t, bool, bool) \
9091
= NULL; \
9192
template <> \
92-
std::vector<size_t> argsort(TYPE *arr, size_t arrsize, bool hasnan) \
93+
std::vector<size_t> argsort( \
94+
TYPE *arr, size_t arrsize, bool hasnan, bool descending) \
9395
{ \
94-
return (*internal_argsort##TYPE)(arr, arrsize, hasnan); \
96+
return (*internal_argsort##TYPE)(arr, arrsize, hasnan, descending); \
9597
}
9698

9799
#define DECLARE_INTERNAL_argselect(TYPE) \

lib/x86simdsort.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ XSS_EXPORT_SYMBOL void partial_qsort(T *arr,
3636
// argsort
3737
template <typename T>
3838
XSS_EXPORT_SYMBOL std::vector<size_t>
39-
argsort(T *arr, size_t arrsize, bool hasnan = false);
39+
argsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false);
4040

4141
// argselect
4242
template <typename T>

src/README.md

+10-10
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ Equivalent to `qsort` in
1313
`std::sort` in [C++](https://en.cppreference.com/w/cpp/algorithm/sort).
1414

1515
```cpp
16-
void avx512_qsort<T>(T* arr, size_t arrsize, bool hasnan = false);
17-
void avx2_qsort<T>(T* arr, size_t arrsize, bool hasnan = false);
16+
void avx512_qsort<T>(T* arr, size_t arrsize, bool hasnan = false, bool descending = false);
17+
void avx2_qsort<T>(T* arr, size_t arrsize, bool hasnan = false, bool descending = false);
1818
```
1919
Supported datatypes: `uint16_t`, `int16_t`, `_Float16`, `uint32_t`, `int32_t`,
2020
`float`, `uint64_t`, `int64_t` and `double`. AVX2 versions currently support
@@ -30,8 +30,8 @@ Equivalent to `std::nth_element` in
3030

3131

3232
```cpp
33-
void avx512_qselect<T>(T* arr, size_t arrsize, bool hasnan = false);
34-
void avx2_qselect<T>(T* arr, size_t arrsize, bool hasnan = false);
33+
void avx512_qselect<T>(T* arr, size_t arrsize, bool hasnan = false, bool descending = false);
34+
void avx2_qselect<T>(T* arr, size_t arrsize, bool hasnan = false, bool descending = false);
3535
```
3636
Supported datatypes: `uint16_t`, `int16_t`, `_Float16`, `uint32_t`, `int32_t`,
3737
`float`, `uint64_t`, `int64_t` and `double`. AVX2 versions currently support
@@ -46,8 +46,8 @@ Equivalent to `std::partial_sort` in
4646

4747

4848
```cpp
49-
void avx512_partial_qsort<T>(T* arr, size_t arrsize, bool hasnan = false)
50-
void avx2_partial_qsort<T>(T* arr, size_t arrsize, bool hasnan = false)
49+
void avx512_partial_qsort<T>(T* arr, size_t arrsize, bool hasnan = false, bool descending = false)
50+
void avx2_partial_qsort<T>(T* arr, size_t arrsize, bool hasnan = false, bool descending = false)
5151
```
5252
Supported datatypes: `uint16_t`, `int16_t`, `_Float16`, `uint32_t`, `int32_t`,
5353
`float`, `uint64_t`, `int64_t` and `double`. AVX2 versions currently support
@@ -61,8 +61,8 @@ Equivalent to `np.argsort` in
6161
[NumPy](https://numpy.org/doc/stable/reference/generated/numpy.argsort.html).
6262

6363
```cpp
64-
std::vector<size_t> arg = avx512_argsort<T>(T* arr, size_t arrsize);
65-
void avx512_argsort<T>(T* arr, size_t *arg, size_t arrsize);
64+
std::vector<size_t> arg = avx512_argsort<T>(T* arr, size_t arrsize, bool hasnan = false, bool descending = false);
65+
void avx512_argsort<T>(T* arr, size_t *arg, size_t arrsize, bool hasnan = false, bool descending = false);
6666
```
6767
Supported datatypes: `uint32_t`, `int32_t`, `float`, `uint64_t`, `int64_t` and
6868
`double`.
@@ -74,8 +74,8 @@ Equivalent to `np.argselect` in
7474
[NumPy](https://numpy.org/doc/stable/reference/generated/numpy.argpartition.html).
7575

7676
```cpp
77-
std::vector<size_t> arg = avx512_argsort<T>(T* arr, size_t arrsize);
78-
void avx512_argsort<T>(T* arr, size_t *arg, size_t arrsize);
77+
std::vector<size_t> arg = avx512_argselect<T>(T* arr, size_t k, size_t arrsize);
78+
void avx512_argselect<T>(T* arr, size_t *arg, size_t k, size_t arrsize);
7979
```
8080
Supported datatypes: `uint32_t`, `int32_t`, `float`, `uint64_t`, `int64_t` and
8181
`double`.

src/xss-common-argsort.h

+26-10
Original file line numberDiff line numberDiff line change
@@ -541,8 +541,11 @@ X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr,
541541

542542
/* argsort methods for 32-bit and 64-bit dtypes */
543543
template <typename T>
544-
X86_SIMD_SORT_INLINE void
545-
avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false)
544+
X86_SIMD_SORT_INLINE void avx512_argsort(T *arr,
545+
arrsize_t *arg,
546+
arrsize_t arrsize,
547+
bool hasnan = false,
548+
bool descending = false)
546549
{
547550
/* TODO optimization: on 32-bit, use zmm_vector for 32-bit dtype */
548551
using vectype = typename std::conditional<sizeof(T) == sizeof(int32_t),
@@ -558,29 +561,37 @@ avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false)
558561
if constexpr (std::is_floating_point_v<T>) {
559562
if ((hasnan) && (array_has_nan<vectype>(arr, arrsize))) {
560563
std_argsort_withnan(arr, arg, 0, arrsize);
564+
565+
if (descending) { std::reverse(arg, arg + arrsize); }
566+
561567
return;
562568
}
563569
}
564570
UNUSED(hasnan);
565571
argsort_64bit_<vectype, argtype>(
566572
arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));
573+
574+
if (descending) { std::reverse(arg, arg + arrsize); }
567575
}
568576
}
569577

570578
template <typename T>
571-
X86_SIMD_SORT_INLINE std::vector<arrsize_t>
572-
avx512_argsort(T *arr, arrsize_t arrsize, bool hasnan = false)
579+
X86_SIMD_SORT_INLINE std::vector<arrsize_t> avx512_argsort(
580+
T *arr, arrsize_t arrsize, bool hasnan = false, bool descending = false)
573581
{
574582
std::vector<arrsize_t> indices(arrsize);
575583
std::iota(indices.begin(), indices.end(), 0);
576-
avx512_argsort<T>(arr, indices.data(), arrsize, hasnan);
584+
avx512_argsort<T>(arr, indices.data(), arrsize, hasnan, descending);
577585
return indices;
578586
}
579587

580588
/* argsort methods for 32-bit and 64-bit dtypes */
581589
template <typename T>
582-
X86_SIMD_SORT_INLINE void
583-
avx2_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false)
590+
X86_SIMD_SORT_INLINE void avx2_argsort(T *arr,
591+
arrsize_t *arg,
592+
arrsize_t arrsize,
593+
bool hasnan = false,
594+
bool descending = false)
584595
{
585596
using vectype = typename std::conditional<sizeof(T) == sizeof(int32_t),
586597
avx2_half_vector<T>,
@@ -594,22 +605,27 @@ avx2_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false)
594605
if constexpr (std::is_floating_point_v<T>) {
595606
if ((hasnan) && (array_has_nan<vectype>(arr, arrsize))) {
596607
std_argsort_withnan(arr, arg, 0, arrsize);
608+
609+
if (descending) { std::reverse(arg, arg + arrsize); }
610+
597611
return;
598612
}
599613
}
600614
UNUSED(hasnan);
601615
argsort_64bit_<vectype, argtype>(
602616
arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));
617+
618+
if (descending) { std::reverse(arg, arg + arrsize); }
603619
}
604620
}
605621

606622
template <typename T>
607-
X86_SIMD_SORT_INLINE std::vector<arrsize_t>
608-
avx2_argsort(T *arr, arrsize_t arrsize, bool hasnan = false)
623+
X86_SIMD_SORT_INLINE std::vector<arrsize_t> avx2_argsort(
624+
T *arr, arrsize_t arrsize, bool hasnan = false, bool descending = false)
609625
{
610626
std::vector<arrsize_t> indices(arrsize);
611627
std::iota(indices.begin(), indices.end(), 0);
612-
avx2_argsort<T>(arr, indices.data(), arrsize, hasnan);
628+
avx2_argsort<T>(arr, indices.data(), arrsize, hasnan, descending);
613629
return indices;
614630
}
615631

tests/test-qsort.cpp

+22-2
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ TYPED_TEST_P(simdsort, test_qsort_descending)
7171
}
7272
}
7373

74-
TYPED_TEST_P(simdsort, test_argsort)
74+
TYPED_TEST_P(simdsort, test_argsort_ascending)
7575
{
7676
for (auto type : this->arrtype) {
7777
bool hasnan = (type == "rand_with_nan") ? true : false;
@@ -89,6 +89,25 @@ TYPED_TEST_P(simdsort, test_argsort)
8989
}
9090
}
9191

92+
TYPED_TEST_P(simdsort, test_argsort_descending)
93+
{
94+
for (auto type : this->arrtype) {
95+
bool hasnan = (type == "rand_with_nan") ? true : false;
96+
for (auto size : this->arrsize) {
97+
std::vector<TypeParam> arr = get_array<TypeParam>(type, size);
98+
std::vector<TypeParam> sortedarr = arr;
99+
std::sort(sortedarr.begin(),
100+
sortedarr.end(),
101+
compare<TypeParam, std::greater<TypeParam>>());
102+
auto arg = x86simdsort::argsort(
103+
arr.data(), arr.size(), hasnan, true);
104+
IS_ARG_SORTED(sortedarr, arr, arg, type);
105+
arr.clear();
106+
arg.clear();
107+
}
108+
}
109+
}
110+
92111
TYPED_TEST_P(simdsort, test_qselect_ascending)
93112
{
94113
for (auto type : this->arrtype) {
@@ -241,7 +260,8 @@ TYPED_TEST_P(simdsort, test_comparator)
241260
REGISTER_TYPED_TEST_SUITE_P(simdsort,
242261
test_qsort_ascending,
243262
test_qsort_descending,
244-
test_argsort,
263+
test_argsort_ascending,
264+
test_argsort_descending,
245265
test_argselect,
246266
test_qselect_ascending,
247267
test_qselect_descending,

0 commit comments

Comments
 (0)