85
85
#define X86_SIMD_SORT_FINLINE static
86
86
#endif
87
87
88
+ #define LIKELY (x ) __builtin_expect((x),1 )
89
+ #define UNLIKELY (x ) __builtin_expect((x),0 )
90
+
88
91
template <typename type>
89
92
struct zmm_vector ;
90
93
@@ -97,25 +100,54 @@ void avx512_qsort(T *arr, int64_t arrsize);
97
100
void avx512_qsort_fp16 (uint16_t *arr, int64_t arrsize);
98
101
99
102
template <typename T>
100
- void avx512_qselect (T *arr, int64_t k, int64_t arrsize);
101
- void avx512_qselect_fp16 (uint16_t *arr, int64_t k, int64_t arrsize);
103
+ void avx512_qselect (T *arr, int64_t k, int64_t arrsize, bool hasnan = false );
104
+ void avx512_qselect_fp16 (uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false );
102
105
103
106
template <typename T>
104
- inline void avx512_partial_qsort (T *arr, int64_t k, int64_t arrsize)
107
+ inline void avx512_partial_qsort (T *arr, int64_t k, int64_t arrsize, bool hasnan = false )
105
108
{
106
- avx512_qselect<T>(arr, k - 1 , arrsize);
109
+ avx512_qselect<T>(arr, k - 1 , arrsize, hasnan );
107
110
avx512_qsort<T>(arr, k - 1 );
108
111
}
109
- inline void avx512_partial_qsort_fp16 (uint16_t *arr, int64_t k, int64_t arrsize)
112
+ inline void avx512_partial_qsort_fp16 (uint16_t *arr, int64_t k, int64_t arrsize, bool hasnan = false )
110
113
{
111
- avx512_qselect_fp16 (arr, k - 1 , arrsize);
114
+ avx512_qselect_fp16 (arr, k - 1 , arrsize, hasnan );
112
115
avx512_qsort_fp16 (arr, k - 1 );
113
116
}
114
117
115
118
// key-value sort routines
116
119
template <typename T>
117
120
void avx512_qsort_kv (T *keys, uint64_t *indexes, int64_t arrsize);
118
121
122
+ template <typename T>
123
+ bool is_a_nan (T elem)
124
+ {
125
+ return std::isnan (elem);
126
+ }
127
+
128
+ /*
129
+ * Sort all the NAN's to end of the array and return the index of the last elem
130
+ * in the array which is not a nan
131
+ */
132
+ template <typename T>
133
+ int64_t move_nans_to_end_of_array (T* arr, int64_t arrsize)
134
+ {
135
+ int64_t jj = arrsize - 1 ;
136
+ int64_t ii = 0 ;
137
+ int64_t count = 0 ;
138
+ while (ii <= jj) {
139
+ if (is_a_nan (arr[ii])) {
140
+ std::swap (arr[ii], arr[jj]);
141
+ jj -= 1 ;
142
+ count++;
143
+ }
144
+ else {
145
+ ii += 1 ;
146
+ }
147
+ }
148
+ return arrsize-count-1 ;
149
+ }
150
+
119
151
template <typename vtype, typename T = typename vtype::type_t >
120
152
bool comparison_func (const T &a, const T &b)
121
153
{
0 commit comments