Skip to content

Commit 93ef8c0

Browse files
committed
Added descending sort to argsort
1 parent 0792fbd commit 93ef8c0

11 files changed

+95
-37
lines changed

README.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ how fast this is relative to `std::sort`.
3636

3737
## Sort an array of built-in integers and floats
3838
```cpp
39-
void x86simdsort::qsort(T* arr, size_t size, bool hasnan);
40-
void x86simdsort::qselect(T* arr, size_t k, size_t size, bool hasnan);
41-
void x86simdsort::partial_qsort(T* arr, size_t k, size_t size, bool hasnan);
39+
void x86simdsort::qsort(T* arr, size_t size, bool hasnan, bool descending);
40+
void x86simdsort::qselect(T* arr, size_t k, size_t size, bool hasnan, bool descending);
41+
void x86simdsort::partial_qsort(T* arr, size_t k, size_t size, bool hasnan, bool descending);
4242
```
4343
Supported datatypes: `T` $\in$ `[_Float16, uint16_t, int16_t, float, uint32_t,
4444
int32_t, double, uint64_t, int64_t]`
@@ -53,7 +53,7 @@ data types.
5353

5454
## Arg sort routines on arrays
5555
```cpp
56-
std::vector<size_t> arg = x86simdsort::argsort(T* arr, size_t size, bool hasnan);
56+
std::vector<size_t> arg = x86simdsort::argsort(T* arr, size_t size, bool hasnan, bool descending);
5757
std::vector<size_t> arg = x86simdsort::argselect(T* arr, size_t k, size_t size, bool hasnan);
5858
```
5959
Supported datatypes: `T` $\in$ `[_Float16, uint16_t, int16_t, float, uint32_t,

benchmarks/bench-argsort.hpp

+17
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,22 @@ static void simdargsort(benchmark::State &state, Args &&...args)
4545
}
4646
}
4747

48+
template <typename T, class... Args>
49+
static void simd_revargsort(benchmark::State &state, Args &&...args)
50+
{
51+
// get args
52+
auto args_tuple = std::make_tuple(std::move(args)...);
53+
size_t arrsize = std::get<0>(args_tuple);
54+
std::string arrtype = std::get<1>(args_tuple);
55+
// set up array
56+
std::vector<T> arr = get_array<T>(arrtype, arrsize);
57+
std::vector<size_t> inx;
58+
// benchmark
59+
for (auto _ : state) {
60+
inx = x86simdsort::argsort(arr.data(), arrsize, false, true);
61+
}
62+
}
63+
4864
template <typename T, class... Args>
4965
static void simd_ordern_argsort(benchmark::State &state, Args &&...args)
5066
{
@@ -68,6 +84,7 @@ static void simd_ordern_argsort(benchmark::State &state, Args &&...args)
6884

6985
#define BENCH_BOTH(type) \
7086
BENCH_SORT(simdargsort, type) \
87+
BENCH_SORT(simd_revargsort, type) \
7188
BENCH_SORT(simd_ordern_argsort, type) \
7289
BENCH_SORT(scalarargsort, type)
7390

lib/x86simdsort-avx2.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
avx2_partial_qsort(arr, k, arrsize, hasnan, descending); \
2525
} \
2626
template <> \
27-
std::vector<size_t> argsort(type *arr, size_t arrsize, bool hasnan) \
27+
std::vector<size_t> argsort(type *arr, size_t arrsize, bool hasnan, bool descending) \
2828
{ \
29-
return avx2_argsort(arr, arrsize, hasnan); \
29+
return avx2_argsort(arr, arrsize, hasnan, descending); \
3030
} \
3131
template <> \
3232
std::vector<size_t> argselect( \

lib/x86simdsort-internal.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ namespace avx512 {
3131
// argsort
3232
template <typename T>
3333
XSS_HIDE_SYMBOL std::vector<size_t>
34-
argsort(T *arr, size_t arrsize, bool hasnan = false);
34+
argsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false);
3535
// argselect
3636
template <typename T>
3737
XSS_HIDE_SYMBOL std::vector<size_t>
@@ -63,7 +63,7 @@ namespace avx2 {
6363
// argsort
6464
template <typename T>
6565
XSS_HIDE_SYMBOL std::vector<size_t>
66-
argsort(T *arr, size_t arrsize, bool hasnan = false);
66+
argsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false);
6767
// argselect
6868
template <typename T>
6969
XSS_HIDE_SYMBOL std::vector<size_t>
@@ -95,7 +95,7 @@ namespace scalar {
9595
// argsort
9696
template <typename T>
9797
XSS_HIDE_SYMBOL std::vector<size_t>
98-
argsort(T *arr, size_t arrsize, bool hasnan = false);
98+
argsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false);
9999
// argselect
100100
template <typename T>
101101
XSS_HIDE_SYMBOL std::vector<size_t>

lib/x86simdsort-scalar.h

+7-3
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,16 @@ namespace scalar {
7070
xss::utils::get_cmp_func<T>(hasnan, reversed));
7171
}
7272
template <typename T>
73-
std::vector<size_t> argsort(T *arr, size_t arrsize, bool hasnan)
73+
std::vector<size_t> argsort(T *arr, size_t arrsize, bool hasnan, bool reversed)
7474
{
7575
UNUSED(hasnan);
7676
std::vector<size_t> arg(arrsize);
7777
std::iota(arg.begin(), arg.end(), 0);
78-
std::sort(arg.begin(), arg.end(), compare_arg<T, std::less<T>>(arr));
78+
if (reversed){
79+
std::sort(arg.begin(), arg.end(), compare_arg<T, std::greater<T>>(arr));
80+
}else{
81+
std::sort(arg.begin(), arg.end(), compare_arg<T, std::less<T>>(arr));
82+
}
7983
return arg;
8084
}
8185
template <typename T>
@@ -93,7 +97,7 @@ namespace scalar {
9397
template <typename T1, typename T2>
9498
void keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan)
9599
{
96-
std::vector<size_t> arg = argsort(key, arrsize, hasnan);
100+
std::vector<size_t> arg = argsort(key, arrsize, hasnan, false);
97101
utils::apply_permutation_in_place(key, arg);
98102
utils::apply_permutation_in_place(val, arg);
99103
}

lib/x86simdsort-skx.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
avx512_partial_qsort(arr, k, arrsize, hasnan, descending); \
2525
} \
2626
template <> \
27-
std::vector<size_t> argsort(type *arr, size_t arrsize, bool hasnan) \
27+
std::vector<size_t> argsort(type *arr, size_t arrsize, bool hasnan, bool descending) \
2828
{ \
29-
return avx512_argsort(arr, arrsize, hasnan); \
29+
return avx512_argsort(arr, arrsize, hasnan, descending); \
3030
} \
3131
template <> \
3232
std::vector<size_t> argselect( \

lib/x86simdsort.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,12 @@ namespace x86simdsort {
8686
}
8787

8888
#define DECLARE_INTERNAL_argsort(TYPE) \
89-
static std::vector<size_t> (*internal_argsort##TYPE)(TYPE *, size_t, bool) \
89+
static std::vector<size_t> (*internal_argsort##TYPE)(TYPE *, size_t, bool, bool) \
9090
= NULL; \
9191
template <> \
92-
std::vector<size_t> argsort(TYPE *arr, size_t arrsize, bool hasnan) \
92+
std::vector<size_t> argsort(TYPE *arr, size_t arrsize, bool hasnan, bool descending) \
9393
{ \
94-
return (*internal_argsort##TYPE)(arr, arrsize, hasnan); \
94+
return (*internal_argsort##TYPE)(arr, arrsize, hasnan, descending); \
9595
}
9696

9797
#define DECLARE_INTERNAL_argselect(TYPE) \

lib/x86simdsort.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ XSS_EXPORT_SYMBOL void partial_qsort(T *arr,
3636
// argsort
3737
template <typename T>
3838
XSS_EXPORT_SYMBOL std::vector<size_t>
39-
argsort(T *arr, size_t arrsize, bool hasnan = false);
39+
argsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false);
4040

4141
// argselect
4242
template <typename T>

src/README.md

+10-10
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ Equivalent to `qsort` in
1313
`std::sort` in [C++](https://en.cppreference.com/w/cpp/algorithm/sort).
1414

1515
```cpp
16-
void avx512_qsort<T>(T* arr, size_t arrsize, bool hasnan = false);
17-
void avx2_qsort<T>(T* arr, size_t arrsize, bool hasnan = false);
16+
void avx512_qsort<T>(T* arr, size_t arrsize, bool hasnan = false, bool descending = false);
17+
void avx2_qsort<T>(T* arr, size_t arrsize, bool hasnan = false, bool descending = false);
1818
```
1919
Supported datatypes: `uint16_t`, `int16_t`, `_Float16`, `uint32_t`, `int32_t`,
2020
`float`, `uint64_t`, `int64_t` and `double`. AVX2 versions currently support
@@ -30,8 +30,8 @@ Equivalent to `std::nth_element` in
3030

3131

3232
```cpp
33-
void avx512_qselect<T>(T* arr, size_t arrsize, bool hasnan = false);
34-
void avx2_qselect<T>(T* arr, size_t arrsize, bool hasnan = false);
33+
void avx512_qselect<T>(T* arr, size_t arrsize, bool hasnan = false, bool descending = false);
34+
void avx2_qselect<T>(T* arr, size_t arrsize, bool hasnan = false, bool descending = false);
3535
```
3636
Supported datatypes: `uint16_t`, `int16_t`, `_Float16`, `uint32_t`, `int32_t`,
3737
`float`, `uint64_t`, `int64_t` and `double`. AVX2 versions currently support
@@ -46,8 +46,8 @@ Equivalent to `std::partial_sort` in
4646

4747

4848
```cpp
49-
void avx512_partial_qsort<T>(T* arr, size_t arrsize, bool hasnan = false)
50-
void avx2_partial_qsort<T>(T* arr, size_t arrsize, bool hasnan = false)
49+
void avx512_partial_qsort<T>(T* arr, size_t arrsize, bool hasnan = false, bool descending = false)
50+
void avx2_partial_qsort<T>(T* arr, size_t arrsize, bool hasnan = false, bool descending = false)
5151
```
5252
Supported datatypes: `uint16_t`, `int16_t`, `_Float16`, `uint32_t`, `int32_t`,
5353
`float`, `uint64_t`, `int64_t` and `double`. AVX2 versions currently support
@@ -61,8 +61,8 @@ Equivalent to `np.argsort` in
6161
[NumPy](https://numpy.org/doc/stable/reference/generated/numpy.argsort.html).
6262

6363
```cpp
64-
std::vector<size_t> arg = avx512_argsort<T>(T* arr, size_t arrsize);
65-
void avx512_argsort<T>(T* arr, size_t *arg, size_t arrsize);
64+
std::vector<size_t> arg = avx512_argsort<T>(T* arr, size_t arrsize, bool hasnan = false, bool descending = false);
65+
void avx512_argsort<T>(T* arr, size_t *arg, size_t arrsize, bool hasnan = false, bool descending = false);
6666
```
6767
Supported datatypes: `uint32_t`, `int32_t`, `float`, `uint64_t`, `int64_t` and
6868
`double`.
@@ -74,8 +74,8 @@ Equivalent to `np.argselect` in
7474
[NumPy](https://numpy.org/doc/stable/reference/generated/numpy.argpartition.html).
7575

7676
```cpp
77-
std::vector<size_t> arg = avx512_argsort<T>(T* arr, size_t arrsize);
78-
void avx512_argsort<T>(T* arr, size_t *arg, size_t arrsize);
77+
std::vector<size_t> arg = avx512_argselect<T>(T* arr, size_t k, size_t arrsize);
78+
void avx512_argselect<T>(T* arr, size_t *arg, size_t k, size_t arrsize);
7979
```
8080
Supported datatypes: `uint32_t`, `int32_t`, `float`, `uint64_t`, `int64_t` and
8181
`double`.

src/xss-common-argsort.h

+25-7
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,7 @@ X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr,
542542
/* argsort methods for 32-bit and 64-bit dtypes */
543543
template <typename T>
544544
X86_SIMD_SORT_INLINE void
545-
avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false)
545+
avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false, bool descending = false)
546546
{
547547
/* TODO optimization: on 32-bit, use zmm_vector for 32-bit dtype */
548548
using vectype = typename std::conditional<sizeof(T) == sizeof(int32_t),
@@ -558,29 +558,38 @@ avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false)
558558
if constexpr (std::is_floating_point_v<T>) {
559559
if ((hasnan) && (array_has_nan<vectype>(arr, arrsize))) {
560560
std_argsort_withnan(arr, arg, 0, arrsize);
561+
562+
if (descending){
563+
std::reverse(arg, arg + arrsize);
564+
}
565+
561566
return;
562567
}
563568
}
564569
UNUSED(hasnan);
565570
argsort_64bit_<vectype, argtype>(
566571
arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));
572+
573+
if (descending){
574+
std::reverse(arg, arg + arrsize);
575+
}
567576
}
568577
}
569578

570579
template <typename T>
571580
X86_SIMD_SORT_INLINE std::vector<arrsize_t>
572-
avx512_argsort(T *arr, arrsize_t arrsize, bool hasnan = false)
581+
avx512_argsort(T *arr, arrsize_t arrsize, bool hasnan = false, bool descending = false)
573582
{
574583
std::vector<arrsize_t> indices(arrsize);
575584
std::iota(indices.begin(), indices.end(), 0);
576-
avx512_argsort<T>(arr, indices.data(), arrsize, hasnan);
585+
avx512_argsort<T>(arr, indices.data(), arrsize, hasnan, descending);
577586
return indices;
578587
}
579588

580589
/* argsort methods for 32-bit and 64-bit dtypes */
581590
template <typename T>
582591
X86_SIMD_SORT_INLINE void
583-
avx2_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false)
592+
avx2_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false, bool descending = false)
584593
{
585594
using vectype = typename std::conditional<sizeof(T) == sizeof(int32_t),
586595
avx2_half_vector<T>,
@@ -594,22 +603,31 @@ avx2_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false)
594603
if constexpr (std::is_floating_point_v<T>) {
595604
if ((hasnan) && (array_has_nan<vectype>(arr, arrsize))) {
596605
std_argsort_withnan(arr, arg, 0, arrsize);
606+
607+
if (descending){
608+
std::reverse(arg, arg + arrsize);
609+
}
610+
597611
return;
598612
}
599613
}
600614
UNUSED(hasnan);
601615
argsort_64bit_<vectype, argtype>(
602616
arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));
617+
618+
if (descending){
619+
std::reverse(arg, arg + arrsize);
620+
}
603621
}
604622
}
605623

606624
template <typename T>
607625
X86_SIMD_SORT_INLINE std::vector<arrsize_t>
608-
avx2_argsort(T *arr, arrsize_t arrsize, bool hasnan = false)
626+
avx2_argsort(T *arr, arrsize_t arrsize, bool hasnan = false, bool descending = false)
609627
{
610628
std::vector<arrsize_t> indices(arrsize);
611629
std::iota(indices.begin(), indices.end(), 0);
612-
avx2_argsort<T>(arr, indices.data(), arrsize, hasnan);
630+
avx2_argsort<T>(arr, indices.data(), arrsize, hasnan, descending);
613631
return indices;
614632
}
615633

@@ -631,7 +649,7 @@ X86_SIMD_SORT_INLINE void avx512_argselect(T *arr,
631649
ymm_vector<arrsize_t>,
632650
zmm_vector<arrsize_t>>::type;
633651

634-
if (arrsize > 1) {
652+
if (arrsize > 1) {
635653
if constexpr (std::is_floating_point_v<T>) {
636654
if ((hasnan) && (array_has_nan<vectype>(arr, arrsize))) {
637655
std_argselect_withnan(arr, arg, k, 0, arrsize);

tests/test-qsort.cpp

+21-2
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ TYPED_TEST_P(simdsort, test_qsort_descending)
7171
}
7272
}
7373

74-
TYPED_TEST_P(simdsort, test_argsort)
74+
TYPED_TEST_P(simdsort, test_argsort_ascending)
7575
{
7676
for (auto type : this->arrtype) {
7777
bool hasnan = (type == "rand_with_nan") ? true : false;
@@ -89,6 +89,24 @@ TYPED_TEST_P(simdsort, test_argsort)
8989
}
9090
}
9191

92+
TYPED_TEST_P(simdsort, test_argsort_descending)
93+
{
94+
for (auto type : this->arrtype) {
95+
bool hasnan = (type == "rand_with_nan") ? true : false;
96+
for (auto size : this->arrsize) {
97+
std::vector<TypeParam> arr = get_array<TypeParam>(type, size);
98+
std::vector<TypeParam> sortedarr = arr;
99+
std::sort(sortedarr.begin(),
100+
sortedarr.end(),
101+
compare<TypeParam, std::greater<TypeParam>>());
102+
auto arg = x86simdsort::argsort(arr.data(), arr.size(), hasnan, true);
103+
IS_ARG_SORTED(sortedarr, arr, arg, type);
104+
arr.clear();
105+
arg.clear();
106+
}
107+
}
108+
}
109+
92110
TYPED_TEST_P(simdsort, test_qselect_ascending)
93111
{
94112
for (auto type : this->arrtype) {
@@ -241,7 +259,8 @@ TYPED_TEST_P(simdsort, test_comparator)
241259
REGISTER_TYPED_TEST_SUITE_P(simdsort,
242260
test_qsort_ascending,
243261
test_qsort_descending,
244-
test_argsort,
262+
test_argsort_ascending,
263+
test_argsort_descending,
245264
test_argselect,
246265
test_qselect_ascending,
247266
test_qselect_descending,

0 commit comments

Comments
 (0)