@@ -29,10 +29,148 @@ static const uint16_t network[6][32]
2929 {16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 , 30 , 31 ,
3030 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 }};
3131
32+ struct float16 {
33+ uint16_t val;
34+ };
35+
36+ template <>
37+ struct zmm_vector <float16> {
38+ using type_t = uint16_t ;
39+ using zmm_t = __m512i;
40+ using ymm_t = __m256i;
41+ using opmask_t = __mmask32;
42+ static const uint8_t numlanes = 32 ;
43+
44+ static zmm_t get_network (int index)
45+ {
46+ return _mm512_loadu_si512 (&network[index - 1 ][0 ]);
47+ }
48+ static type_t type_max ()
49+ {
50+ return X86_SIMD_SORT_INFINITYH;
51+ }
52+ static type_t type_min ()
53+ {
54+ return X86_SIMD_SORT_NEGINFINITYH;
55+ }
56+ static zmm_t zmm_max ()
57+ {
58+ return _mm512_set1_epi16 (type_max ());
59+ }
60+ static opmask_t knot_opmask (opmask_t x)
61+ {
62+ return _knot_mask32 (x);
63+ }
64+
65+ static opmask_t ge (zmm_t x, zmm_t y)
66+ {
67+ zmm_t sign_x = _mm512_and_si512 (x, _mm512_set1_epi16 (0x8000 ));
68+ zmm_t sign_y = _mm512_and_si512 (y, _mm512_set1_epi16 (0x8000 ));
69+ zmm_t exp_x = _mm512_and_si512 (x, _mm512_set1_epi16 (0x7c00 ));
70+ zmm_t exp_y = _mm512_and_si512 (y, _mm512_set1_epi16 (0x7c00 ));
71+ zmm_t mant_x = _mm512_and_si512 (x, _mm512_set1_epi16 (0x3ff ));
72+ zmm_t mant_y = _mm512_and_si512 (y, _mm512_set1_epi16 (0x3ff ));
73+
74+ __mmask32 mask_ge = _mm512_cmp_epu16_mask (sign_x, sign_y, _MM_CMPINT_LT); // only greater than
75+ __mmask32 sign_eq = _mm512_cmpeq_epu16_mask (sign_x, sign_y);
76+ __mmask32 neg = _mm512_mask_cmpeq_epu16_mask (sign_eq, sign_x, _mm512_set1_epi16 (0x8000 )); // both numbers are -ve
77+
78+ // compare exponents only if signs are equal:
79+ mask_ge = mask_ge | _mm512_mask_cmp_epu16_mask (sign_eq, exp_x, exp_y, _MM_CMPINT_NLE);
80+ // get mask for elements for which both sign and exponents are equal:
81+ __mmask32 exp_eq = _mm512_mask_cmpeq_epu16_mask (sign_eq, exp_x, exp_y);
82+
83+ // compare mantissa for elements for which both sign and expponent are equal:
84+ mask_ge = mask_ge | _mm512_mask_cmp_epu16_mask (exp_eq, mant_x, mant_y, _MM_CMPINT_NLT);
85+ return _kxor_mask32 (mask_ge, neg);
86+ }
87+ static zmm_t loadu (void const *mem)
88+ {
89+ return _mm512_loadu_si512 (mem);
90+ }
91+ static zmm_t max (zmm_t x, zmm_t y)
92+ {
93+ return _mm512_mask_mov_epi16 (y, ge (x, y), x);
94+ }
95+ static void mask_compressstoreu (void *mem, opmask_t mask, zmm_t x)
96+ {
97+ // AVX512_VBMI2
98+ return _mm512_mask_compressstoreu_epi16 (mem, mask, x);
99+ }
100+ static zmm_t mask_loadu (zmm_t x, opmask_t mask, void const *mem)
101+ {
102+ // AVX512BW
103+ return _mm512_mask_loadu_epi16 (x, mask, mem);
104+ }
105+ static zmm_t mask_mov (zmm_t x, opmask_t mask, zmm_t y)
106+ {
107+ return _mm512_mask_mov_epi16 (x, mask, y);
108+ }
109+ static void mask_storeu (void *mem, opmask_t mask, zmm_t x)
110+ {
111+ return _mm512_mask_storeu_epi16 (mem, mask, x);
112+ }
113+ static zmm_t min (zmm_t x, zmm_t y)
114+ {
115+ return _mm512_mask_mov_epi16 (x, ge (x, y), y);
116+ }
117+ static zmm_t permutexvar (__m512i idx, zmm_t zmm)
118+ {
119+ return _mm512_permutexvar_epi16 (idx, zmm);
120+ }
121+ // Apparently this is a terrible for perf, npy_half_to_float seems to work
122+ // better
123+ // static float uint16_to_float(uint16_t val)
124+ // {
125+ // // Ideally use _mm_loadu_si16, but its only gcc > 11.x
126+ // // TODO: use inline ASM? https://godbolt.org/z/aGYvh7fMM
127+ // __m128i xmm = _mm_maskz_loadu_epi16(0x01, &val);
128+ // __m128 xmm2 = _mm_cvtph_ps(xmm);
129+ // return _mm_cvtss_f32(xmm2);
130+ // }
131+ static type_t float_to_uint16 (float val)
132+ {
133+ __m128 xmm = _mm_load_ss (&val);
134+ __m128i xmm2 = _mm_cvtps_ph (xmm, _MM_FROUND_NO_EXC);
135+ return _mm_extract_epi16 (xmm2, 0 );
136+ }
137+ static type_t reducemax (zmm_t v)
138+ {
139+ __m512 lo = _mm512_cvtph_ps (_mm512_extracti64x4_epi64 (v, 0 ));
140+ __m512 hi = _mm512_cvtph_ps (_mm512_extracti64x4_epi64 (v, 1 ));
141+ float lo_max = _mm512_reduce_max_ps (lo);
142+ float hi_max = _mm512_reduce_max_ps (hi);
143+ return float_to_uint16 (std::max (lo_max, hi_max));
144+ }
145+ static type_t reducemin (zmm_t v)
146+ {
147+ __m512 lo = _mm512_cvtph_ps (_mm512_extracti64x4_epi64 (v, 0 ));
148+ __m512 hi = _mm512_cvtph_ps (_mm512_extracti64x4_epi64 (v, 1 ));
149+ float lo_max = _mm512_reduce_min_ps (lo);
150+ float hi_max = _mm512_reduce_min_ps (hi);
151+ return float_to_uint16 (std::min (lo_max, hi_max));
152+ }
153+ static zmm_t set1 (type_t v)
154+ {
155+ return _mm512_set1_epi16 (v);
156+ }
157+ template <uint8_t mask>
158+ static zmm_t shuffle (zmm_t zmm)
159+ {
160+ zmm = _mm512_shufflehi_epi16 (zmm, (_MM_PERM_ENUM)mask);
161+ return _mm512_shufflelo_epi16 (zmm, (_MM_PERM_ENUM)mask);
162+ }
163+ static void storeu (void *mem, zmm_t x)
164+ {
165+ return _mm512_storeu_si512 (mem, x);
166+ }
167+ };
168+
32169template <>
33170struct zmm_vector <int16_t > {
34171 using type_t = int16_t ;
35172 using zmm_t = __m512i;
173+ using ymm_t = __m256i;
36174 using opmask_t = __mmask32;
37175 static const uint8_t numlanes = 32 ;
38176
@@ -130,6 +268,7 @@ template <>
130268struct zmm_vector <uint16_t > {
131269 using type_t = uint16_t ;
132270 using zmm_t = __m512i;
271+ using ymm_t = __m256i;
133272 using opmask_t = __mmask32;
134273 static const uint8_t numlanes = 32 ;
135274
@@ -227,7 +366,7 @@ struct zmm_vector<uint16_t> {
227366 * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg
228367 */
229368template <typename vtype, typename zmm_t = typename vtype::zmm_t >
230- X86_SIMD_SORT_FORCEINLINE zmm_t sort_zmm_16bit (zmm_t zmm)
369+ X86_SIMD_SORT_FINLINE zmm_t sort_zmm_16bit (zmm_t zmm)
231370{
232371 // Level 1
233372 zmm = cmp_merge<vtype>(
@@ -287,7 +426,7 @@ X86_SIMD_SORT_FORCEINLINE zmm_t sort_zmm_16bit(zmm_t zmm)
287426
288427// Assumes zmm is bitonic and performs a recursive half cleaner
289428template <typename vtype, typename zmm_t = typename vtype::zmm_t >
290- X86_SIMD_SORT_FORCEINLINE zmm_t bitonic_merge_zmm_16bit (zmm_t zmm)
429+ X86_SIMD_SORT_FINLINE zmm_t bitonic_merge_zmm_16bit (zmm_t zmm)
291430{
292431 // 1) half_cleaner[32]: compare 1-17, 2-18, 3-19 etc ..
293432 zmm = cmp_merge<vtype>(
@@ -313,8 +452,7 @@ X86_SIMD_SORT_FORCEINLINE zmm_t bitonic_merge_zmm_16bit(zmm_t zmm)
313452
314453// Assumes zmm1 and zmm2 are sorted and performs a recursive half cleaner
315454template <typename vtype, typename zmm_t = typename vtype::zmm_t >
316- X86_SIMD_SORT_FORCEINLINE void bitonic_merge_two_zmm_16bit (zmm_t &zmm1,
317- zmm_t &zmm2)
455+ X86_SIMD_SORT_FINLINE void bitonic_merge_two_zmm_16bit (zmm_t &zmm1, zmm_t &zmm2)
318456{
319457 // 1) First step of a merging network: coex of zmm1 and zmm2 reversed
320458 zmm2 = vtype::permutexvar (vtype::get_network (4 ), zmm2);
@@ -328,7 +466,7 @@ X86_SIMD_SORT_FORCEINLINE void bitonic_merge_two_zmm_16bit(zmm_t &zmm1,
328466// Assumes [zmm0, zmm1] and [zmm2, zmm3] are sorted and performs a recursive
329467// half cleaner
330468template <typename vtype, typename zmm_t = typename vtype::zmm_t >
331- X86_SIMD_SORT_FORCEINLINE void bitonic_merge_four_zmm_16bit (zmm_t *zmm)
469+ X86_SIMD_SORT_FINLINE void bitonic_merge_four_zmm_16bit (zmm_t *zmm)
332470{
333471 zmm_t zmm2r = vtype::permutexvar (vtype::get_network (4 ), zmm[2 ]);
334472 zmm_t zmm3r = vtype::permutexvar (vtype::get_network (4 ), zmm[3 ]);
@@ -349,7 +487,7 @@ X86_SIMD_SORT_FORCEINLINE void bitonic_merge_four_zmm_16bit(zmm_t *zmm)
349487}
350488
351489template <typename vtype, typename type_t >
352- X86_SIMD_SORT_FORCEINLINE void sort_32_16bit (type_t *arr, int32_t N)
490+ X86_SIMD_SORT_FINLINE void sort_32_16bit (type_t *arr, int32_t N)
353491{
354492 typename vtype::opmask_t load_mask = ((0x1ull << N) - 0x1ull ) & 0xFFFFFFFF ;
355493 typename vtype::zmm_t zmm
@@ -358,7 +496,7 @@ X86_SIMD_SORT_FORCEINLINE void sort_32_16bit(type_t *arr, int32_t N)
358496}
359497
360498template <typename vtype, typename type_t >
361- X86_SIMD_SORT_FORCEINLINE void sort_64_16bit (type_t *arr, int32_t N)
499+ X86_SIMD_SORT_FINLINE void sort_64_16bit (type_t *arr, int32_t N)
362500{
363501 if (N <= 32 ) {
364502 sort_32_16bit<vtype>(arr, N);
@@ -377,7 +515,7 @@ X86_SIMD_SORT_FORCEINLINE void sort_64_16bit(type_t *arr, int32_t N)
377515}
378516
379517template <typename vtype, typename type_t >
380- X86_SIMD_SORT_FORCEINLINE void sort_128_16bit (type_t *arr, int32_t N)
518+ X86_SIMD_SORT_FINLINE void sort_128_16bit (type_t *arr, int32_t N)
381519{
382520 if (N <= 64 ) {
383521 sort_64_16bit<vtype>(arr, N);
@@ -410,9 +548,9 @@ X86_SIMD_SORT_FORCEINLINE void sort_128_16bit(type_t *arr, int32_t N)
410548}
411549
412550template <typename vtype, typename type_t >
413- X86_SIMD_SORT_FORCEINLINE type_t get_pivot_16bit (type_t *arr,
414- const int64_t left,
415- const int64_t right)
551+ X86_SIMD_SORT_FINLINE type_t get_pivot_16bit (type_t *arr,
552+ const int64_t left,
553+ const int64_t right)
416554{
417555 // median of 32
418556 int64_t size = (right - left) / 32 ;
@@ -453,6 +591,38 @@ X86_SIMD_SORT_FORCEINLINE type_t get_pivot_16bit(type_t *arr,
453591 return ((type_t *)&sort)[16 ];
454592}
455593
594+ template <>
595+ bool comparison_func<zmm_vector<float16>>(const uint16_t &a, const uint16_t &b)
596+ {
597+ uint16_t signa = a & 0x8000 , signb = b & 0x8000 ;
598+ uint16_t expa = a & 0x7c00 , expb = b & 0x7c00 ;
599+ uint16_t manta = a & 0x3ff , mantb = b & 0x3ff ;
600+ if (signa != signb) {
601+ // opposite signs
602+ return a > b;
603+ }
604+ else if (signa > 0 ) {
605+ // both -ve
606+ if (expa != expb) {
607+ return expa > expb;
608+ }
609+ else {
610+ return manta > mantb;
611+ }
612+ }
613+ else {
614+ // both +ve
615+ if (expa != expb) {
616+ return expa < expb;
617+ }
618+ else {
619+ return manta < mantb;
620+ }
621+ }
622+
623+ // return npy_half_to_float(a) < npy_half_to_float(b);
624+ }
625+
456626template <typename vtype, typename type_t >
457627static void
458628qsort_16bit_ (type_t *arr, int64_t left, int64_t right, int64_t max_iters)
@@ -461,7 +631,7 @@ qsort_16bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters)
461631 * Resort to std::sort if quicksort isnt making any progress
462632 */
463633 if (max_iters <= 0 ) {
464- std::sort (arr + left, arr + right + 1 );
634+ std::sort (arr + left, arr + right + 1 , comparison_func<vtype> );
465635 return ;
466636 }
467637 /*
@@ -483,12 +653,39 @@ qsort_16bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters)
483653 qsort_16bit_<vtype>(arr, pivot_index, right, max_iters - 1 );
484654}
485655
656+ X86_SIMD_SORT_FINLINE int64_t replace_nan_with_inf (uint16_t *arr, int64_t arrsize)
657+ {
658+ int64_t nan_count = 0 ;
659+ __mmask16 loadmask = 0xFFFF ;
660+ while (arrsize > 0 ) {
661+ if (arrsize < 16 ) { loadmask = (0x0001 << arrsize) - 0x0001 ; }
662+ __m256i in_zmm = _mm256_maskz_loadu_epi16 (loadmask, arr);
663+ __m512 in_zmm_asfloat = _mm512_cvtph_ps (in_zmm);
664+ __mmask16 nanmask = _mm512_cmp_ps_mask (
665+ in_zmm_asfloat, in_zmm_asfloat, _CMP_NEQ_UQ);
666+ nan_count += _mm_popcnt_u32 ((int32_t )nanmask);
667+ _mm256_mask_storeu_epi16 (arr, nanmask, YMM_MAX_HALF);
668+ arr += 16 ;
669+ arrsize -= 16 ;
670+ }
671+ return nan_count;
672+ }
673+
674+ X86_SIMD_SORT_FINLINE void
675+ replace_inf_with_nan (uint16_t *arr, int64_t arrsize, int64_t nan_count)
676+ {
677+ for (int64_t ii = arrsize - 1 ; nan_count > 0 ; --ii) {
678+ arr[ii] = 0xFFFF ;
679+ nan_count -= 1 ;
680+ }
681+ }
682+
486683template <>
487684void avx512_qsort (int16_t *arr, int64_t arrsize)
488685{
489686 if (arrsize > 1 ) {
490687 qsort_16bit_<zmm_vector<int16_t >, int16_t >(
491- arr, 0 , arrsize - 1 , 2 * (63 - __builtin_clzll (arrsize) ));
688+ arr, 0 , arrsize - 1 , 2 * (int64_t ) log2 (arrsize));
492689 }
493690}
494691
@@ -497,7 +694,17 @@ void avx512_qsort(uint16_t *arr, int64_t arrsize)
497694{
498695 if (arrsize > 1 ) {
499696 qsort_16bit_<zmm_vector<uint16_t >, uint16_t >(
500- arr, 0 , arrsize - 1 , 2 * (63 - __builtin_clzll (arrsize)));
697+ arr, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
698+ }
699+ }
700+
701+ void avx512_qsort_fp16 (uint16_t *arr, int64_t arrsize)
702+ {
703+ if (arrsize > 1 ) {
704+ int64_t nan_count = replace_nan_with_inf (arr, arrsize);
705+ qsort_16bit_<zmm_vector<float16>, uint16_t >(
706+ arr, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
707+ replace_inf_with_nan (arr, arrsize, nan_count);
501708 }
502709}
503710#endif // AVX512_QSORT_16BIT
0 commit comments