@@ -349,6 +349,12 @@ struct zmm_vector<uint16_t> {
349
349
}
350
350
};
351
351
352
+ template <>
353
+ bool is_a_nan<uint16_t >(uint16_t elem)
354
+ {
355
+ return (elem & 0x7c00 ) == 0x7c00 ;
356
+ }
357
+
352
358
template <>
353
359
bool comparison_func<zmm_vector<float16>>(const uint16_t &a, const uint16_t &b)
354
360
{
@@ -377,34 +383,6 @@ bool comparison_func<zmm_vector<float16>>(const uint16_t &a, const uint16_t &b)
377
383
// return npy_half_to_float(a) < npy_half_to_float(b);
378
384
}
379
385
380
- X86_SIMD_SORT_INLINE int64_t replace_nan_with_inf (uint16_t *arr,
381
- int64_t arrsize)
382
- {
383
- int64_t nan_count = 0 ;
384
- __mmask16 loadmask = 0xFFFF ;
385
- while (arrsize > 0 ) {
386
- if (arrsize < 16 ) { loadmask = (0x0001 << arrsize) - 0x0001 ; }
387
- __m256i in_zmm = _mm256_maskz_loadu_epi16 (loadmask, arr);
388
- __m512 in_zmm_asfloat = _mm512_cvtph_ps (in_zmm);
389
- __mmask16 nanmask = _mm512_cmp_ps_mask (
390
- in_zmm_asfloat, in_zmm_asfloat, _CMP_NEQ_UQ);
391
- nan_count += _mm_popcnt_u32 ((int32_t )nanmask);
392
- _mm256_mask_storeu_epi16 (arr, nanmask, YMM_MAX_HALF);
393
- arr += 16 ;
394
- arrsize -= 16 ;
395
- }
396
- return nan_count;
397
- }
398
-
399
- X86_SIMD_SORT_INLINE void
400
- replace_inf_with_nan (uint16_t *arr, int64_t arrsize, int64_t nan_count)
401
- {
402
- for (int64_t ii = arrsize - 1 ; nan_count > 0 ; --ii) {
403
- arr[ii] = 0xFFFF ;
404
- nan_count -= 1 ;
405
- }
406
- }
407
-
408
386
template <>
409
387
void avx512_qselect (int16_t *arr, int64_t k, int64_t arrsize)
410
388
{
@@ -425,11 +403,10 @@ void avx512_qselect(uint16_t *arr, int64_t k, int64_t arrsize)
425
403
426
404
void avx512_qselect_fp16 (uint16_t *arr, int64_t k, int64_t arrsize)
427
405
{
428
- if (arrsize > 1 ) {
429
- int64_t nan_count = replace_nan_with_inf (arr, arrsize);
406
+ int64_t indx_last_elem = move_nans_to_end_of_array (arr, arrsize);
407
+ if (indx_last_elem >= k) {
430
408
qselect_16bit_<zmm_vector<float16>, uint16_t >(
431
- arr, k, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
432
- replace_inf_with_nan (arr, arrsize, nan_count);
409
+ arr, k, 0 , indx_last_elem, 2 * (int64_t )log2 (indx_last_elem));
433
410
}
434
411
}
435
412
@@ -453,11 +430,10 @@ void avx512_qsort(uint16_t *arr, int64_t arrsize)
453
430
454
431
void avx512_qsort_fp16 (uint16_t *arr, int64_t arrsize)
455
432
{
456
- if (arrsize > 1 ) {
457
- int64_t nan_count = replace_nan_with_inf (arr, arrsize);
433
+ int64_t indx_last_elem = move_nans_to_end_of_array (arr, arrsize);
434
+ if (indx_last_elem > 0 ) {
458
435
qsort_16bit_<zmm_vector<float16>, uint16_t >(
459
- arr, 0 , arrsize - 1 , 2 * (int64_t )log2 (arrsize));
460
- replace_inf_with_nan (arr, arrsize, nan_count);
436
+ arr, 0 , indx_last_elem, 2 * (int64_t )log2 (indx_last_elem));
461
437
}
462
438
}
463
439
0 commit comments