@@ -349,12 +349,6 @@ struct zmm_vector<uint16_t> {
349349 }
350350};
351351
352- template <>
353- bool is_a_nan<uint16_t >(uint16_t elem)
354- {
355- return (elem & 0x7c00 ) == 0x7c00 ;
356- }
357-
358352template <>
359353bool comparison_func<zmm_vector<float16>>(const uint16_t &a, const uint16_t &b)
360354{
@@ -383,6 +377,34 @@ bool comparison_func<zmm_vector<float16>>(const uint16_t &a, const uint16_t &b)
383377 // return npy_half_to_float(a) < npy_half_to_float(b);
384378}
385379
380+ X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf (uint16_t *arr,
381+ int64_t arrsize)
382+ {
383+ int64_t nan_count = 0 ;
384+ __mmask16 loadmask = 0xFFFF ;
385+ while (arrsize > 0 ) {
386+ if (arrsize < 16 ) { loadmask = (0x0001 << arrsize) - 0x0001 ; }
387+ __m256i in_zmm = _mm256_maskz_loadu_epi16 (loadmask, arr);
388+ __m512 in_zmm_asfloat = _mm512_cvtph_ps (in_zmm);
389+ __mmask16 nanmask = _mm512_cmp_ps_mask (
390+ in_zmm_asfloat, in_zmm_asfloat, _CMP_NEQ_UQ);
391+ nan_count += _mm_popcnt_u32 ((int32_t )nanmask);
392+ _mm256_mask_storeu_epi16 (arr, nanmask, YMM_MAX_HALF);
393+ arr += 16 ;
394+ arrsize -= 16 ;
395+ }
396+ return nan_count;
397+ }
398+
399+ X86_SIMD_SORT_INLINE void
400+ replace_inf_with_nan (uint16_t *arr, int64_t arrsize, int64_t nan_count)
401+ {
402+ for (int64_t ii = arrsize - 1 ; nan_count > 0 ; --ii) {
403+ arr[ii] = 0xFFFF ;
404+ nan_count -= 1 ;
405+ }
406+ }
407+
386408template <>
387409void avx512_qselect (int16_t *arr, int64_t k, int64_t arrsize)
388410{
@@ -403,10 +425,11 @@ void avx512_qselect(uint16_t *arr, int64_t k, int64_t arrsize)
403425
404426void avx512_qselect_fp16 (uint16_t *arr, int64_t k, int64_t arrsize)
405427{
406- int64_t indx_last_elem = move_nans_to_end_of_array (arr, arrsize);
407- if (indx_last_elem >= k) {
428+ if ( arrsize > 1 ) {
429+ int64_t nan_count = replace_nan_with_inf (arr, arrsize);
408430 qselect_16bit_<zmm_vector<float16>, uint16_t >(
409- arr, k, 0 , indx_last_elem, 2 * (int64_t )log2 (indx_last_elem));
431+ arr, k, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
432+ replace_inf_with_nan (arr, arrsize, nan_count);
410433 }
411434}
412435
@@ -430,10 +453,11 @@ void avx512_qsort(uint16_t *arr, int64_t arrsize)
430453
431454void avx512_qsort_fp16 (uint16_t *arr, int64_t arrsize)
432455{
433- int64_t indx_last_elem = move_nans_to_end_of_array (arr, arrsize);
434- if (indx_last_elem > 0 ) {
456+ if ( arrsize > 1 ) {
457+ int64_t nan_count = replace_nan_with_inf (arr, arrsize);
435458 qsort_16bit_<zmm_vector<float16>, uint16_t >(
436- arr, 0 , indx_last_elem, 2 * (int64_t )log2 (indx_last_elem));
459+ arr, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
460+ replace_inf_with_nan (arr, arrsize, nan_count);
437461 }
438462}
439463
0 commit comments