Skip to content

Commit 85fbe7d

Browse files
authored
Merge pull request #47 from r-devulap/dont-use-qsort-nans
Revert PR #45
2 parents 3c07df9 + 0b2d89e commit 85fbe7d

5 files changed

+112
-64
lines changed

src/avx512-16bit-qsort.hpp

+36-12
Original file line numberDiff line numberDiff line change
@@ -349,12 +349,6 @@ struct zmm_vector<uint16_t> {
349349
}
350350
};
351351

352-
template <>
353-
bool is_a_nan<uint16_t>(uint16_t elem)
354-
{
355-
return (elem & 0x7c00) == 0x7c00;
356-
}
357-
358352
template <>
359353
bool comparison_func<zmm_vector<float16>>(const uint16_t &a, const uint16_t &b)
360354
{
@@ -383,6 +377,34 @@ bool comparison_func<zmm_vector<float16>>(const uint16_t &a, const uint16_t &b)
383377
//return npy_half_to_float(a) < npy_half_to_float(b);
384378
}
385379

380+
X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(uint16_t *arr,
381+
int64_t arrsize)
382+
{
383+
int64_t nan_count = 0;
384+
__mmask16 loadmask = 0xFFFF;
385+
while (arrsize > 0) {
386+
if (arrsize < 16) { loadmask = (0x0001 << arrsize) - 0x0001; }
387+
__m256i in_zmm = _mm256_maskz_loadu_epi16(loadmask, arr);
388+
__m512 in_zmm_asfloat = _mm512_cvtph_ps(in_zmm);
389+
__mmask16 nanmask = _mm512_cmp_ps_mask(
390+
in_zmm_asfloat, in_zmm_asfloat, _CMP_NEQ_UQ);
391+
nan_count += _mm_popcnt_u32((int32_t)nanmask);
392+
_mm256_mask_storeu_epi16(arr, nanmask, YMM_MAX_HALF);
393+
arr += 16;
394+
arrsize -= 16;
395+
}
396+
return nan_count;
397+
}
398+
399+
X86_SIMD_SORT_INLINE void
400+
replace_inf_with_nan(uint16_t *arr, int64_t arrsize, int64_t nan_count)
401+
{
402+
for (int64_t ii = arrsize - 1; nan_count > 0; --ii) {
403+
arr[ii] = 0xFFFF;
404+
nan_count -= 1;
405+
}
406+
}
407+
386408
template <>
387409
void avx512_qselect(int16_t *arr, int64_t k, int64_t arrsize)
388410
{
@@ -403,10 +425,11 @@ void avx512_qselect(uint16_t *arr, int64_t k, int64_t arrsize)
403425

404426
void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize)
405427
{
406-
int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
407-
if (indx_last_elem >= k) {
428+
if (arrsize > 1) {
429+
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
408430
qselect_16bit_<zmm_vector<float16>, uint16_t>(
409-
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
431+
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
432+
replace_inf_with_nan(arr, arrsize, nan_count);
410433
}
411434
}
412435

@@ -430,10 +453,11 @@ void avx512_qsort(uint16_t *arr, int64_t arrsize)
430453

431454
void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize)
432455
{
433-
int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
434-
if (indx_last_elem > 0) {
456+
if (arrsize > 1) {
457+
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
435458
qsort_16bit_<zmm_vector<float16>, uint16_t>(
436-
arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
459+
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
460+
replace_inf_with_nan(arr, arrsize, nan_count);
437461
}
438462
}
439463

src/avx512-32bit-qsort.hpp

+33-6
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,31 @@ static void qselect_32bit_(type_t *arr,
689689
qselect_32bit_<vtype>(arr, pos, pivot_index, right, max_iters - 1);
690690
}
691691

692+
X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(float *arr, int64_t arrsize)
693+
{
694+
int64_t nan_count = 0;
695+
__mmask16 loadmask = 0xFFFF;
696+
while (arrsize > 0) {
697+
if (arrsize < 16) { loadmask = (0x0001 << arrsize) - 0x0001; }
698+
__m512 in_zmm = _mm512_maskz_loadu_ps(loadmask, arr);
699+
__mmask16 nanmask = _mm512_cmp_ps_mask(in_zmm, in_zmm, _CMP_NEQ_UQ);
700+
nan_count += _mm_popcnt_u32((int32_t)nanmask);
701+
_mm512_mask_storeu_ps(arr, nanmask, ZMM_MAX_FLOAT);
702+
arr += 16;
703+
arrsize -= 16;
704+
}
705+
return nan_count;
706+
}
707+
708+
X86_SIMD_SORT_INLINE void
709+
replace_inf_with_nan(float *arr, int64_t arrsize, int64_t nan_count)
710+
{
711+
for (int64_t ii = arrsize - 1; nan_count > 0; --ii) {
712+
arr[ii] = std::nanf("1");
713+
nan_count -= 1;
714+
}
715+
}
716+
692717
template <>
693718
void avx512_qselect<int32_t>(int32_t *arr, int64_t k, int64_t arrsize)
694719
{
@@ -710,10 +735,11 @@ void avx512_qselect<uint32_t>(uint32_t *arr, int64_t k, int64_t arrsize)
710735
template <>
711736
void avx512_qselect<float>(float *arr, int64_t k, int64_t arrsize)
712737
{
713-
int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
714-
if (indx_last_elem >= k) {
738+
if (arrsize > 1) {
739+
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
715740
qselect_32bit_<zmm_vector<float>, float>(
716-
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
741+
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
742+
replace_inf_with_nan(arr, arrsize, nan_count);
717743
}
718744
}
719745

@@ -738,10 +764,11 @@ void avx512_qsort<uint32_t>(uint32_t *arr, int64_t arrsize)
738764
template <>
739765
void avx512_qsort<float>(float *arr, int64_t arrsize)
740766
{
741-
int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
742-
if (indx_last_elem > 0) {
767+
if (arrsize > 1) {
768+
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
743769
qsort_32bit_<zmm_vector<float>, float>(
744-
arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
770+
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
771+
replace_inf_with_nan(arr, arrsize, nan_count);
745772
}
746773
}
747774

src/avx512-64bit-qsort.hpp

+8-6
Original file line numberDiff line numberDiff line change
@@ -804,10 +804,11 @@ void avx512_qselect<uint64_t>(uint64_t *arr, int64_t k, int64_t arrsize)
804804
template <>
805805
void avx512_qselect<double>(double *arr, int64_t k, int64_t arrsize)
806806
{
807-
int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
808-
if (indx_last_elem >= k) {
807+
if (arrsize > 1) {
808+
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
809809
qselect_64bit_<zmm_vector<double>, double>(
810-
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
810+
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
811+
replace_inf_with_nan(arr, arrsize, nan_count);
811812
}
812813
}
813814

@@ -832,10 +833,11 @@ void avx512_qsort<uint64_t>(uint64_t *arr, int64_t arrsize)
832833
template <>
833834
void avx512_qsort<double>(double *arr, int64_t arrsize)
834835
{
835-
int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
836-
if (indx_last_elem > 0) {
836+
if (arrsize > 1) {
837+
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
837838
qsort_64bit_<zmm_vector<double>, double>(
838-
arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
839+
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
840+
replace_inf_with_nan(arr, arrsize, nan_count);
839841
}
840842
}
841843
#endif // AVX512_QSORT_64BIT

src/avx512-common-qsort.h

-29
Original file line numberDiff line numberDiff line change
@@ -116,35 +116,6 @@ inline void avx512_partial_qsort_fp16(uint16_t *arr, int64_t k, int64_t arrsize)
116116
template <typename T>
117117
void avx512_qsort_kv(T *keys, uint64_t *indexes, int64_t arrsize);
118118

119-
template <typename T>
120-
bool is_a_nan(T elem)
121-
{
122-
return std::isnan(elem);
123-
}
124-
125-
/*
126-
* Sort all the NAN's to end of the array and return the index of the last elem
127-
* in the array which is not a nan
128-
*/
129-
template <typename T>
130-
int64_t move_nans_to_end_of_array(T* arr, int64_t arrsize)
131-
{
132-
int64_t jj = arrsize - 1;
133-
int64_t ii = 0;
134-
int64_t count = 0;
135-
while (ii <= jj) {
136-
if (is_a_nan(arr[ii])) {
137-
std::swap(arr[ii], arr[jj]);
138-
jj -= 1;
139-
count++;
140-
}
141-
else {
142-
ii += 1;
143-
}
144-
}
145-
return arrsize-count-1;
146-
}
147-
148119
template <typename vtype, typename T = typename vtype::type_t>
149120
bool comparison_func(const T &a, const T &b)
150121
{

src/avx512fp16-16bit-qsort.hpp

+35-11
Original file line numberDiff line numberDiff line change
@@ -114,31 +114,55 @@ struct zmm_vector<_Float16> {
114114
}
115115
};
116116

117-
template <>
118-
bool is_a_nan<_Float16>(_Float16 elem)
117+
X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(_Float16 *arr,
118+
int64_t arrsize)
119+
{
120+
int64_t nan_count = 0;
121+
__mmask32 loadmask = 0xFFFFFFFF;
122+
__m512h in_zmm;
123+
while (arrsize > 0) {
124+
if (arrsize < 32) {
125+
loadmask = (0x00000001 << arrsize) - 0x00000001;
126+
in_zmm = _mm512_castsi512_ph(
127+
_mm512_maskz_loadu_epi16(loadmask, arr));
128+
}
129+
else {
130+
in_zmm = _mm512_loadu_ph(arr);
131+
}
132+
__mmask32 nanmask = _mm512_cmp_ph_mask(in_zmm, in_zmm, _CMP_NEQ_UQ);
133+
nan_count += _mm_popcnt_u32((int32_t)nanmask);
134+
_mm512_mask_storeu_epi16(arr, nanmask, ZMM_MAX_HALF);
135+
arr += 32;
136+
arrsize -= 32;
137+
}
138+
return nan_count;
139+
}
140+
141+
X86_SIMD_SORT_INLINE void
142+
replace_inf_with_nan(_Float16 *arr, int64_t arrsize, int64_t nan_count)
119143
{
120-
Fp16Bits temp;
121-
temp.f_ = elem;
122-
return (temp.i_ & 0x7c00) == 0x7c00;
144+
memset(arr + arrsize - nan_count, 0xFF, nan_count * 2);
123145
}
124146

125147
template <>
126148
void avx512_qselect(_Float16 *arr, int64_t k, int64_t arrsize)
127149
{
128-
int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
129-
if (indx_last_elem >= k) {
150+
if (arrsize > 1) {
151+
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
130152
qselect_16bit_<zmm_vector<_Float16>, _Float16>(
131-
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
153+
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
154+
replace_inf_with_nan(arr, arrsize, nan_count);
132155
}
133156
}
134157

135158
template <>
136159
void avx512_qsort(_Float16 *arr, int64_t arrsize)
137160
{
138-
int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
139-
if (indx_last_elem > 0) {
161+
if (arrsize > 1) {
162+
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
140163
qsort_16bit_<zmm_vector<_Float16>, _Float16>(
141-
arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
164+
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
165+
replace_inf_with_nan(arr, arrsize, nan_count);
142166
}
143167
}
144168
#endif // AVX512FP16_QSORT_16BIT

0 commit comments

Comments
 (0)