Skip to content

Commit 5385d3c

Browse files
committed
Preserve NANs in the array for QSORT
1 parent 1473250 commit 5385d3c

5 files changed

+62
-112
lines changed

src/avx512-16bit-qsort.hpp

+12-36
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,12 @@ struct zmm_vector<uint16_t> {
349349
}
350350
};
351351

352+
template <>
353+
bool is_a_nan<uint16_t>(uint16_t elem)
354+
{
355+
return (elem & 0x7c00) == 0x7c00;
356+
}
357+
352358
template <>
353359
bool comparison_func<zmm_vector<float16>>(const uint16_t &a, const uint16_t &b)
354360
{
@@ -377,34 +383,6 @@ bool comparison_func<zmm_vector<float16>>(const uint16_t &a, const uint16_t &b)
377383
//return npy_half_to_float(a) < npy_half_to_float(b);
378384
}
379385

380-
X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(uint16_t *arr,
381-
int64_t arrsize)
382-
{
383-
int64_t nan_count = 0;
384-
__mmask16 loadmask = 0xFFFF;
385-
while (arrsize > 0) {
386-
if (arrsize < 16) { loadmask = (0x0001 << arrsize) - 0x0001; }
387-
__m256i in_zmm = _mm256_maskz_loadu_epi16(loadmask, arr);
388-
__m512 in_zmm_asfloat = _mm512_cvtph_ps(in_zmm);
389-
__mmask16 nanmask = _mm512_cmp_ps_mask(
390-
in_zmm_asfloat, in_zmm_asfloat, _CMP_NEQ_UQ);
391-
nan_count += _mm_popcnt_u32((int32_t)nanmask);
392-
_mm256_mask_storeu_epi16(arr, nanmask, YMM_MAX_HALF);
393-
arr += 16;
394-
arrsize -= 16;
395-
}
396-
return nan_count;
397-
}
398-
399-
X86_SIMD_SORT_INLINE void
400-
replace_inf_with_nan(uint16_t *arr, int64_t arrsize, int64_t nan_count)
401-
{
402-
for (int64_t ii = arrsize - 1; nan_count > 0; --ii) {
403-
arr[ii] = 0xFFFF;
404-
nan_count -= 1;
405-
}
406-
}
407-
408386
template <>
409387
void avx512_qselect(int16_t *arr, int64_t k, int64_t arrsize)
410388
{
@@ -425,11 +403,10 @@ void avx512_qselect(uint16_t *arr, int64_t k, int64_t arrsize)
425403

426404
void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize)
427405
{
428-
if (arrsize > 1) {
429-
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
406+
int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize);
407+
if (indx_last_elem >= k) {
430408
qselect_16bit_<zmm_vector<float16>, uint16_t>(
431-
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
432-
replace_inf_with_nan(arr, arrsize, nan_count);
409+
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
433410
}
434411
}
435412

@@ -453,11 +430,10 @@ void avx512_qsort(uint16_t *arr, int64_t arrsize)
453430

454431
void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize)
455432
{
456-
if (arrsize > 1) {
457-
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
433+
int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize);
434+
if (indx_last_elem > 0) {
458435
qsort_16bit_<zmm_vector<float16>, uint16_t>(
459-
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
460-
replace_inf_with_nan(arr, arrsize, nan_count);
436+
arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
461437
}
462438
}
463439

src/avx512-32bit-qsort.hpp

+6-33
Original file line numberDiff line numberDiff line change
@@ -689,31 +689,6 @@ static void qselect_32bit_(type_t *arr,
689689
qselect_32bit_<vtype>(arr, pos, pivot_index, right, max_iters - 1);
690690
}
691691

692-
X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(float *arr, int64_t arrsize)
693-
{
694-
int64_t nan_count = 0;
695-
__mmask16 loadmask = 0xFFFF;
696-
while (arrsize > 0) {
697-
if (arrsize < 16) { loadmask = (0x0001 << arrsize) - 0x0001; }
698-
__m512 in_zmm = _mm512_maskz_loadu_ps(loadmask, arr);
699-
__mmask16 nanmask = _mm512_cmp_ps_mask(in_zmm, in_zmm, _CMP_NEQ_UQ);
700-
nan_count += _mm_popcnt_u32((int32_t)nanmask);
701-
_mm512_mask_storeu_ps(arr, nanmask, ZMM_MAX_FLOAT);
702-
arr += 16;
703-
arrsize -= 16;
704-
}
705-
return nan_count;
706-
}
707-
708-
X86_SIMD_SORT_INLINE void
709-
replace_inf_with_nan(float *arr, int64_t arrsize, int64_t nan_count)
710-
{
711-
for (int64_t ii = arrsize - 1; nan_count > 0; --ii) {
712-
arr[ii] = std::nanf("1");
713-
nan_count -= 1;
714-
}
715-
}
716-
717692
template <>
718693
void avx512_qselect<int32_t>(int32_t *arr, int64_t k, int64_t arrsize)
719694
{
@@ -735,11 +710,10 @@ void avx512_qselect<uint32_t>(uint32_t *arr, int64_t k, int64_t arrsize)
735710
template <>
736711
void avx512_qselect<float>(float *arr, int64_t k, int64_t arrsize)
737712
{
738-
if (arrsize > 1) {
739-
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
713+
int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize);
714+
if (indx_last_elem >= k) {
740715
qselect_32bit_<zmm_vector<float>, float>(
741-
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
742-
replace_inf_with_nan(arr, arrsize, nan_count);
716+
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
743717
}
744718
}
745719

@@ -764,11 +738,10 @@ void avx512_qsort<uint32_t>(uint32_t *arr, int64_t arrsize)
764738
template <>
765739
void avx512_qsort<float>(float *arr, int64_t arrsize)
766740
{
767-
if (arrsize > 1) {
768-
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
741+
int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize);
742+
if (indx_last_elem > 0) {
769743
qsort_32bit_<zmm_vector<float>, float>(
770-
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
771-
replace_inf_with_nan(arr, arrsize, nan_count);
744+
arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
772745
}
773746
}
774747

src/avx512-64bit-qsort.hpp

+6-8
Original file line numberDiff line numberDiff line change
@@ -804,11 +804,10 @@ void avx512_qselect<uint64_t>(uint64_t *arr, int64_t k, int64_t arrsize)
804804
template <>
805805
void avx512_qselect<double>(double *arr, int64_t k, int64_t arrsize)
806806
{
807-
if (arrsize > 1) {
808-
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
807+
int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize);
808+
if (indx_last_elem >= k) {
809809
qselect_64bit_<zmm_vector<double>, double>(
810-
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
811-
replace_inf_with_nan(arr, arrsize, nan_count);
810+
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
812811
}
813812
}
814813

@@ -833,11 +832,10 @@ void avx512_qsort<uint64_t>(uint64_t *arr, int64_t arrsize)
833832
template <>
834833
void avx512_qsort<double>(double *arr, int64_t arrsize)
835834
{
836-
if (arrsize > 1) {
837-
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
835+
int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize);
836+
if (indx_last_elem > 0) {
838837
qsort_64bit_<zmm_vector<double>, double>(
839-
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
840-
replace_inf_with_nan(arr, arrsize, nan_count);
838+
arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
841839
}
842840
}
843841
#endif // AVX512_QSORT_64BIT

src/avx512-common-qsort.h

+27
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,33 @@ inline void avx512_partial_qsort_fp16(uint16_t *arr, int64_t k, int64_t arrsize)
116116
template <typename T>
117117
void avx512_qsort_kv(T *keys, uint64_t *indexes, int64_t arrsize);
118118

119+
template <typename T>
120+
bool is_a_nan(T elem)
121+
{
122+
return std::isnan(elem);
123+
}
124+
125+
/*
126+
* Sort all the NAN's to end of the array and return the index of the last elem
127+
* in the array which is not a nan
128+
*/
129+
template <typename T>
130+
int64_t put_nans_at_end_of_array(T* arr, int64_t arrsize)
131+
{
132+
int64_t jj = arrsize - 1;
133+
int64_t ii = 0;
134+
while (ii < jj) {
135+
if (is_a_nan(arr[ii])) {
136+
std::swap(arr[ii], arr[jj]);
137+
jj -= 1;
138+
}
139+
else {
140+
ii += 1;
141+
}
142+
}
143+
return ii;
144+
}
145+
119146
template <typename vtype, typename T = typename vtype::type_t>
120147
bool comparison_func(const T &a, const T &b)
121148
{

src/avx512fp16-16bit-qsort.hpp

+11-35
Original file line numberDiff line numberDiff line change
@@ -114,55 +114,31 @@ struct zmm_vector<_Float16> {
114114
}
115115
};
116116

117-
X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(_Float16 *arr,
118-
int64_t arrsize)
119-
{
120-
int64_t nan_count = 0;
121-
__mmask32 loadmask = 0xFFFFFFFF;
122-
__m512h in_zmm;
123-
while (arrsize > 0) {
124-
if (arrsize < 32) {
125-
loadmask = (0x00000001 << arrsize) - 0x00000001;
126-
in_zmm = _mm512_castsi512_ph(
127-
_mm512_maskz_loadu_epi16(loadmask, arr));
128-
}
129-
else {
130-
in_zmm = _mm512_loadu_ph(arr);
131-
}
132-
__mmask32 nanmask = _mm512_cmp_ph_mask(in_zmm, in_zmm, _CMP_NEQ_UQ);
133-
nan_count += _mm_popcnt_u32((int32_t)nanmask);
134-
_mm512_mask_storeu_epi16(arr, nanmask, ZMM_MAX_HALF);
135-
arr += 32;
136-
arrsize -= 32;
137-
}
138-
return nan_count;
139-
}
140-
141-
X86_SIMD_SORT_INLINE void
142-
replace_inf_with_nan(_Float16 *arr, int64_t arrsize, int64_t nan_count)
117+
template <>
118+
bool is_a_nan<_Float16>(_Float16 elem)
143119
{
144-
memset(arr + arrsize - nan_count, 0xFF, nan_count * 2);
120+
Fp16Bits temp;
121+
temp.f_ = elem;
122+
return (temp.i_ & 0x7c00) == 0x7c00;
145123
}
146124

147125
template <>
148126
void avx512_qselect(_Float16 *arr, int64_t k, int64_t arrsize)
149127
{
150-
if (arrsize > 1) {
151-
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
128+
int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize);
129+
if (indx_last_elem >= k) {
152130
qselect_16bit_<zmm_vector<_Float16>, _Float16>(
153-
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
154-
replace_inf_with_nan(arr, arrsize, nan_count);
131+
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
155132
}
156133
}
157134

158135
template <>
159136
void avx512_qsort(_Float16 *arr, int64_t arrsize)
160137
{
161-
if (arrsize > 1) {
162-
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
138+
int64_t indx_last_elem = put_nans_at_end_of_array(arr, arrsize);
139+
if (indx_last_elem > 0) {
163140
qsort_16bit_<zmm_vector<_Float16>, _Float16>(
164-
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
165-
replace_inf_with_nan(arr, arrsize, nan_count);
141+
arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
166142
}
167143
}
168144
#endif // AVX512FP16_QSORT_16BIT

0 commit comments

Comments
 (0)