Skip to content

Commit 2315766

Browse files
authored
Merge pull request #151 from sterrettm2/kv-select
Adds support for kv-select, kv-partial sort, and descending order for all key-value functions.
2 parents 6621ac3 + f436aae commit 2315766

14 files changed

+827
-111
lines changed

README.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ int32_t, double, uint64_t, int64_t]`
4545
4646
## Key-value sort routines on pairs of arrays
4747
```cpp
48-
void x86simdsort::keyvalue_qsort(T1* key, T2* val, size_t size, bool hasnan);
48+
void x86simdsort::keyvalue_qsort(T1* key, T2* val, size_t size, bool hasnan, bool descending);
49+
void x86simdsort::keyvalue_select(T1* key, T2* val, size_t k, size_t size, bool hasnan, bool descending);
50+
void x86simdsort::keyvalue_partial_sort(T1* key, T2* val, size_t k, size_t size, bool hasnan, bool descending);
4951
```
5052
Supported datatypes: `T1`, `T2` $\in$ `[float, uint32_t, int32_t, double,
5153
uint64_t, int64_t]` Note that keyvalue sort is not yet supported for 16-bit

benchmarks/bench-keyvalue.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ static void scalarkvsort(benchmark::State &state, Args &&...args)
1313
std::vector<T> key_bkp = key;
1414
// benchmark
1515
for (auto _ : state) {
16-
xss::scalar::keyvalue_qsort(key.data(), val.data(), arrsize, false);
16+
xss::scalar::keyvalue_qsort(
17+
key.data(), val.data(), arrsize, false, false);
1718
state.PauseTiming();
1819
key = key_bkp;
1920
state.ResumeTiming();

lib/x86simdsort-avx2.cpp

+32-22
Original file line numberDiff line numberDiff line change
@@ -34,38 +34,48 @@
3434
return x86simdsortStatic::argselect(arr, k, arrsize, hasnan); \
3535
}
3636

37-
#define DEFINE_KEYVALUE_METHODS(type) \
38-
template <> \
39-
void keyvalue_qsort(type *key, uint64_t *val, size_t arrsize, bool hasnan) \
40-
{ \
41-
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
42-
} \
43-
template <> \
44-
void keyvalue_qsort(type *key, int64_t *val, size_t arrsize, bool hasnan) \
45-
{ \
46-
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
47-
} \
48-
template <> \
49-
void keyvalue_qsort(type *key, double *val, size_t arrsize, bool hasnan) \
50-
{ \
51-
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
52-
} \
37+
#define DEFINE_KEYVALUE_METHODS_BASE(type1, type2) \
5338
template <> \
54-
void keyvalue_qsort(type *key, uint32_t *val, size_t arrsize, bool hasnan) \
39+
void keyvalue_qsort(type1 *key, \
40+
type2 *val, \
41+
size_t arrsize, \
42+
bool hasnan, \
43+
bool descending) \
5544
{ \
56-
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
45+
x86simdsortStatic::keyvalue_qsort( \
46+
key, val, arrsize, hasnan, descending); \
5747
} \
5848
template <> \
59-
void keyvalue_qsort(type *key, int32_t *val, size_t arrsize, bool hasnan) \
49+
void keyvalue_select(type1 *key, \
50+
type2 *val, \
51+
size_t k, \
52+
size_t arrsize, \
53+
bool hasnan, \
54+
bool descending) \
6055
{ \
61-
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
56+
x86simdsortStatic::keyvalue_select( \
57+
key, val, k, arrsize, hasnan, descending); \
6258
} \
6359
template <> \
64-
void keyvalue_qsort(type *key, float *val, size_t arrsize, bool hasnan) \
60+
void keyvalue_partial_sort(type1 *key, \
61+
type2 *val, \
62+
size_t k, \
63+
size_t arrsize, \
64+
bool hasnan, \
65+
bool descending) \
6566
{ \
66-
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
67+
x86simdsortStatic::keyvalue_partial_sort( \
68+
key, val, k, arrsize, hasnan, descending); \
6769
}
6870

71+
#define DEFINE_KEYVALUE_METHODS(type) \
72+
DEFINE_KEYVALUE_METHODS_BASE(type, uint64_t) \
73+
DEFINE_KEYVALUE_METHODS_BASE(type, int64_t) \
74+
DEFINE_KEYVALUE_METHODS_BASE(type, double) \
75+
DEFINE_KEYVALUE_METHODS_BASE(type, uint32_t) \
76+
DEFINE_KEYVALUE_METHODS_BASE(type, int32_t) \
77+
DEFINE_KEYVALUE_METHODS_BASE(type, float)
78+
6979
namespace xss {
7080
namespace avx2 {
7181
DEFINE_ALL_METHODS(uint32_t)

lib/x86simdsort-internal.h

+63-6
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,41 @@ namespace avx512 {
1212
qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false);
1313
// key-value quicksort
1414
template <typename T1, typename T2>
15-
XSS_HIDE_SYMBOL void
16-
keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false);
15+
XSS_HIDE_SYMBOL void keyvalue_qsort(T1 *key,
16+
T2 *val,
17+
size_t arrsize,
18+
bool hasnan = false,
19+
bool descending = false);
1720
// quickselect
1821
template <typename T>
1922
XSS_HIDE_SYMBOL void qselect(T *arr,
2023
size_t k,
2124
size_t arrsize,
2225
bool hasnan = false,
2326
bool descending = false);
27+
// key-value select
28+
template <typename T1, typename T2>
29+
XSS_HIDE_SYMBOL void keyvalue_select(T1 *key,
30+
T2 *val,
31+
size_t k,
32+
size_t arrsize,
33+
bool hasnan = false,
34+
bool descending = false);
2435
// partial sort
2536
template <typename T>
2637
XSS_HIDE_SYMBOL void partial_qsort(T *arr,
2738
size_t k,
2839
size_t arrsize,
2940
bool hasnan = false,
3041
bool descending = false);
42+
// key-value partial sort
43+
template <typename T1, typename T2>
44+
XSS_HIDE_SYMBOL void keyvalue_partial_sort(T1 *key,
45+
T2 *val,
46+
size_t k,
47+
size_t arrsize,
48+
bool hasnan = false,
49+
bool descending = false);
3150
// argsort
3251
template <typename T>
3352
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr,
@@ -46,22 +65,41 @@ namespace avx2 {
4665
qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false);
4766
// key-value quicksort
4867
template <typename T1, typename T2>
49-
XSS_HIDE_SYMBOL void
50-
keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false);
68+
XSS_HIDE_SYMBOL void keyvalue_qsort(T1 *key,
69+
T2 *val,
70+
size_t arrsize,
71+
bool hasnan = false,
72+
bool descending = false);
5173
// quickselect
5274
template <typename T>
5375
XSS_HIDE_SYMBOL void qselect(T *arr,
5476
size_t k,
5577
size_t arrsize,
5678
bool hasnan = false,
5779
bool descending = false);
80+
// key-value select
81+
template <typename T1, typename T2>
82+
XSS_HIDE_SYMBOL void keyvalue_select(T1 *key,
83+
T2 *val,
84+
size_t k,
85+
size_t arrsize,
86+
bool hasnan = false,
87+
bool descending = false);
5888
// partial sort
5989
template <typename T>
6090
XSS_HIDE_SYMBOL void partial_qsort(T *arr,
6191
size_t k,
6292
size_t arrsize,
6393
bool hasnan = false,
6494
bool descending = false);
95+
// key-value partial sort
96+
template <typename T1, typename T2>
97+
XSS_HIDE_SYMBOL void keyvalue_partial_sort(T1 *key,
98+
T2 *val,
99+
size_t k,
100+
size_t arrsize,
101+
bool hasnan = false,
102+
bool descending = false);
65103
// argsort
66104
template <typename T>
67105
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr,
@@ -80,22 +118,41 @@ namespace scalar {
80118
qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false);
81119
// key-value quicksort
82120
template <typename T1, typename T2>
83-
XSS_HIDE_SYMBOL void
84-
keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false);
121+
XSS_HIDE_SYMBOL void keyvalue_qsort(T1 *key,
122+
T2 *val,
123+
size_t arrsize,
124+
bool hasnan = false,
125+
bool descending = false);
85126
// quickselect
86127
template <typename T>
87128
XSS_HIDE_SYMBOL void qselect(T *arr,
88129
size_t k,
89130
size_t arrsize,
90131
bool hasnan = false,
91132
bool descending = false);
133+
// key-value select
134+
template <typename T1, typename T2>
135+
XSS_HIDE_SYMBOL void keyvalue_select(T1 *key,
136+
T2 *val,
137+
size_t k,
138+
size_t arrsize,
139+
bool hasnan = false,
140+
bool descending = false);
92141
// partial sort
93142
template <typename T>
94143
XSS_HIDE_SYMBOL void partial_qsort(T *arr,
95144
size_t k,
96145
size_t arrsize,
97146
bool hasnan = false,
98147
bool descending = false);
148+
// key-value partial sort
149+
template <typename T1, typename T2>
150+
XSS_HIDE_SYMBOL void keyvalue_partial_sort(T1 *key,
151+
T2 *val,
152+
size_t k,
153+
size_t arrsize,
154+
bool hasnan = false,
155+
bool descending = false);
99156
// argsort
100157
template <typename T>
101158
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr,

lib/x86simdsort-scalar.h

+27-2
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,37 @@ namespace scalar {
100100
return arg;
101101
}
102102
template <typename T1, typename T2>
103-
void keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan)
103+
void keyvalue_qsort(
104+
T1 *key, T2 *val, size_t arrsize, bool hasnan, bool descending)
104105
{
105-
std::vector<size_t> arg = argsort(key, arrsize, hasnan, false);
106+
std::vector<size_t> arg = argsort(key, arrsize, hasnan, descending);
106107
utils::apply_permutation_in_place(key, arg);
107108
utils::apply_permutation_in_place(val, arg);
108109
}
110+
template <typename T1, typename T2>
111+
void keyvalue_select(T1 *key,
112+
T2 *val,
113+
size_t k,
114+
size_t arrsize,
115+
bool hasnan,
116+
bool descending)
117+
{
118+
// Note that this does a full kv-sort
119+
UNUSED(k);
120+
keyvalue_qsort(key, val, arrsize, hasnan, descending);
121+
}
122+
template <typename T1, typename T2>
123+
void keyvalue_partial_sort(T1 *key,
124+
T2 *val,
125+
size_t k,
126+
size_t arrsize,
127+
bool hasnan,
128+
bool descending)
129+
{
130+
// Note that this does a full kv-sort
131+
UNUSED(k);
132+
keyvalue_qsort(key, val, arrsize, hasnan, descending);
133+
}
109134

110135
} // namespace scalar
111136
} // namespace xss

lib/x86simdsort-skx.cpp

+32-22
Original file line numberDiff line numberDiff line change
@@ -34,38 +34,48 @@
3434
return x86simdsortStatic::argselect(arr, k, arrsize, hasnan); \
3535
}
3636

37-
#define DEFINE_KEYVALUE_METHODS(type) \
38-
template <> \
39-
void keyvalue_qsort(type *key, uint64_t *val, size_t arrsize, bool hasnan) \
40-
{ \
41-
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
42-
} \
43-
template <> \
44-
void keyvalue_qsort(type *key, int64_t *val, size_t arrsize, bool hasnan) \
45-
{ \
46-
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
47-
} \
48-
template <> \
49-
void keyvalue_qsort(type *key, double *val, size_t arrsize, bool hasnan) \
50-
{ \
51-
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
52-
} \
37+
#define DEFINE_KEYVALUE_METHODS_BASE(type1, type2) \
5338
template <> \
54-
void keyvalue_qsort(type *key, uint32_t *val, size_t arrsize, bool hasnan) \
39+
void keyvalue_qsort(type1 *key, \
40+
type2 *val, \
41+
size_t arrsize, \
42+
bool hasnan, \
43+
bool descending) \
5544
{ \
56-
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
45+
x86simdsortStatic::keyvalue_qsort( \
46+
key, val, arrsize, hasnan, descending); \
5747
} \
5848
template <> \
59-
void keyvalue_qsort(type *key, int32_t *val, size_t arrsize, bool hasnan) \
49+
void keyvalue_select(type1 *key, \
50+
type2 *val, \
51+
size_t k, \
52+
size_t arrsize, \
53+
bool hasnan, \
54+
bool descending) \
6055
{ \
61-
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
56+
x86simdsortStatic::keyvalue_select( \
57+
key, val, k, arrsize, hasnan, descending); \
6258
} \
6359
template <> \
64-
void keyvalue_qsort(type *key, float *val, size_t arrsize, bool hasnan) \
60+
void keyvalue_partial_sort(type1 *key, \
61+
type2 *val, \
62+
size_t k, \
63+
size_t arrsize, \
64+
bool hasnan, \
65+
bool descending) \
6566
{ \
66-
x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \
67+
x86simdsortStatic::keyvalue_partial_sort( \
68+
key, val, k, arrsize, hasnan, descending); \
6769
}
6870

71+
#define DEFINE_KEYVALUE_METHODS(type) \
72+
DEFINE_KEYVALUE_METHODS_BASE(type, uint64_t) \
73+
DEFINE_KEYVALUE_METHODS_BASE(type, int64_t) \
74+
DEFINE_KEYVALUE_METHODS_BASE(type, double) \
75+
DEFINE_KEYVALUE_METHODS_BASE(type, uint32_t) \
76+
DEFINE_KEYVALUE_METHODS_BASE(type, int32_t) \
77+
DEFINE_KEYVALUE_METHODS_BASE(type, float)
78+
6979
namespace xss {
7080
namespace avx512 {
7181
DEFINE_ALL_METHODS(uint32_t)

0 commit comments

Comments
 (0)