9
9
10
10
#include " avx512-16bit-common.h"
11
11
12
- struct float16 {
13
- uint16_t val;
14
- };
15
-
16
12
template <>
17
13
struct zmm_vector <float16> {
18
14
using type_t = uint16_t ;
@@ -545,10 +541,65 @@ replace_nan_with_inf<zmm_vector<float16>>(uint16_t *arr, arrsize_t arrsize)
545
541
return nan_count;
546
542
}
547
543
548
- template <>
549
- X86_SIMD_SORT_INLINE_ONLY bool is_a_nan<uint16_t >(uint16_t elem)
544
+ X86_SIMD_SORT_INLINE_ONLY void replace_inf_with_nan_fp16 (_Float16 *arr,
545
+ arrsize_t size,
546
+ arrsize_t nan_count,
547
+ bool descending
548
+ = false )
549
+ {
550
+ if (descending) {
551
+ for (arrsize_t ii = 0 ; nan_count > 0 ; ++ii) {
552
+ arr[ii] = xss::fp::quiet_NaN<_Float16>();
553
+ nan_count -= 1 ;
554
+ }
555
+ }
556
+ else {
557
+ for (arrsize_t ii = size - 1 ; nan_count > 0 ; --ii) {
558
+ arr[ii] = xss::fp::quiet_NaN<_Float16>();
559
+ nan_count -= 1 ;
560
+ }
561
+ }
562
+ }
563
+
564
+ template <typename comparator>
565
+ [[maybe_unused]] X86_SIMD_SORT_INLINE void
566
+ avx512_qsort_fp16_helper (uint16_t *arr, arrsize_t arrsize)
550
567
{
551
- return ((elem & 0x7c00u ) == 0x7c00u ) && ((elem & 0x03ffu ) != 0 );
568
+ using T = uint16_t ;
569
+ using vtype = zmm_vector<float16>;
570
+
571
+ #ifdef XSS_COMPILE_OPENMP
572
+ bool use_parallel = arrsize > 100000 ;
573
+
574
+ if (use_parallel) {
575
+ // This thread limit was determined experimentally; it may be better for it to be the number of physical cores on the system
576
+ constexpr int thread_limit = 8 ;
577
+ int thread_count = std::min (thread_limit, omp_get_max_threads ());
578
+ arrsize_t task_threshold = std::max ((arrsize_t )100000 , arrsize / 100 );
579
+
580
+ // We use omp parallel and then omp single to setup the threads that will run the omp task calls in qsort_
581
+ // The omp single prevents multiple threads from running the initial qsort_ simultaneously and causing problems
582
+ // Note that we do not use the if(...) clause built into OpenMP, because it causes a performance regression for small arrays
583
+ #pragma omp parallel num_threads(thread_count)
584
+ #pragma omp single
585
+ qsort_<vtype, comparator, T>(arr,
586
+ 0 ,
587
+ arrsize - 1 ,
588
+ 2 * (arrsize_t )log2 (arrsize),
589
+ task_threshold);
590
+ }
591
+ else {
592
+ qsort_<vtype, comparator, T>(arr,
593
+ 0 ,
594
+ arrsize - 1 ,
595
+ 2 * (arrsize_t )log2 (arrsize),
596
+ std::numeric_limits<arrsize_t >::max ());
597
+ }
598
+ #pragma omp taskwait
599
+ #else
600
+ qsort_<vtype, comparator, T>(
601
+ arr, 0 , arrsize - 1 , 2 * (arrsize_t )log2 (arrsize), 0 );
602
+ #endif
552
603
}
553
604
554
605
[[maybe_unused]] X86_SIMD_SORT_INLINE void
@@ -559,22 +610,19 @@ avx512_qsort_fp16(uint16_t *arr,
559
610
{
560
611
using vtype = zmm_vector<float16>;
561
612
562
- // TODO multithreading support here
563
613
if (arrsize > 1 ) {
564
614
arrsize_t nan_count = 0 ;
565
615
if (UNLIKELY (hasnan)) {
566
- nan_count = replace_nan_with_inf<zmm_vector<float16>, uint16_t >(
567
- arr, arrsize);
616
+ nan_count = replace_nan_with_inf<vtype, uint16_t >(arr, arrsize);
568
617
}
569
618
if (descending) {
570
- qsort_<vtype, Comparator<vtype, true >, uint16_t >(
571
- arr, 0 , arrsize - 1 , 2 * (arrsize_t )log2 (arrsize), 0 );
619
+ avx512_qsort_fp16_helper<Comparator<vtype, true >>(arr, arrsize);
572
620
}
573
621
else {
574
- qsort_<vtype, Comparator<vtype, false >, uint16_t >(
575
- arr, 0 , arrsize - 1 , 2 * (arrsize_t )log2 (arrsize), 0 );
622
+ avx512_qsort_fp16_helper<Comparator<vtype, false >>(arr, arrsize);
576
623
}
577
- replace_inf_with_nan (arr, arrsize, nan_count, descending);
624
+ replace_inf_with_nan_fp16 (
625
+ (_Float16 *)arr, arrsize, nan_count, descending);
578
626
}
579
627
580
628
#ifdef __MMX__
@@ -592,26 +640,37 @@ avx512_qselect_fp16(uint16_t *arr,
592
640
{
593
641
using vtype = zmm_vector<float16>;
594
642
595
- arrsize_t indx_last_elem = arrsize - 1 ;
643
+ // Exit early if no work would be done
644
+ if (arrsize <= 1 ) return ;
645
+
646
+ arrsize_t index_first_elem = 0 ;
647
+ arrsize_t index_last_elem = arrsize - 1 ;
648
+
596
649
if (UNLIKELY (hasnan)) {
597
- indx_last_elem = move_nans_to_end_of_array (arr, arrsize);
650
+ if (descending) {
651
+ index_first_elem = move_nans_to_start_of_array (arr, arrsize);
652
+ }
653
+ else {
654
+ index_last_elem = move_nans_to_end_of_array (arr, arrsize);
655
+ }
598
656
}
599
- if (indx_last_elem >= k) {
657
+
658
+ if (index_first_elem <= k && index_last_elem >= k) {
600
659
if (descending) {
601
660
qselect_<vtype, Comparator<vtype, true >, uint16_t >(
602
661
arr,
603
662
k,
604
- 0 ,
605
- indx_last_elem ,
606
- 2 * (arrsize_t )log2 (indx_last_elem ));
663
+ index_first_elem ,
664
+ index_last_elem ,
665
+ 2 * (arrsize_t )log2 (arrsize ));
607
666
}
608
667
else {
609
668
qselect_<vtype, Comparator<vtype, false >, uint16_t >(
610
669
arr,
611
670
k,
612
- 0 ,
613
- indx_last_elem ,
614
- 2 * (arrsize_t )log2 (indx_last_elem ));
671
+ index_first_elem ,
672
+ index_last_elem ,
673
+ 2 * (arrsize_t )log2 (arrsize ));
615
674
}
616
675
}
617
676
@@ -628,7 +687,8 @@ avx512_partial_qsort_fp16(uint16_t *arr,
628
687
bool hasnan = false ,
629
688
bool descending = false )
630
689
{
690
+ if (k == 0 ) return ;
631
691
avx512_qselect_fp16 (arr, k - 1 , arrsize, hasnan, descending);
632
- avx512_qsort_fp16 (arr, k - 1 , descending);
692
+ avx512_qsort_fp16 (arr, k - 1 , hasnan, descending);
633
693
}
634
694
#endif // AVX512_QSORT_16BIT
0 commit comments