Skip to content

Commit 0e9a9a1

Browse files
authored
Merge pull request #127 from sterrettm2/pivot_quality
New pivot selection algorithm to better handle many special cases
2 parents 5b5884c + cf7b0e5 commit 0e9a9a1

15 files changed

+320
-20
lines changed

src/avx2-32bit-qsort.hpp

+24-1
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ struct avx2_vector<int32_t> {
8686
{
8787
return _mm256_set1_epi32(type_max());
8888
} // TODO: this should broadcast bits as is?
89+
static opmask_t knot_opmask(opmask_t x)
90+
{
91+
auto allOnes = seti(-1, -1, -1, -1, -1, -1, -1, -1);
92+
return _mm256_xor_si256(x, allOnes);
93+
}
8994
static opmask_t get_partial_loadmask(uint64_t num_to_read)
9095
{
9196
auto mask = ((0x1ull << num_to_read) - 0x1ull);
@@ -204,6 +209,9 @@ struct avx2_vector<int32_t> {
204209
{
205210
return v;
206211
}
212+
static bool all_false(opmask_t k){
213+
return _mm256_movemask_ps(_mm256_castsi256_ps(k)) == 0;
214+
}
207215
static int double_compressstore(type_t *left_addr,
208216
type_t *right_addr,
209217
opmask_t k,
@@ -242,6 +250,11 @@ struct avx2_vector<uint32_t> {
242250
{
243251
return _mm256_set1_epi32(type_max());
244252
}
253+
static opmask_t knot_opmask(opmask_t x)
254+
{
255+
auto allOnes = seti(-1, -1, -1, -1, -1, -1, -1, -1);
256+
return _mm256_xor_si256(x, allOnes);
257+
}
245258
static opmask_t get_partial_loadmask(uint64_t num_to_read)
246259
{
247260
auto mask = ((0x1ull << num_to_read) - 0x1ull);
@@ -349,6 +362,9 @@ struct avx2_vector<uint32_t> {
349362
{
350363
return v;
351364
}
365+
static bool all_false(opmask_t k){
366+
return _mm256_movemask_ps(_mm256_castsi256_ps(k)) == 0;
367+
}
352368
static int double_compressstore(type_t *left_addr,
353369
type_t *right_addr,
354370
opmask_t k,
@@ -387,7 +403,11 @@ struct avx2_vector<float> {
387403
{
388404
return _mm256_set1_ps(type_max());
389405
}
390-
406+
static opmask_t knot_opmask(opmask_t x)
407+
{
408+
auto allOnes = seti(-1, -1, -1, -1, -1, -1, -1, -1);
409+
return _mm256_xor_si256(x, allOnes);
410+
}
391411
static ymmi_t
392412
seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8)
393413
{
@@ -514,6 +534,9 @@ struct avx2_vector<float> {
514534
{
515535
return _mm256_castps_si256(v);
516536
}
537+
static bool all_false(opmask_t k){
538+
return _mm256_movemask_ps(_mm256_castsi256_ps(k)) == 0;
539+
}
517540
static int double_compressstore(type_t *left_addr,
518541
type_t *right_addr,
519542
opmask_t k,

src/avx2-64bit-qsort.hpp

+27-3
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,17 @@ struct avx2_vector<int64_t> {
6868
{
6969
return _mm256_set1_epi64x(type_max());
7070
} // TODO: this should broadcast bits as is?
71+
static opmask_t knot_opmask(opmask_t x)
72+
{
73+
auto allTrue = _mm256_set1_epi64x(0xFFFF'FFFF'FFFF'FFFF);
74+
return _mm256_xor_si256(x, allTrue);
75+
}
7176
static opmask_t get_partial_loadmask(uint64_t num_to_read)
7277
{
7378
auto mask = ((0x1ull << num_to_read) - 0x1ull);
7479
return convert_int_to_avx2_mask_64bit(mask);
7580
}
76-
static ymmi_t seti(int v1, int v2, int v3, int v4)
81+
static ymmi_t seti(int64_t v1, int64_t v2, int64_t v3, int64_t v4)
7782
{
7883
return _mm256_set_epi64x(v1, v2, v3, v4);
7984
}
@@ -209,6 +214,9 @@ struct avx2_vector<int64_t> {
209214
{
210215
return v;
211216
}
217+
static bool all_false(opmask_t k){
218+
return _mm256_movemask_pd(_mm256_castsi256_pd(k)) == 0;
219+
}
212220
};
213221
template <>
214222
struct avx2_vector<uint64_t> {
@@ -239,12 +247,17 @@ struct avx2_vector<uint64_t> {
239247
{
240248
return _mm256_set1_epi64x(type_max());
241249
}
250+
static opmask_t knot_opmask(opmask_t x)
251+
{
252+
auto allTrue = _mm256_set1_epi64x(0xFFFF'FFFF'FFFF'FFFF);
253+
return _mm256_xor_si256(x, allTrue);
254+
}
242255
static opmask_t get_partial_loadmask(uint64_t num_to_read)
243256
{
244257
auto mask = ((0x1ull << num_to_read) - 0x1ull);
245258
return convert_int_to_avx2_mask_64bit(mask);
246259
}
247-
static ymmi_t seti(int v1, int v2, int v3, int v4)
260+
static ymmi_t seti(int64_t v1, int64_t v2, int64_t v3, int64_t v4)
248261
{
249262
return _mm256_set_epi64x(v1, v2, v3, v4);
250263
}
@@ -378,6 +391,9 @@ struct avx2_vector<uint64_t> {
378391
{
379392
return v;
380393
}
394+
static bool all_false(opmask_t k){
395+
return _mm256_movemask_pd(_mm256_castsi256_pd(k)) == 0;
396+
}
381397
};
382398

383399
/*
@@ -421,6 +437,11 @@ struct avx2_vector<double> {
421437
{
422438
return _mm256_set1_pd(type_max());
423439
}
440+
static opmask_t knot_opmask(opmask_t x)
441+
{
442+
auto allTrue = _mm256_set1_epi64x(0xFFFF'FFFF'FFFF'FFFF);
443+
return _mm256_xor_si256(x, allTrue);
444+
}
424445
static opmask_t get_partial_loadmask(uint64_t num_to_read)
425446
{
426447
auto mask = ((0x1ull << num_to_read) - 0x1ull);
@@ -440,7 +461,7 @@ struct avx2_vector<double> {
440461
static_assert(type == (0x01 | 0x80), "should not reach here");
441462
}
442463
}
443-
static ymmi_t seti(int v1, int v2, int v3, int v4)
464+
static ymmi_t seti(int64_t v1, int64_t v2, int64_t v3, int64_t v4)
444465
{
445466
return _mm256_set_epi64x(v1, v2, v3, v4);
446467
}
@@ -571,6 +592,9 @@ struct avx2_vector<double> {
571592
{
572593
return _mm256_castpd_si256(v);
573594
}
595+
static bool all_false(opmask_t k){
596+
return _mm256_movemask_pd(_mm256_castsi256_pd(k)) == 0;
597+
}
574598
};
575599

576600
struct avx2_64bit_swizzle_ops {

src/avx512-16bit-qsort.hpp

+21
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ struct zmm_vector<float16> {
8181
exp_eq, mant_x, mant_y, _MM_CMPINT_NLT);
8282
return _kxor_mask32(mask_ge, neg);
8383
}
84+
static opmask_t eq(reg_t x, reg_t y)
85+
{
86+
return _mm512_cmpeq_epu16_mask(x, y);
87+
}
8488
static opmask_t get_partial_loadmask(uint64_t num_to_read)
8589
{
8690
return ((0x1ull << num_to_read) - 0x1ull);
@@ -186,6 +190,9 @@ struct zmm_vector<float16> {
186190
{
187191
return v;
188192
}
193+
static bool all_false(opmask_t k){
194+
return k == 0;
195+
}
189196
static int double_compressstore(type_t *left_addr,
190197
type_t *right_addr,
191198
opmask_t k,
@@ -238,6 +245,10 @@ struct zmm_vector<int16_t> {
238245
{
239246
return _mm512_cmp_epi16_mask(x, y, _MM_CMPINT_NLT);
240247
}
248+
static opmask_t eq(reg_t x, reg_t y)
249+
{
250+
return _mm512_cmpeq_epi16_mask(x, y);
251+
}
241252
static opmask_t get_partial_loadmask(uint64_t num_to_read)
242253
{
243254
return ((0x1ull << num_to_read) - 0x1ull);
@@ -323,6 +334,9 @@ struct zmm_vector<int16_t> {
323334
{
324335
return v;
325336
}
337+
static bool all_false(opmask_t k){
338+
return k == 0;
339+
}
326340
static int double_compressstore(type_t *left_addr,
327341
type_t *right_addr,
328342
opmask_t k,
@@ -374,6 +388,10 @@ struct zmm_vector<uint16_t> {
374388
{
375389
return _mm512_cmp_epu16_mask(x, y, _MM_CMPINT_NLT);
376390
}
391+
static opmask_t eq(reg_t x, reg_t y)
392+
{
393+
return _mm512_cmpeq_epu16_mask(x, y);
394+
}
377395
static opmask_t get_partial_loadmask(uint64_t num_to_read)
378396
{
379397
return ((0x1ull << num_to_read) - 0x1ull);
@@ -457,6 +475,9 @@ struct zmm_vector<uint16_t> {
457475
{
458476
return v;
459477
}
478+
static bool all_false(opmask_t k){
479+
return k == 0;
480+
}
460481
static int double_compressstore(type_t *left_addr,
461482
type_t *right_addr,
462483
opmask_t k,

src/avx512-32bit-qsort.hpp

+9
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,9 @@ struct zmm_vector<int32_t> {
198198
{
199199
return v;
200200
}
201+
static bool all_false(opmask_t k){
202+
return k == 0;
203+
}
201204
static int double_compressstore(type_t *left_addr,
202205
type_t *right_addr,
203206
opmask_t k,
@@ -377,6 +380,9 @@ struct zmm_vector<uint32_t> {
377380
{
378381
return v;
379382
}
383+
static bool all_false(opmask_t k){
384+
return k == 0;
385+
}
380386
static int double_compressstore(type_t *left_addr,
381387
type_t *right_addr,
382388
opmask_t k,
@@ -570,6 +576,9 @@ struct zmm_vector<float> {
570576
{
571577
return _mm512_castps_si512(v);
572578
}
579+
static bool all_false(opmask_t k){
580+
return k == 0;
581+
}
573582
static int double_compressstore(type_t *left_addr,
574583
type_t *right_addr,
575584
opmask_t k,

src/avx512-64bit-common.h

+9
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,9 @@ struct zmm_vector<int64_t> {
732732
{
733733
return v;
734734
}
735+
static bool all_false(opmask_t k){
736+
return k == 0;
737+
}
735738
static int double_compressstore(type_t *left_addr,
736739
type_t *right_addr,
737740
opmask_t k,
@@ -903,6 +906,9 @@ struct zmm_vector<uint64_t> {
903906
{
904907
return v;
905908
}
909+
static bool all_false(opmask_t k){
910+
return k == 0;
911+
}
906912
static int double_compressstore(type_t *left_addr,
907913
type_t *right_addr,
908914
opmask_t k,
@@ -1093,6 +1099,9 @@ struct zmm_vector<double> {
10931099
{
10941100
return _mm512_castpd_si512(v);
10951101
}
1102+
static bool all_false(opmask_t k){
1103+
return k == 0;
1104+
}
10961105
static int double_compressstore(type_t *left_addr,
10971106
type_t *right_addr,
10981107
opmask_t k,

src/avx512fp16-16bit-qsort.hpp

+7
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ struct zmm_vector<_Float16> {
5555
{
5656
return _mm512_cmp_ph_mask(x, y, _CMP_GE_OQ);
5757
}
58+
static opmask_t eq(reg_t x, reg_t y)
59+
{
60+
return _mm512_cmp_ph_mask(x, y, _CMP_EQ_OQ);
61+
}
5862
static opmask_t get_partial_loadmask(uint64_t num_to_read)
5963
{
6064
return ((0x1ull << num_to_read) - 0x1ull);
@@ -150,6 +154,9 @@ struct zmm_vector<_Float16> {
150154
{
151155
return _mm512_castph_si512(v);
152156
}
157+
static bool all_false(opmask_t k){
158+
return k == 0;
159+
}
153160
static int double_compressstore(type_t *left_addr,
154161
type_t *right_addr,
155162
opmask_t k,

src/xss-common-includes.h

+3
Original file line numberDiff line numberDiff line change
@@ -106,4 +106,7 @@ struct avx2_half_vector;
106106

107107
enum class simd_type : int { AVX2, AVX512 };
108108

109+
template <typename vtype, typename T = typename vtype::type_t>
110+
X86_SIMD_SORT_INLINE bool comparison_func(const T &a, const T &b);
111+
109112
#endif // XSS_COMMON_INCLUDES

src/xss-common-qsort.h

+10-3
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ X86_SIMD_SORT_INLINE bool array_has_nan(type_t *arr, arrsize_t size)
8787
else {
8888
in = vtype::loadu(arr + ii);
8989
}
90-
auto nanmask = vtype::convert_mask_to_int(vtype::template fpclass<0x01 | 0x80>(in));
90+
auto nanmask = vtype::convert_mask_to_int(
91+
vtype::template fpclass<0x01 | 0x80>(in));
9192
if (nanmask != 0x00) {
9293
found_nan = true;
9394
break;
@@ -136,7 +137,7 @@ X86_SIMD_SORT_INLINE arrsize_t move_nans_to_end_of_array(T *arr, arrsize_t size)
136137
return size - count - 1;
137138
}
138139

139-
template <typename vtype, typename T = typename vtype::type_t>
140+
template <typename vtype, typename T>
140141
X86_SIMD_SORT_INLINE bool comparison_func(const T &a, const T &b)
141142
{
142143
return a < b;
@@ -499,14 +500,20 @@ qsort_(type_t *arr, arrsize_t left, arrsize_t right, arrsize_t max_iters)
499500
return;
500501
}
501502

502-
type_t pivot = get_pivot_blocks<vtype, type_t>(arr, left, right);
503+
auto pivot_result = get_pivot_smart<vtype, type_t>(arr, left, right);
504+
type_t pivot = pivot_result.pivot;
505+
506+
if (pivot_result.result == pivot_result_t::Sorted) { return; }
507+
503508
type_t smallest = vtype::type_max();
504509
type_t biggest = vtype::type_min();
505510

506511
arrsize_t pivot_index
507512
= partition_avx512_unrolled<vtype, vtype::partition_unroll_factor>(
508513
arr, left, right + 1, pivot, &smallest, &biggest);
509514

515+
if (pivot_result.result == pivot_result_t::Only2Values) { return; }
516+
510517
if (pivot != smallest)
511518
qsort_<vtype>(arr, left, pivot_index - 1, max_iters - 1);
512519
if (pivot != biggest) qsort_<vtype>(arr, pivot_index, right, max_iters - 1);

src/xss-network-keyvaluesort.hpp

+4-6
Original file line numberDiff line numberDiff line change
@@ -441,9 +441,8 @@ bitonic_fullmerge_n_vec(typename keyType::reg_t *keys,
441441
}
442442

443443
template <typename keyType, typename indexType, int numVecs>
444-
X86_SIMD_SORT_INLINE void argsort_n_vec(typename keyType::type_t *keys,
445-
arrsize_t *indices,
446-
int N)
444+
X86_SIMD_SORT_INLINE void
445+
argsort_n_vec(typename keyType::type_t *keys, arrsize_t *indices, int N)
447446
{
448447
using kreg_t = typename keyType::reg_t;
449448
using ireg_t = typename indexType::reg_t;
@@ -586,9 +585,8 @@ X86_SIMD_SORT_INLINE void kvsort_n_vec(typename keyType::type_t *keys,
586585
}
587586

588587
template <typename keyType, typename indexType, int maxN>
589-
X86_SIMD_SORT_INLINE void argsort_n(typename keyType::type_t *keys,
590-
arrsize_t *indices,
591-
int N)
588+
X86_SIMD_SORT_INLINE void
589+
argsort_n(typename keyType::type_t *keys, arrsize_t *indices, int N)
592590
{
593591
static_assert(keyType::numlanes == indexType::numlanes,
594592
"invalid pairing of value/index types");

0 commit comments

Comments
 (0)