@@ -349,12 +349,6 @@ struct zmm_vector<uint16_t> {
349
349
}
350
350
};
351
351
352
- template <>
353
- bool is_a_nan<uint16_t >(uint16_t elem)
354
- {
355
- return (elem & 0x7c00 ) == 0x7c00 ;
356
- }
357
-
358
352
template <>
359
353
bool comparison_func<zmm_vector<float16>>(const uint16_t &a, const uint16_t &b)
360
354
{
@@ -383,6 +377,34 @@ bool comparison_func<zmm_vector<float16>>(const uint16_t &a, const uint16_t &b)
383
377
// return npy_half_to_float(a) < npy_half_to_float(b);
384
378
}
385
379
380
+ X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf (uint16_t *arr,
381
+ int64_t arrsize)
382
+ {
383
+ int64_t nan_count = 0 ;
384
+ __mmask16 loadmask = 0xFFFF ;
385
+ while (arrsize > 0 ) {
386
+ if (arrsize < 16 ) { loadmask = (0x0001 << arrsize) - 0x0001 ; }
387
+ __m256i in_zmm = _mm256_maskz_loadu_epi16 (loadmask, arr);
388
+ __m512 in_zmm_asfloat = _mm512_cvtph_ps (in_zmm);
389
+ __mmask16 nanmask = _mm512_cmp_ps_mask (
390
+ in_zmm_asfloat, in_zmm_asfloat, _CMP_NEQ_UQ);
391
+ nan_count += _mm_popcnt_u32 ((int32_t )nanmask);
392
+ _mm256_mask_storeu_epi16 (arr, nanmask, YMM_MAX_HALF);
393
+ arr += 16 ;
394
+ arrsize -= 16 ;
395
+ }
396
+ return nan_count;
397
+ }
398
+
399
+ X86_SIMD_SORT_INLINE void
400
+ replace_inf_with_nan (uint16_t *arr, int64_t arrsize, int64_t nan_count)
401
+ {
402
+ for (int64_t ii = arrsize - 1 ; nan_count > 0 ; --ii) {
403
+ arr[ii] = 0xFFFF ;
404
+ nan_count -= 1 ;
405
+ }
406
+ }
407
+
386
408
template <>
387
409
void avx512_qselect (int16_t *arr, int64_t k, int64_t arrsize)
388
410
{
@@ -403,10 +425,11 @@ void avx512_qselect(uint16_t *arr, int64_t k, int64_t arrsize)
403
425
404
426
void avx512_qselect_fp16 (uint16_t *arr, int64_t k, int64_t arrsize)
405
427
{
406
- int64_t indx_last_elem = move_nans_to_end_of_array (arr, arrsize);
407
- if (indx_last_elem >= k) {
428
+ if ( arrsize > 1 ) {
429
+ int64_t nan_count = replace_nan_with_inf (arr, arrsize);
408
430
qselect_16bit_<zmm_vector<float16>, uint16_t >(
409
- arr, k, 0 , indx_last_elem, 2 * (int64_t )log2 (indx_last_elem));
431
+ arr, k, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
432
+ replace_inf_with_nan (arr, arrsize, nan_count);
410
433
}
411
434
}
412
435
@@ -430,10 +453,11 @@ void avx512_qsort(uint16_t *arr, int64_t arrsize)
430
453
431
454
void avx512_qsort_fp16 (uint16_t *arr, int64_t arrsize)
432
455
{
433
- int64_t indx_last_elem = move_nans_to_end_of_array (arr, arrsize);
434
- if (indx_last_elem > 0 ) {
456
+ if ( arrsize > 1 ) {
457
+ int64_t nan_count = replace_nan_with_inf (arr, arrsize);
435
458
qsort_16bit_<zmm_vector<float16>, uint16_t >(
436
- arr, 0 , indx_last_elem, 2 * (int64_t )log2 (indx_last_elem));
459
+ arr, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
460
+ replace_inf_with_nan (arr, arrsize, nan_count);
437
461
}
438
462
}
439
463
0 commit comments