Skip to content

Commit e0103be

Browse files
committed
Enable fp16 sorting without AVX512FP16
1 parent 745324c commit e0103be

6 files changed

+144
-28
lines changed

lib/x86simdsort-icl.cpp

+25
Original file line numberDiff line numberDiff line change
@@ -50,5 +50,30 @@ namespace avx512 {
5050
{
5151
x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending);
5252
}
53+
#ifdef __FLT16_MAX__
54+
template <>
55+
void qsort(_Float16 *arr, size_t size, bool hasnan, bool descending)
56+
{
57+
x86simdsortStatic::qsort(arr, size, hasnan, descending);
58+
}
59+
template <>
60+
void qselect(_Float16 *arr,
61+
size_t k,
62+
size_t arrsize,
63+
bool hasnan,
64+
bool descending)
65+
{
66+
x86simdsortStatic::qselect(arr, k, arrsize, hasnan, descending);
67+
}
68+
template <>
69+
void partial_qsort(_Float16 *arr,
70+
size_t k,
71+
size_t arrsize,
72+
bool hasnan,
73+
bool descending)
74+
{
75+
x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending);
76+
}
77+
#endif
5378
} // namespace avx512
5479
} // namespace xss

lib/x86simdsort.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,9 @@ namespace x86simdsort {
137137
}
138138

139139
#ifdef __FLT16_MAX__
140-
DISPATCH(qsort, _Float16, ISA_LIST("avx512_spr"))
141-
DISPATCH(qselect, _Float16, ISA_LIST("avx512_spr"))
142-
DISPATCH(partial_qsort, _Float16, ISA_LIST("avx512_spr"))
140+
DISPATCH(qsort, _Float16, ISA_LIST("avx512_spr", "avx512_icl"))
141+
DISPATCH(qselect, _Float16, ISA_LIST("avx512_spr", "avx512_icl"))
142+
DISPATCH(partial_qsort, _Float16, ISA_LIST("avx512_spr", "avx512_icl"))
143143
DISPATCH(argsort, _Float16, ISA_LIST("none"))
144144
DISPATCH(argselect, _Float16, ISA_LIST("none"))
145145
#endif

src/avx512-16bit-qsort.hpp

+85-25
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,6 @@
99

1010
#include "avx512-16bit-common.h"
1111

12-
struct float16 {
13-
uint16_t val;
14-
};
15-
1612
template <>
1713
struct zmm_vector<float16> {
1814
using type_t = uint16_t;
@@ -545,10 +541,65 @@ replace_nan_with_inf<zmm_vector<float16>>(uint16_t *arr, arrsize_t arrsize)
545541
return nan_count;
546542
}
547543

548-
template <>
549-
X86_SIMD_SORT_INLINE_ONLY bool is_a_nan<uint16_t>(uint16_t elem)
544+
X86_SIMD_SORT_INLINE_ONLY void replace_inf_with_nan_fp16(_Float16 *arr,
545+
arrsize_t size,
546+
arrsize_t nan_count,
547+
bool descending
548+
= false)
549+
{
550+
if (descending) {
551+
for (arrsize_t ii = 0; nan_count > 0; ++ii) {
552+
arr[ii] = xss::fp::quiet_NaN<_Float16>();
553+
nan_count -= 1;
554+
}
555+
}
556+
else {
557+
for (arrsize_t ii = size - 1; nan_count > 0; --ii) {
558+
arr[ii] = xss::fp::quiet_NaN<_Float16>();
559+
nan_count -= 1;
560+
}
561+
}
562+
}
563+
564+
template <typename comparator>
565+
[[maybe_unused]] X86_SIMD_SORT_INLINE void
566+
avx512_qsort_fp16_helper(uint16_t *arr, arrsize_t arrsize)
550567
{
551-
return ((elem & 0x7c00u) == 0x7c00u) && ((elem & 0x03ffu) != 0);
568+
using T = uint16_t;
569+
using vtype = zmm_vector<float16>;
570+
571+
#ifdef XSS_COMPILE_OPENMP
572+
bool use_parallel = arrsize > 100000;
573+
574+
if (use_parallel) {
575+
// This thread limit was determined experimentally; it may be better for it to be the number of physical cores on the system
576+
constexpr int thread_limit = 8;
577+
int thread_count = std::min(thread_limit, omp_get_max_threads());
578+
arrsize_t task_threshold = std::max((arrsize_t)100000, arrsize / 100);
579+
580+
// We use omp parallel and then omp single to setup the threads that will run the omp task calls in qsort_
581+
// The omp single prevents multiple threads from running the initial qsort_ simultaneously and causing problems
582+
// Note that we do not use the if(...) clause built into OpenMP, because it causes a performance regression for small arrays
583+
#pragma omp parallel num_threads(thread_count)
584+
#pragma omp single
585+
qsort_<vtype, comparator, T>(arr,
586+
0,
587+
arrsize - 1,
588+
2 * (arrsize_t)log2(arrsize),
589+
task_threshold);
590+
}
591+
else {
592+
qsort_<vtype, comparator, T>(arr,
593+
0,
594+
arrsize - 1,
595+
2 * (arrsize_t)log2(arrsize),
596+
std::numeric_limits<arrsize_t>::max());
597+
}
598+
#pragma omp taskwait
599+
#else
600+
qsort_<vtype, comparator, T>(
601+
arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize), 0);
602+
#endif
552603
}
553604

