Skip to content

Commit 3c07df9

Browse files
authored
Merge pull request #45 from r-devulap/nan-fix
Preserve NAN's in avx512_qsort
2 parents 1473250 + 880f5a2 commit 3c07df9

7 files changed

+116
-112
lines changed

src/avx512-16bit-qsort.hpp

+12-36
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,12 @@ struct zmm_vector<uint16_t> {
349349
}
350350
};
351351

352+
template <>
353+
bool is_a_nan<uint16_t>(uint16_t elem)
354+
{
355+
return (elem & 0x7c00) == 0x7c00;
356+
}
357+
352358
template <>
353359
bool comparison_func<zmm_vector<float16>>(const uint16_t &a, const uint16_t &b)
354360
{
@@ -377,34 +383,6 @@ bool comparison_func<zmm_vector<float16>>(const uint16_t &a, const uint16_t &b)
377383
//return npy_half_to_float(a) < npy_half_to_float(b);
378384
}
379385

380-
X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(uint16_t *arr,
381-
int64_t arrsize)
382-
{
383-
int64_t nan_count = 0;
384-
__mmask16 loadmask = 0xFFFF;
385-
while (arrsize > 0) {
386-
if (arrsize < 16) { loadmask = (0x0001 << arrsize) - 0x0001; }
387-
__m256i in_zmm = _mm256_maskz_loadu_epi16(loadmask, arr);
388-
__m512 in_zmm_asfloat = _mm512_cvtph_ps(in_zmm);
389-
__mmask16 nanmask = _mm512_cmp_ps_mask(
390-
in_zmm_asfloat, in_zmm_asfloat, _CMP_NEQ_UQ);
391-
nan_count += _mm_popcnt_u32((int32_t)nanmask);
392-
_mm256_mask_storeu_epi16(arr, nanmask, YMM_MAX_HALF);
393-
arr += 16;
394-
arrsize -= 16;
395-
}
396-
return nan_count;
397-
}
398-
399-
X86_SIMD_SORT_INLINE void
400-
replace_inf_with_nan(uint16_t *arr, int64_t arrsize, int64_t nan_count)
401-
{
402-
for (int64_t ii = arrsize - 1; nan_count > 0; --ii) {
403-
arr[ii] = 0xFFFF;
404-
nan_count -= 1;
405-
}
406-
}
407-
408386
template <>
409387
void avx512_qselect(int16_t *arr, int64_t k, int64_t arrsize)
410388
{
@@ -425,11 +403,10 @@ void avx512_qselect(uint16_t *arr, int64_t k, int64_t arrsize)
425403

426404
void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize)
427405
{
428-
if (arrsize > 1) {
429-
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
406+
int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
407+
if (indx_last_elem >= k) {
430408
qselect_16bit_<zmm_vector<float16>, uint16_t>(
431-
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
432-
replace_inf_with_nan(arr, arrsize, nan_count);
409+
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
433410
}
434411
}
435412

@@ -453,11 +430,10 @@ void avx512_qsort(uint16_t *arr, int64_t arrsize)
453430

454431
void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize)
455432
{
456-
if (arrsize > 1) {
457-
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
433+
int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
434+
if (indx_last_elem > 0) {
458435
qsort_16bit_<zmm_vector<float16>, uint16_t>(
459-
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
460-
replace_inf_with_nan(arr, arrsize, nan_count);
436+
arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
461437
}
462438
}
463439

src/avx512-32bit-qsort.hpp

+6-33
Original file line numberDiff line numberDiff line change
@@ -689,31 +689,6 @@ static void qselect_32bit_(type_t *arr,
689689
qselect_32bit_<vtype>(arr, pos, pivot_index, right, max_iters - 1);
690690
}
691691

692-
X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(float *arr, int64_t arrsize)
693-
{
694-
int64_t nan_count = 0;
695-
__mmask16 loadmask = 0xFFFF;
696-
while (arrsize > 0) {
697-
if (arrsize < 16) { loadmask = (0x0001 << arrsize) - 0x0001; }
698-
__m512 in_zmm = _mm512_maskz_loadu_ps(loadmask, arr);
699-
__mmask16 nanmask = _mm512_cmp_ps_mask(in_zmm, in_zmm, _CMP_NEQ_UQ);
700-
nan_count += _mm_popcnt_u32((int32_t)nanmask);
701-
_mm512_mask_storeu_ps(arr, nanmask, ZMM_MAX_FLOAT);
702-
arr += 16;
703-
arrsize -= 16;
704-
}
705-
return nan_count;
706-
}
707-
708-
X86_SIMD_SORT_INLINE void
709-
replace_inf_with_nan(float *arr, int64_t arrsize, int64_t nan_count)
710-
{
711-
for (int64_t ii = arrsize - 1; nan_count > 0; --ii) {
712-
arr[ii] = std::nanf("1");
713-
nan_count -= 1;
714-
}
715-
}
716-
717692
template <>
718693
void avx512_qselect<int32_t>(int32_t *arr, int64_t k, int64_t arrsize)
719694
{
@@ -735,11 +710,10 @@ void avx512_qselect<uint32_t>(uint32_t *arr, int64_t k, int64_t arrsize)
735710
template <>
736711
void avx512_qselect<float>(float *arr, int64_t k, int64_t arrsize)
737712
{
738-
if (arrsize > 1) {
739-
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
713+
int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
714+
if (indx_last_elem >= k) {
740715
qselect_32bit_<zmm_vector<float>, float>(
741-
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
742-
replace_inf_with_nan(arr, arrsize, nan_count);
716+
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
743717
}
744718
}
745719

@@ -764,11 +738,10 @@ void avx512_qsort<uint32_t>(uint32_t *arr, int64_t arrsize)
764738
template <>
765739
void avx512_qsort<float>(float *arr, int64_t arrsize)
766740
{
767-
if (arrsize > 1) {
768-
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
741+
int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
742+
if (indx_last_elem > 0) {
769743
qsort_32bit_<zmm_vector<float>, float>(
770-
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
771-
replace_inf_with_nan(arr, arrsize, nan_count);
744+
arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
772745
}
773746
}
774747

src/avx512-64bit-qsort.hpp

+6-8
Original file line numberDiff line numberDiff line change
@@ -804,11 +804,10 @@ void avx512_qselect<uint64_t>(uint64_t *arr, int64_t k, int64_t arrsize)
804804
template <>
805805
void avx512_qselect<double>(double *arr, int64_t k, int64_t arrsize)
806806
{
807-
if (arrsize > 1) {
808-
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
807+
int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
808+
if (indx_last_elem >= k) {
809809
qselect_64bit_<zmm_vector<double>, double>(
810-
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
811-
replace_inf_with_nan(arr, arrsize, nan_count);
810+
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
812811
}
813812
}
814813

@@ -833,11 +832,10 @@ void avx512_qsort<uint64_t>(uint64_t *arr, int64_t arrsize)
833832
template <>
834833
void avx512_qsort<double>(double *arr, int64_t arrsize)
835834
{
836-
if (arrsize > 1) {
837-
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
835+
int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
836+
if (indx_last_elem > 0) {
838837
qsort_64bit_<zmm_vector<double>, double>(
839-
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
840-
replace_inf_with_nan(arr, arrsize, nan_count);
838+
arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
841839
}
842840
}
843841
#endif // AVX512_QSORT_64BIT

src/avx512-common-qsort.h

+29
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,35 @@ inline void avx512_partial_qsort_fp16(uint16_t *arr, int64_t k, int64_t arrsize)
116116
template <typename T>
117117
void avx512_qsort_kv(T *keys, uint64_t *indexes, int64_t arrsize);
118118

119+
template <typename T>
120+
bool is_a_nan(T elem)
121+
{
122+
return std::isnan(elem);
123+
}
124+
125+
/*
126+
* Sort all the NAN's to end of the array and return the index of the last elem
127+
* in the array which is not a nan
128+
*/
129+
template <typename T>
130+
int64_t move_nans_to_end_of_array(T* arr, int64_t arrsize)
131+
{
132+
int64_t jj = arrsize - 1;
133+
int64_t ii = 0;
134+
int64_t count = 0;
135+
while (ii <= jj) {
136+
if (is_a_nan(arr[ii])) {
137+
std::swap(arr[ii], arr[jj]);
138+
jj -= 1;
139+
count++;
140+
}
141+
else {
142+
ii += 1;
143+
}
144+
}
145+
return arrsize-count-1;
146+
}
147+
119148
template <typename vtype, typename T = typename vtype::type_t>
120149
bool comparison_func(const T &a, const T &b)
121150
{

src/avx512fp16-16bit-qsort.hpp

+11-35
Original file line numberDiff line numberDiff line change
@@ -114,55 +114,31 @@ struct zmm_vector<_Float16> {
114114
}
115115
};
116116

117-
X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf(_Float16 *arr,
118-
int64_t arrsize)
119-
{
120-
int64_t nan_count = 0;
121-
__mmask32 loadmask = 0xFFFFFFFF;
122-
__m512h in_zmm;
123-
while (arrsize > 0) {
124-
if (arrsize < 32) {
125-
loadmask = (0x00000001 << arrsize) - 0x00000001;
126-
in_zmm = _mm512_castsi512_ph(
127-
_mm512_maskz_loadu_epi16(loadmask, arr));
128-
}
129-
else {
130-
in_zmm = _mm512_loadu_ph(arr);
131-
}
132-
__mmask32 nanmask = _mm512_cmp_ph_mask(in_zmm, in_zmm, _CMP_NEQ_UQ);
133-
nan_count += _mm_popcnt_u32((int32_t)nanmask);
134-
_mm512_mask_storeu_epi16(arr, nanmask, ZMM_MAX_HALF);
135-
arr += 32;
136-
arrsize -= 32;
137-
}
138-
return nan_count;
139-
}
140-
141-
X86_SIMD_SORT_INLINE void
142-
replace_inf_with_nan(_Float16 *arr, int64_t arrsize, int64_t nan_count)
117+
template <>
118+
bool is_a_nan<_Float16>(_Float16 elem)
143119
{
144-
memset(arr + arrsize - nan_count, 0xFF, nan_count * 2);
120+
Fp16Bits temp;
121+
temp.f_ = elem;
122+
return (temp.i_ & 0x7c00) == 0x7c00;
145123
}
146124

147125
template <>
148126
void avx512_qselect(_Float16 *arr, int64_t k, int64_t arrsize)
149127
{
150-
if (arrsize > 1) {
151-
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
128+
int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
129+
if (indx_last_elem >= k) {
152130
qselect_16bit_<zmm_vector<_Float16>, _Float16>(
153-
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
154-
replace_inf_with_nan(arr, arrsize, nan_count);
131+
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
155132
}
156133
}
157134

158135
template <>
159136
void avx512_qsort(_Float16 *arr, int64_t arrsize)
160137
{
161-
if (arrsize > 1) {
162-
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
138+
int64_t indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
139+
if (indx_last_elem > 0) {
163140
qsort_16bit_<zmm_vector<_Float16>, _Float16>(
164-
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
165-
replace_inf_with_nan(arr, arrsize, nan_count);
141+
arr, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
166142
}
167143
}
168144
#endif // AVX512FP16_QSORT_16BIT

tests/test-qsort-fp.hpp

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*******************************************
2+
* * Copyright (C) 2022 Intel Corporation
3+
* * SPDX-License-Identifier: BSD-3-Clause
4+
* *******************************************/
5+
6+
#include "test-qsort-common.h"
7+
8+
template <typename T>
9+
class avx512_sort_fp : public ::testing::Test {
10+
};
11+
TYPED_TEST_SUITE_P(avx512_sort_fp);
12+
13+
TYPED_TEST_P(avx512_sort_fp, test_random_nan)
14+
{
15+
const int num_nans = 3;
16+
if (!cpu_has_avx512bw()) {
17+
GTEST_SKIP() << "Skipping this test, it requires avx512bw";
18+
}
19+
std::vector<int64_t> arrsizes;
20+
for (int64_t ii = num_nans; ii < 1024; ++ii) {
21+
arrsizes.push_back((TypeParam)ii);
22+
}
23+
std::vector<TypeParam> arr;
24+
std::vector<TypeParam> sortedarr;
25+
for (auto &size : arrsizes) {
26+
/* Random array */
27+
arr = get_uniform_rand_array<TypeParam>(size);
28+
for (auto ii = 1; ii <= num_nans; ++ii) {
29+
arr[size-ii] = std::numeric_limits<TypeParam>::quiet_NaN();
30+
}
31+
sortedarr = arr;
32+
std::sort(sortedarr.begin(), sortedarr.end()-3);
33+
std::random_shuffle(arr.begin(), arr.end());
34+
avx512_qsort<TypeParam>(arr.data(), arr.size());
35+
for (auto ii = 1; ii <= num_nans; ++ii) {
36+
if (!std::isnan(arr[size-ii])) {
37+
ASSERT_TRUE(false) << "NAN's aren't sorted to the end. Arr size = " << size;
38+
}
39+
}
40+
if (!std::is_sorted(arr.begin(), arr.end() - num_nans)) {
41+
ASSERT_TRUE(true) << "Array isn't sorted";
42+
}
43+
arr.clear();
44+
sortedarr.clear();
45+
}
46+
}
47+
REGISTER_TYPED_TEST_SUITE_P(avx512_sort_fp, test_random_nan);

tests/test-qsort.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "test-partial-qsort.hpp"
22
#include "test-qselect.hpp"
3+
#include "test-qsort-fp.hpp"
34
#include "test-qsort.hpp"
45

56
using QSortTestTypes = testing::Types<uint16_t,
@@ -10,6 +11,10 @@ using QSortTestTypes = testing::Types<uint16_t,
1011
int32_t,
1112
uint64_t,
1213
int64_t>;
14+
15+
using QSortTestFPTypes = testing::Types<float, double>;
16+
1317
INSTANTIATE_TYPED_TEST_SUITE_P(T, avx512_sort, QSortTestTypes);
18+
INSTANTIATE_TYPED_TEST_SUITE_P(T, avx512_sort_fp, QSortTestFPTypes);
1419
INSTANTIATE_TYPED_TEST_SUITE_P(T, avx512_select, QSortTestTypes);
1520
INSTANTIATE_TYPED_TEST_SUITE_P(T, avx512_partial_sort, QSortTestTypes);

0 commit comments

Comments
 (0)