Skip to content

Commit a4e57cb

Browse files
authored
Merge pull request #49 from r-devulap/nan
Add an optional argument "hasnan" for qselect and partialsort
2 parents c6ddd9c + 9783a9b commit a4e57cb

5 files changed

+86
-32
lines changed

src/avx512-16bit-qsort.hpp

+15-7
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,13 @@ replace_inf_with_nan(uint16_t *arr, int64_t arrsize, int64_t nan_count)
406406
}
407407

408408
template <>
409-
void avx512_qselect(int16_t *arr, int64_t k, int64_t arrsize)
409+
bool is_a_nan<uint16_t>(uint16_t elem)
410+
{
411+
return (elem & 0x7c00) == 0x7c00;
412+
}
413+
414+
template <>
415+
void avx512_qselect(int16_t *arr, int64_t k, int64_t arrsize, bool hasnan)
410416
{
411417
if (arrsize > 1) {
412418
qselect_16bit_<zmm_vector<int16_t>, int16_t>(
@@ -415,21 +421,23 @@ void avx512_qselect(int16_t *arr, int64_t k, int64_t arrsize)
415421
}
416422

417423
template <>
418-
void avx512_qselect(uint16_t *arr, int64_t k, int64_t arrsize)
424+
void avx512_qselect(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan)
419425
{
420426
if (arrsize > 1) {
421427
qselect_16bit_<zmm_vector<uint16_t>, uint16_t>(
422428
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
423429
}
424430
}
425431

426-
void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize)
432+
void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan)
427433
{
428-
if (arrsize > 1) {
429-
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
434+
int64_t indx_last_elem = arrsize - 1;
435+
if (UNLIKELY(hasnan)) {
436+
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
437+
}
438+
if (indx_last_elem >= k) {
430439
qselect_16bit_<zmm_vector<float16>, uint16_t>(
431-
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
432-
replace_inf_with_nan(arr, arrsize, nan_count);
440+
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
433441
}
434442
}
435443

src/avx512-32bit-qsort.hpp

+9-7
Original file line numberDiff line numberDiff line change
@@ -715,7 +715,7 @@ replace_inf_with_nan(float *arr, int64_t arrsize, int64_t nan_count)
715715
}
716716

717717
template <>
718-
void avx512_qselect<int32_t>(int32_t *arr, int64_t k, int64_t arrsize)
718+
void avx512_qselect<int32_t>(int32_t *arr, int64_t k, int64_t arrsize, bool hasnan)
719719
{
720720
if (arrsize > 1) {
721721
qselect_32bit_<zmm_vector<int32_t>, int32_t>(
@@ -724,7 +724,7 @@ void avx512_qselect<int32_t>(int32_t *arr, int64_t k, int64_t arrsize)
724724
}
725725

726726
template <>
727-
void avx512_qselect<uint32_t>(uint32_t *arr, int64_t k, int64_t arrsize)
727+
void avx512_qselect<uint32_t>(uint32_t *arr, int64_t k, int64_t arrsize, bool hasnan)
728728
{
729729
if (arrsize > 1) {
730730
qselect_32bit_<zmm_vector<uint32_t>, uint32_t>(
@@ -733,13 +733,15 @@ void avx512_qselect<uint32_t>(uint32_t *arr, int64_t k, int64_t arrsize)
733733
}
734734

735735
template <>
736-
void avx512_qselect<float>(float *arr, int64_t k, int64_t arrsize)
736+
void avx512_qselect<float>(float *arr, int64_t k, int64_t arrsize, bool hasnan)
737737
{
738-
if (arrsize > 1) {
739-
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
738+
int64_t indx_last_elem = arrsize - 1;
739+
if (UNLIKELY(hasnan)) {
740+
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
741+
}
742+
if (indx_last_elem >= k) {
740743
qselect_32bit_<zmm_vector<float>, float>(
741-
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
742-
replace_inf_with_nan(arr, arrsize, nan_count);
744+
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
743745
}
744746
}
745747

src/avx512-64bit-qsort.hpp

+9-7
Original file line numberDiff line numberDiff line change
@@ -784,7 +784,7 @@ static void qselect_64bit_(type_t *arr,
784784
}
785785

786786
template <>
787-
void avx512_qselect<int64_t>(int64_t *arr, int64_t k, int64_t arrsize)
787+
void avx512_qselect<int64_t>(int64_t *arr, int64_t k, int64_t arrsize, bool hasnan)
788788
{
789789
if (arrsize > 1) {
790790
qselect_64bit_<zmm_vector<int64_t>, int64_t>(
@@ -793,7 +793,7 @@ void avx512_qselect<int64_t>(int64_t *arr, int64_t k, int64_t arrsize)
793793
}
794794

795795
template <>
796-
void avx512_qselect<uint64_t>(uint64_t *arr, int64_t k, int64_t arrsize)
796+
void avx512_qselect<uint64_t>(uint64_t *arr, int64_t k, int64_t arrsize, bool hasnan)
797797
{
798798
if (arrsize > 1) {
799799
qselect_64bit_<zmm_vector<uint64_t>, uint64_t>(
@@ -802,13 +802,15 @@ void avx512_qselect<uint64_t>(uint64_t *arr, int64_t k, int64_t arrsize)
802802
}
803803

804804
template <>
805-
void avx512_qselect<double>(double *arr, int64_t k, int64_t arrsize)
805+
void avx512_qselect<double>(double *arr, int64_t k, int64_t arrsize, bool hasnan)
806806
{
807-
if (arrsize > 1) {
808-
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
807+
int64_t indx_last_elem = arrsize - 1;
808+
if (UNLIKELY(hasnan)) {
809+
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
810+
}
811+
if (indx_last_elem >= k) {
809812
qselect_64bit_<zmm_vector<double>, double>(
810-
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
811-
replace_inf_with_nan(arr, arrsize, nan_count);
813+
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
812814
}
813815
}
814816

src/avx512-common-qsort.h

+38-6
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@
8585
#define X86_SIMD_SORT_FINLINE static
8686
#endif
8787

88+
#define LIKELY(x) __builtin_expect((x),1)
89+
#define UNLIKELY(x) __builtin_expect((x),0)
90+
8891
template <typename type>
8992
struct zmm_vector;
9093

@@ -97,25 +100,54 @@ void avx512_qsort(T *arr, int64_t arrsize);
97100
void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize);
98101

99102
template <typename T>
100-
void avx512_qselect(T *arr, int64_t k, int64_t arrsize);
101-
void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize);
103+
void avx512_qselect(T *arr, int64_t k, int64_t arrsize, bool hasnan = false);
104+
void avx512_qselect_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false);
102105

103106
template <typename T>
104-
inline void avx512_partial_qsort(T *arr, int64_t k, int64_t arrsize)
107+
inline void avx512_partial_qsort(T *arr, int64_t k, int64_t arrsize, bool hasnan = false)
105108
{
106-
avx512_qselect<T>(arr, k - 1, arrsize);
109+
avx512_qselect<T>(arr, k - 1, arrsize, hasnan);
107110
avx512_qsort<T>(arr, k - 1);
108111
}
109-
inline void avx512_partial_qsort_fp16(uint16_t *arr, int64_t k, int64_t arrsize)
112+
inline void avx512_partial_qsort_fp16(uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false)
110113
{
111-
avx512_qselect_fp16(arr, k - 1, arrsize);
114+
avx512_qselect_fp16(arr, k - 1, arrsize, hasnan);
112115
avx512_qsort_fp16(arr, k - 1);
113116
}
114117

115118
// key-value sort routines
116119
template <typename T>
117120
void avx512_qsort_kv(T *keys, uint64_t *indexes, int64_t arrsize);
118121

122+
template <typename T>
123+
bool is_a_nan(T elem)
124+
{
125+
return std::isnan(elem);
126+
}
127+
128+
/*
129+
* Sort all the NAN's to end of the array and return the index of the last elem
130+
* in the array which is not a nan
131+
*/
132+
template <typename T>
133+
int64_t move_nans_to_end_of_array(T* arr, int64_t arrsize)
134+
{
135+
int64_t jj = arrsize - 1;
136+
int64_t ii = 0;
137+
int64_t count = 0;
138+
while (ii <= jj) {
139+
if (is_a_nan(arr[ii])) {
140+
std::swap(arr[ii], arr[jj]);
141+
jj -= 1;
142+
count++;
143+
}
144+
else {
145+
ii += 1;
146+
}
147+
}
148+
return arrsize-count-1;
149+
}
150+
119151
template <typename vtype, typename T = typename vtype::type_t>
120152
bool comparison_func(const T &a, const T &b)
121153
{

src/avx512fp16-16bit-qsort.hpp

+15-5
Original file line numberDiff line numberDiff line change
@@ -145,13 +145,23 @@ replace_inf_with_nan(_Float16 *arr, int64_t arrsize, int64_t nan_count)
145145
}
146146

147147
template <>
148-
void avx512_qselect(_Float16 *arr, int64_t k, int64_t arrsize)
148+
bool is_a_nan<_Float16>(_Float16 elem)
149149
{
150-
if (arrsize > 1) {
151-
int64_t nan_count = replace_nan_with_inf(arr, arrsize);
150+
Fp16Bits temp;
151+
temp.f_ = elem;
152+
return (temp.i_ & 0x7c00) == 0x7c00;
153+
}
154+
155+
template <>
156+
void avx512_qselect(_Float16 *arr, int64_t k, int64_t arrsize, bool hasnan)
157+
{
158+
int64_t indx_last_elem = arrsize - 1;
159+
if (UNLIKELY(hasnan)) {
160+
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
161+
}
162+
if (indx_last_elem >= k) {
152163
qselect_16bit_<zmm_vector<_Float16>, _Float16>(
153-
arr, k, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
154-
replace_inf_with_nan(arr, arrsize, nan_count);
164+
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
155165
}
156166
}
157167

0 commit comments

Comments
 (0)