@@ -64,25 +64,33 @@ struct zmm_vector<float16> {
6464
6565 static opmask_t ge (zmm_t x, zmm_t y)
6666 {
67- zmm_t sign_x = _mm512_and_si512 (x, _mm512_set1_epi16 (0x8000 ));
68- zmm_t sign_y = _mm512_and_si512 (y, _mm512_set1_epi16 (0x8000 ));
69- zmm_t exp_x = _mm512_and_si512 (x, _mm512_set1_epi16 (0x7c00 ));
70- zmm_t exp_y = _mm512_and_si512 (y, _mm512_set1_epi16 (0x7c00 ));
71- zmm_t mant_x = _mm512_and_si512 (x, _mm512_set1_epi16 (0x3ff ));
72- zmm_t mant_y = _mm512_and_si512 (y, _mm512_set1_epi16 (0x3ff ));
73-
74- __mmask32 mask_ge = _mm512_cmp_epu16_mask (sign_x, sign_y, _MM_CMPINT_LT); // only greater than
75- __mmask32 sign_eq = _mm512_cmpeq_epu16_mask (sign_x, sign_y);
76- __mmask32 neg = _mm512_mask_cmpeq_epu16_mask (sign_eq, sign_x, _mm512_set1_epi16 (0x8000 )); // both numbers are -ve
77-
78- // compare exponents only if signs are equal:
79- mask_ge = mask_ge | _mm512_mask_cmp_epu16_mask (sign_eq, exp_x, exp_y, _MM_CMPINT_NLE);
80- // get mask for elements for which both sign and exponents are equal:
81- __mmask32 exp_eq = _mm512_mask_cmpeq_epu16_mask (sign_eq, exp_x, exp_y);
82-
83- // compare mantissa for elements for which both sign and expponent are equal:
84- mask_ge = mask_ge | _mm512_mask_cmp_epu16_mask (exp_eq, mant_x, mant_y, _MM_CMPINT_NLT);
85- return _kxor_mask32 (mask_ge, neg);
67+ zmm_t sign_x = _mm512_and_si512 (x, _mm512_set1_epi16 (0x8000 ));
68+ zmm_t sign_y = _mm512_and_si512 (y, _mm512_set1_epi16 (0x8000 ));
69+ zmm_t exp_x = _mm512_and_si512 (x, _mm512_set1_epi16 (0x7c00 ));
70+ zmm_t exp_y = _mm512_and_si512 (y, _mm512_set1_epi16 (0x7c00 ));
71+ zmm_t mant_x = _mm512_and_si512 (x, _mm512_set1_epi16 (0x3ff ));
72+ zmm_t mant_y = _mm512_and_si512 (y, _mm512_set1_epi16 (0x3ff ));
73+
74+ __mmask32 mask_ge = _mm512_cmp_epu16_mask (
75+ sign_x, sign_y, _MM_CMPINT_LT); // only greater than
76+ __mmask32 sign_eq = _mm512_cmpeq_epu16_mask (sign_x, sign_y);
77+ __mmask32 neg = _mm512_mask_cmpeq_epu16_mask (
78+ sign_eq,
79+ sign_x,
80+ _mm512_set1_epi16 (0x8000 )); // both numbers are -ve
81+
82+ // compare exponents only if signs are equal:
83+ mask_ge = mask_ge
84+ | _mm512_mask_cmp_epu16_mask (
85+ sign_eq, exp_x, exp_y, _MM_CMPINT_NLE);
86+ // get mask for elements for which both sign and exponents are equal:
87+ __mmask32 exp_eq = _mm512_mask_cmpeq_epu16_mask (sign_eq, exp_x, exp_y);
88+
89+ // compare mantissa for elements for which both sign and expponent are equal:
90+ mask_ge = mask_ge
91+ | _mm512_mask_cmp_epu16_mask (
92+ exp_eq, mant_x, mant_y, _MM_CMPINT_NLT);
93+ return _kxor_mask32 (mask_ge, neg);
8694 }
8795 static zmm_t loadu (void const *mem)
8896 {
@@ -549,8 +557,8 @@ X86_SIMD_SORT_FINLINE void sort_128_16bit(type_t *arr, int32_t N)
549557
550558template <typename vtype, typename type_t >
551559X86_SIMD_SORT_FINLINE type_t get_pivot_16bit (type_t *arr,
552- const int64_t left,
553- const int64_t right)
560+ const int64_t left,
561+ const int64_t right)
554562{
555563 // median of 32
556564 int64_t size = (right - left) / 32 ;
@@ -598,26 +606,22 @@ bool comparison_func<zmm_vector<float16>>(const uint16_t &a, const uint16_t &b)
598606 uint16_t expa = a & 0x7c00 , expb = b & 0x7c00 ;
599607 uint16_t manta = a & 0x3ff , mantb = b & 0x3ff ;
600608 if (signa != signb) {
601- // opposite signs
602- return a > b;
609+ // opposite signs
610+ return a > b;
603611 }
604612 else if (signa > 0 ) {
605- // both -ve
606- if (expa != expb) {
607- return expa > expb;
608- }
609- else {
610- return manta > mantb;
611- }
613+ // both -ve
614+ if (expa != expb) { return expa > expb; }
615+ else {
616+ return manta > mantb;
617+ }
612618 }
613619 else {
614- // both +ve
615- if (expa != expb) {
616- return expa < expb;
617- }
618- else {
619- return manta < mantb;
620- }
620+ // both +ve
621+ if (expa != expb) { return expa < expb; }
622+ else {
623+ return manta < mantb;
624+ }
621625 }
622626
623627 // return npy_half_to_float(a) < npy_half_to_float(b);
@@ -653,7 +657,8 @@ qsort_16bit_(type_t *arr, int64_t left, int64_t right, int64_t max_iters)
653657 qsort_16bit_<vtype>(arr, pivot_index, right, max_iters - 1 );
654658}
655659
656- X86_SIMD_SORT_FINLINE int64_t replace_nan_with_inf (uint16_t *arr, int64_t arrsize)
660+ X86_SIMD_SORT_FINLINE int64_t replace_nan_with_inf (uint16_t *arr,
661+ int64_t arrsize)
657662{
658663 int64_t nan_count = 0 ;
659664 __mmask16 loadmask = 0xFFFF ;
0 commit comments