554605
[[maybe_unused]] X86_SIMD_SORT_INLINE void
@@ -559,22 +610,19 @@ avx512_qsort_fp16(uint16_t *arr,
559610
{
560611
using vtype = zmm_vector<float16>;
561612

562-
// TODO multithreading support here
563613
if (arrsize > 1) {
564614
arrsize_t nan_count = 0;
565615
if (UNLIKELY(hasnan)) {
566-
nan_count = replace_nan_with_inf<zmm_vector<float16>, uint16_t>(
567-
arr, arrsize);
616+
nan_count = replace_nan_with_inf<vtype, uint16_t>(arr, arrsize);
568617
}
569618
if (descending) {
570-
qsort_<vtype, Comparator<vtype, true>, uint16_t>(
571-
arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize), 0);
619+
avx512_qsort_fp16_helper<Comparator<vtype, true>>(arr, arrsize);
572620
}
573621
else {
574-
qsort_<vtype, Comparator<vtype, false>, uint16_t>(
575-
arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize), 0);
622+
avx512_qsort_fp16_helper<Comparator<vtype, false>>(arr, arrsize);
576623
}
577-
replace_inf_with_nan(arr, arrsize, nan_count, descending);
624+
replace_inf_with_nan_fp16(
625+
(_Float16 *)arr, arrsize, nan_count, descending);
578626
}
579627

580628
#ifdef __MMX__
@@ -592,26 +640,37 @@ avx512_qselect_fp16(uint16_t *arr,
592640
{
593641
using vtype = zmm_vector<float16>;
594642

595-
arrsize_t indx_last_elem = arrsize - 1;
643+
// Exit early if no work would be done
644+
if (arrsize <= 1) return;
645+
646+
arrsize_t index_first_elem = 0;
647+
arrsize_t index_last_elem = arrsize - 1;
648+
596649
if (UNLIKELY(hasnan)) {
597-
indx_last_elem = move_nans_to_end_of_array(arr, arrsize);
650+
if (descending) {
651+
index_first_elem = move_nans_to_start_of_array(arr, arrsize);
652+
}
653+
else {
654+
index_last_elem = move_nans_to_end_of_array(arr, arrsize);
655+
}
598656
}
599-
if (indx_last_elem >= k) {
657+
658+
if (index_first_elem <= k && index_last_elem >= k) {
600659
if (descending) {
601660
qselect_<vtype, Comparator<vtype, true>, uint16_t>(
602661
arr,
603662
k,
604-
0,
605-
indx_last_elem,
606-
2 * (arrsize_t)log2(indx_last_elem));
663+
index_first_elem,
664+
index_last_elem,
665+
2 * (arrsize_t)log2(arrsize));
607666
}
608667
else {
609668
qselect_<vtype, Comparator<vtype, false>, uint16_t>(
610669
arr,
611670
k,
612-
0,
613-
indx_last_elem,
614-
2 * (arrsize_t)log2(indx_last_elem));
671+
index_first_elem,
672+
index_last_elem,
673+
2 * (arrsize_t)log2(arrsize));
615674
}
616675
}
617676

@@ -628,7 +687,8 @@ avx512_partial_qsort_fp16(uint16_t *arr,
628687
bool hasnan = false,
629688
bool descending = false)
630689
{
690+
if (k == 0) return;
631691
avx512_qselect_fp16(arr, k - 1, arrsize, hasnan, descending);
632-
avx512_qsort_fp16(arr, k - 1, descending);
692+
avx512_qsort_fp16(arr, k - 1, hasnan, descending);
633693
}
634694
#endif // AVX512_QSORT_16BIT

src/x86simdsort-static-incl.h

+21
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,27 @@ X86_SIMD_SORT_FINLINE void keyvalue_partial_sort(T1 *key,
173173

174174
XSS_METHODS(avx512)
175175

176+
#if defined(__FLT16_MAX__) && defined(__AVX512BW__) && defined(__AVX512VBMI2__) && !defined(__AVX512FP16__)
177+
template <>
178+
void x86simdsortStatic::qsort<_Float16>(
179+
_Float16 *arr, size_t size, bool hasnan, bool descending)
180+
{
181+
avx512_qsort_fp16((uint16_t *)arr, size, hasnan, descending);
182+
}
183+
template <>
184+
void x86simdsortStatic::qselect<_Float16>(
185+
_Float16 *arr, size_t k, size_t size, bool hasnan, bool descending)
186+
{
187+
avx512_qselect_fp16((uint16_t *)arr, k, size, hasnan, descending);
188+
}
189+
template <>
190+
void x86simdsortStatic::partial_qsort<_Float16>(
191+
_Float16 *arr, size_t k, size_t size, bool hasnan, bool descending)
192+
{
193+
avx512_partial_qsort_fp16((uint16_t *)arr, k, size, hasnan, descending);
194+
}
195+
#endif
196+
176197
#elif defined(__AVX512F__)
177198
#error "x86simdsort requires AVX512DQ and AVX512VL to be enabled in addition to AVX512F to use AVX512"
178199

src/xss-common-includes.h

+4
Original file line numberDiff line numberDiff line change
@@ -109,4 +109,8 @@ enum class simd_type : int { AVX2, AVX512 };
109109
template <typename vtype, typename T = typename vtype::type_t>
110110
X86_SIMD_SORT_INLINE bool comparison_func(const T &a, const T &b);
111111

112+
struct float16 {
113+
uint16_t val;
114+
};
115+
112116
#endif // XSS_COMMON_INCLUDES

src/xss-common-qsort.h

+6
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ bool is_a_nan(T elem)
4545
return std::isnan(elem);
4646
}
4747

48+
template <>
49+
X86_SIMD_SORT_INLINE_ONLY bool is_a_nan<uint16_t>(uint16_t elem)
50+
{
51+
return ((elem & 0x7c00u) == 0x7c00u) && ((elem & 0x03ffu) != 0);
52+
}
53+
4854
template <typename vtype, typename T>
4955
X86_SIMD_SORT_INLINE arrsize_t replace_nan_with_inf(T *arr, arrsize_t size)
5056
{

0 commit comments

Comments
 (0)