Skip to content

Commit 0041c05

Browse files
sterrettm2r-devulap
authored andcommitted
New pivot selection to improve performance in many special cases
1 parent 5b5884c commit 0041c05

10 files changed

+301
-14
lines changed

src/avx2-32bit-qsort.hpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ struct avx2_vector<int32_t> {
8686
{
8787
return _mm256_set1_epi32(type_max());
8888
} // TODO: this should broadcast bits as is?
89+
static opmask_t knot_opmask(opmask_t x)
90+
{
91+
auto allOnes = seti(-1, -1, -1, -1, -1, -1, -1, -1);
92+
return _mm256_xor_si256(x, allOnes);
93+
}
8994
static opmask_t get_partial_loadmask(uint64_t num_to_read)
9095
{
9196
auto mask = ((0x1ull << num_to_read) - 0x1ull);
@@ -204,6 +209,9 @@ struct avx2_vector<int32_t> {
204209
{
205210
return v;
206211
}
212+
static bool all_false(opmask_t k){
213+
return _mm256_movemask_ps(_mm256_castsi256_ps(k)) == 0;
214+
}
207215
static int double_compressstore(type_t *left_addr,
208216
type_t *right_addr,
209217
opmask_t k,
@@ -242,6 +250,11 @@ struct avx2_vector<uint32_t> {
242250
{
243251
return _mm256_set1_epi32(type_max());
244252
}
253+
static opmask_t knot_opmask(opmask_t x)
254+
{
255+
auto allOnes = seti(-1, -1, -1, -1, -1, -1, -1, -1);
256+
return _mm256_xor_si256(x, allOnes);
257+
}
245258
static opmask_t get_partial_loadmask(uint64_t num_to_read)
246259
{
247260
auto mask = ((0x1ull << num_to_read) - 0x1ull);
@@ -349,6 +362,9 @@ struct avx2_vector<uint32_t> {
349362
{
350363
return v;
351364
}
365+
static bool all_false(opmask_t k){
366+
return _mm256_movemask_ps(_mm256_castsi256_ps(k)) == 0;
367+
}
352368
static int double_compressstore(type_t *left_addr,
353369
type_t *right_addr,
354370
opmask_t k,
@@ -387,7 +403,11 @@ struct avx2_vector<float> {
387403
{
388404
return _mm256_set1_ps(type_max());
389405
}
390-
406+
static opmask_t knot_opmask(opmask_t x)
407+
{
408+
auto allOnes = seti(-1, -1, -1, -1, -1, -1, -1, -1);
409+
return _mm256_xor_si256(x, allOnes);
410+
}
391411
static ymmi_t
392412
seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8)
393413
{
@@ -514,6 +534,9 @@ struct avx2_vector<float> {
514534
{
515535
return _mm256_castps_si256(v);
516536
}
537+
static bool all_false(opmask_t k){
538+
return _mm256_movemask_ps(_mm256_castsi256_ps(k)) == 0;
539+
}
517540
static int double_compressstore(type_t *left_addr,
518541
type_t *right_addr,
519542
opmask_t k,

src/avx2-64bit-qsort.hpp

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,17 @@ struct avx2_vector<int64_t> {
6868
{
6969
return _mm256_set1_epi64x(type_max());
7070
} // TODO: this should broadcast bits as is?
71+
static opmask_t knot_opmask(opmask_t x)
72+
{
73+
auto allTrue = _mm256_set1_epi64x(0xFFFF'FFFF);
74+
return _mm256_xor_si256(x, allTrue);
75+
}
7176
static opmask_t get_partial_loadmask(uint64_t num_to_read)
7277
{
7378
auto mask = ((0x1ull << num_to_read) - 0x1ull);
7479
return convert_int_to_avx2_mask_64bit(mask);
7580
}
76-
static ymmi_t seti(int v1, int v2, int v3, int v4)
81+
static ymmi_t seti(int64_t v1, int64_t v2, int64_t v3, int64_t v4)
7782
{
7883
return _mm256_set_epi64x(v1, v2, v3, v4);
7984
}
@@ -209,6 +214,9 @@ struct avx2_vector<int64_t> {
209214
{
210215
return v;
211216
}
217+
static bool all_false(opmask_t k){
218+
return _mm256_movemask_pd(_mm256_castsi256_pd(k)) == 0;
219+
}
212220
};
213221
template <>
214222
struct avx2_vector<uint64_t> {
@@ -239,12 +247,17 @@ struct avx2_vector<uint64_t> {
239247
{
240248
return _mm256_set1_epi64x(type_max());
241249
}
250+
static opmask_t knot_opmask(opmask_t x)
251+
{
252+
auto allTrue = _mm256_set1_epi64x(0xFFFF'FFFF);
253+
return _mm256_xor_si256(x, allTrue);
254+
}
242255
static opmask_t get_partial_loadmask(uint64_t num_to_read)
243256
{
244257
auto mask = ((0x1ull << num_to_read) - 0x1ull);
245258
return convert_int_to_avx2_mask_64bit(mask);
246259
}
247-
static ymmi_t seti(int v1, int v2, int v3, int v4)
260+
static ymmi_t seti(int64_t v1, int64_t v2, int64_t v3, int64_t v4)
248261
{
249262
return _mm256_set_epi64x(v1, v2, v3, v4);
250263
}
@@ -378,6 +391,9 @@ struct avx2_vector<uint64_t> {
378391
{
379392
return v;
380393
}
394+
static bool all_false(opmask_t k){
395+
return _mm256_movemask_pd(_mm256_castsi256_pd(k)) == 0;
396+
}
381397
};
382398

383399
/*
@@ -421,6 +437,11 @@ struct avx2_vector<double> {
421437
{
422438
return _mm256_set1_pd(type_max());
423439
}
440+
static opmask_t knot_opmask(opmask_t x)
441+
{
442+
auto allTrue = _mm256_set1_epi64x(0xFFFF'FFFF);
443+
return _mm256_xor_si256(x, allTrue);
444+
}
424445
static opmask_t get_partial_loadmask(uint64_t num_to_read)
425446
{
426447
auto mask = ((0x1ull << num_to_read) - 0x1ull);
@@ -440,7 +461,7 @@ struct avx2_vector<double> {
440461
static_assert(type == (0x01 | 0x80), "should not reach here");
441462
}
442463
}
443-
static ymmi_t seti(int v1, int v2, int v3, int v4)
464+
static ymmi_t seti(int64_t v1, int64_t v2, int64_t v3, int64_t v4)
444465
{
445466
return _mm256_set_epi64x(v1, v2, v3, v4);
446467
}
@@ -571,6 +592,9 @@ struct avx2_vector<double> {
571592
{
572593
return _mm256_castpd_si256(v);
573594
}
595+
static bool all_false(opmask_t k){
596+
return _mm256_movemask_pd(_mm256_castsi256_pd(k)) == 0;
597+
}
574598
};
575599

576600
struct avx2_64bit_swizzle_ops {

src/avx512-16bit-qsort.hpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ struct zmm_vector<float16> {
8181
exp_eq, mant_x, mant_y, _MM_CMPINT_NLT);
8282
return _kxor_mask32(mask_ge, neg);
8383
}
84+
static opmask_t eq(reg_t x, reg_t y)
85+
{
86+
return _mm512_cmpeq_epu16_mask(x, y);
87+
}
8488
static opmask_t get_partial_loadmask(uint64_t num_to_read)
8589
{
8690
return ((0x1ull << num_to_read) - 0x1ull);
@@ -186,6 +190,9 @@ struct zmm_vector<float16> {
186190
{
187191
return v;
188192
}
193+
static bool all_false(opmask_t k){
194+
return k == 0;
195+
}
189196
static int double_compressstore(type_t *left_addr,
190197
type_t *right_addr,
191198
opmask_t k,
@@ -238,6 +245,10 @@ struct zmm_vector<int16_t> {
238245
{
239246
return _mm512_cmp_epi16_mask(x, y, _MM_CMPINT_NLT);
240247
}
248+
static opmask_t eq(reg_t x, reg_t y)
249+
{
250+
return _mm512_cmpeq_epi16_mask(x, y);
251+
}
241252
static opmask_t get_partial_loadmask(uint64_t num_to_read)
242253
{
243254
return ((0x1ull << num_to_read) - 0x1ull);
@@ -323,6 +334,9 @@ struct zmm_vector<int16_t> {
323334
{
324335
return v;
325336
}
337+
static bool all_false(opmask_t k){
338+
return k == 0;
339+
}
326340
static int double_compressstore(type_t *left_addr,
327341
type_t *right_addr,
328342
opmask_t k,
@@ -374,6 +388,10 @@ struct zmm_vector<uint16_t> {
374388
{
375389
return _mm512_cmp_epu16_mask(x, y, _MM_CMPINT_NLT);
376390
}
391+
static opmask_t eq(reg_t x, reg_t y)
392+
{
393+
return _mm512_cmpeq_epu16_mask(x, y);
394+
}
377395
static opmask_t get_partial_loadmask(uint64_t num_to_read)
378396
{
379397
return ((0x1ull << num_to_read) - 0x1ull);
@@ -457,6 +475,9 @@ struct zmm_vector<uint16_t> {
457475
{
458476
return v;
459477
}
478+
static bool all_false(opmask_t k){
479+
return k == 0;
480+
}
460481
static int double_compressstore(type_t *left_addr,
461482
type_t *right_addr,
462483
opmask_t k,

src/avx512-32bit-qsort.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,9 @@ struct zmm_vector<int32_t> {
198198
{
199199
return v;
200200
}
201+
static bool all_false(opmask_t k){
202+
return k == 0;
203+
}
201204
static int double_compressstore(type_t *left_addr,
202205
type_t *right_addr,
203206
opmask_t k,
@@ -377,6 +380,9 @@ struct zmm_vector<uint32_t> {
377380
{
378381
return v;
379382
}
383+
static bool all_false(opmask_t k){
384+
return k == 0;
385+
}
380386
static int double_compressstore(type_t *left_addr,
381387
type_t *right_addr,
382388
opmask_t k,
@@ -570,6 +576,9 @@ struct zmm_vector<float> {
570576
{
571577
return _mm512_castps_si512(v);
572578
}
579+
static bool all_false(opmask_t k){
580+
return k == 0;
581+
}
573582
static int double_compressstore(type_t *left_addr,
574583
type_t *right_addr,
575584
opmask_t k,

src/avx512-64bit-common.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,9 @@ struct zmm_vector<int64_t> {
732732
{
733733
return v;
734734
}
735+
static bool all_false(opmask_t k){
736+
return k == 0;
737+
}
735738
static int double_compressstore(type_t *left_addr,
736739
type_t *right_addr,
737740
opmask_t k,
@@ -903,6 +906,9 @@ struct zmm_vector<uint64_t> {
903906
{
904907
return v;
905908
}
909+
static bool all_false(opmask_t k){
910+
return k == 0;
911+
}
906912
static int double_compressstore(type_t *left_addr,
907913
type_t *right_addr,
908914
opmask_t k,
@@ -1093,6 +1099,9 @@ struct zmm_vector<double> {
10931099
{
10941100
return _mm512_castpd_si512(v);
10951101
}
1102+
static bool all_false(opmask_t k){
1103+
return k == 0;
1104+
}
10961105
static int double_compressstore(type_t *left_addr,
10971106
type_t *right_addr,
10981107
opmask_t k,

src/avx512fp16-16bit-qsort.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@ struct zmm_vector<_Float16> {
150150
{
151151
return _mm512_castph_si512(v);
152152
}
153+
static bool all_false(opmask_t k){
154+
return k == 0;
155+
}
153156
static int double_compressstore(type_t *left_addr,
154157
type_t *right_addr,
155158
opmask_t k,

src/xss-common-qsort.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -498,14 +498,24 @@ qsort_(type_t *arr, arrsize_t left, arrsize_t right, arrsize_t max_iters)
498498
arr + left, (int32_t)(right + 1 - left));
499499
return;
500500
}
501-
502-
type_t pivot = get_pivot_blocks<vtype, type_t>(arr, left, right);
501+
502+
auto pivot_result = get_pivot_smart<vtype, type_t>(arr, left, right);
503+
type_t pivot = pivot_result.pivot;
504+
505+
if (pivot_result.alreadySorted){
506+
return;
507+
}
508+
503509
type_t smallest = vtype::type_max();
504510
type_t biggest = vtype::type_min();
505511

506512
arrsize_t pivot_index
507513
= partition_avx512_unrolled<vtype, vtype::partition_unroll_factor>(
508514
arr, left, right + 1, pivot, &smallest, &biggest);
515+
516+
if (pivot_result.only2Values){
517+
return;
518+
}
509519

510520
if (pivot != smallest)
511521
qsort_<vtype>(arr, left, pivot_index - 1, max_iters - 1);

src/xss-network-qsort.hpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
#include "xss-optimal-networks.hpp"
55
#include "xss-common-qsort.h"
66

7+
template <typename vtype, typename mm_t>
8+
X86_SIMD_SORT_INLINE void COEX(mm_t &a, mm_t &b);
9+
710
template <typename vtype, int numVecs, typename reg_t = typename vtype::reg_t>
811
X86_SIMD_SORT_FINLINE void bitonic_sort_n_vec(reg_t *regs)
912
{
@@ -140,6 +143,17 @@ X86_SIMD_SORT_FINLINE void merge_n_vec(reg_t *regs)
140143
}
141144
}
142145

146+
template <typename vtype, int numVecs, typename reg_t = typename vtype::reg_t>
147+
X86_SIMD_SORT_FINLINE void sort_vectors(reg_t * vecs){
148+
/* Run the initial sorting network to sort the columns of the [numVecs x
149+
* num_lanes] matrix
150+
*/
151+
bitonic_sort_n_vec<vtype, numVecs>(vecs);
152+
153+
// Merge the vectors using bitonic merging networks
154+
merge_n_vec<vtype, numVecs>(vecs);
155+
}
156+
143157
template <typename vtype, int numVecs, typename reg_t = typename vtype::reg_t>
144158
X86_SIMD_SORT_INLINE void sort_n_vec(typename vtype::type_t *arr, int N)
145159
{
@@ -174,14 +188,8 @@ X86_SIMD_SORT_INLINE void sort_n_vec(typename vtype::type_t *arr, int N)
174188
vecs[i] = vtype::mask_loadu(
175189
vtype::zmm_max(), ioMasks[j], arr + i * vtype::numlanes);
176190
}
177-
178-
/* Run the initial sorting network to sort the columns of the [numVecs x
179-
* num_lanes] matrix
180-
*/
181-
bitonic_sort_n_vec<vtype, numVecs>(vecs);
182-
183-
// Merge the vectors using bitonic merging networks
184-
merge_n_vec<vtype, numVecs>(vecs);
191+
192+
sort_vectors<vtype, numVecs>(vecs);
185193

186194
// Unmasked part of the store
187195
X86_SIMD_SORT_UNROLL_LOOP(64)

src/xss-optimal-networks.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
// All of these sources files are generated from the optimal networks described in
22
// https://bertdobbelaere.github.io/sorting_networks.html
33

4+
template <typename vtype, typename mm_t>
5+
X86_SIMD_SORT_INLINE void COEX(mm_t &a, mm_t &b);
6+
47
template <typename vtype, typename reg_t = typename vtype::reg_t>
58
X86_SIMD_SORT_FINLINE void optimal_sort_4(reg_t *vecs)
69
{

0 commit comments

Comments
 (0)