Skip to content

Commit bbb7906

Browse files
committed
Fix dispatch logic to use both ICL and SPR fp16
1 parent e0103be commit bbb7906

File tree

4 files changed

+69
-6
lines changed

4 files changed

+69
-6
lines changed

lib/x86simdsort-icl.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ namespace avx512 {
5050
{
5151
x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending);
5252
}
53+
} // namespace avx512
54+
namespace fp16_icl {
5355
#ifdef __FLT16_MAX__
5456
template <>
5557
void qsort(_Float16 *arr, size_t size, bool hasnan, bool descending)
@@ -75,5 +77,5 @@ namespace avx512 {
7577
x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending);
7678
}
7779
#endif
78-
} // namespace avx512
80+
} // namespace fp16_icl
7981
} // namespace xss

lib/x86simdsort-internal.h

+40
Original file line numberDiff line numberDiff line change
@@ -164,5 +164,45 @@ namespace scalar {
164164
XSS_HIDE_SYMBOL std::vector<size_t>
165165
argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false);
166166
} // namespace scalar
167+
namespace fp16_spr {
168+
// quicksort
169+
template <typename T>
170+
XSS_HIDE_SYMBOL void
171+
qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false);
172+
// quickselect
173+
template <typename T>
174+
XSS_HIDE_SYMBOL void qselect(T *arr,
175+
size_t k,
176+
size_t arrsize,
177+
bool hasnan = false,
178+
bool descending = false);
179+
// partial sort
180+
template <typename T>
181+
XSS_HIDE_SYMBOL void partial_qsort(T *arr,
182+
size_t k,
183+
size_t arrsize,
184+
bool hasnan = false,
185+
bool descending = false);
186+
} // namespace fp16_spr
187+
namespace fp16_icl {
188+
// quicksort
189+
template <typename T>
190+
XSS_HIDE_SYMBOL void
191+
qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false);
192+
// quickselect
193+
template <typename T>
194+
XSS_HIDE_SYMBOL void qselect(T *arr,
195+
size_t k,
196+
size_t arrsize,
197+
bool hasnan = false,
198+
bool descending = false);
199+
// partial sort
200+
template <typename T>
201+
XSS_HIDE_SYMBOL void partial_qsort(T *arr,
202+
size_t k,
203+
size_t arrsize,
204+
bool hasnan = false,
205+
bool descending = false);
206+
} // namespace fp16_icl
167207
} // namespace xss
168208
#endif

lib/x86simdsort-spr.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#include "x86simdsort-internal.h"
44

55
namespace xss {
6-
namespace avx512 {
6+
namespace fp16_spr {
77
template <>
88
void qsort(_Float16 *arr, size_t size, bool hasnan, bool descending)
99
{
@@ -27,5 +27,5 @@ namespace avx512 {
2727
{
2828
x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending);
2929
}
30-
} // namespace avx512
30+
} // namespace fp16_spr
3131
} // namespace xss

lib/x86simdsort.cpp

+24-3
Original file line numberDiff line numberDiff line change
@@ -130,16 +130,37 @@ namespace x86simdsort {
130130
} \
131131
}
132132

133+
#define DISPATCH_FP16(func, TYPE, ISA) \
134+
DECLARE_INTERNAL_##func(TYPE) static __attribute__((constructor)) void \
135+
CAT(CAT(resolve_, func), TYPE)(void) \
136+
{ \
137+
CAT(CAT(internal_, func), TYPE) = &xss::scalar::func<TYPE>; \
138+
__builtin_cpu_init(); \
139+
std::string_view preferred_cpu = find_preferred_cpu(ISA); \
140+
if constexpr (dispatch_requested("avx512_spr", ISA)) { \
141+
if (preferred_cpu.find("avx512_spr") != std::string_view::npos) { \
142+
CAT(CAT(internal_, func), TYPE) = &xss::fp16_spr::func<TYPE>; \
143+
return; \
144+
} \
145+
} \
146+
if constexpr (dispatch_requested("avx512_icl", ISA)) { \
147+
if (preferred_cpu.find("avx512_icl") != std::string_view::npos) { \
148+
CAT(CAT(internal_, func), TYPE) = &xss::fp16_icl::func<TYPE>; \
149+
return; \
150+
} \
151+
} \
152+
}
153+
133154
#define ISA_LIST(...) \
134155
std::initializer_list<std::string_view> \
135156
{ \
136157
__VA_ARGS__ \
137158
}
138159

139160
#ifdef __FLT16_MAX__
140-
DISPATCH(qsort, _Float16, ISA_LIST("avx512_spr", "avx512_icl"))
141-
DISPATCH(qselect, _Float16, ISA_LIST("avx512_spr", "avx512_icl"))
142-
DISPATCH(partial_qsort, _Float16, ISA_LIST("avx512_spr", "avx512_icl"))
161+
DISPATCH_FP16(qsort, _Float16, ISA_LIST("avx512_spr", "avx512_icl"))
162+
DISPATCH_FP16(qselect, _Float16, ISA_LIST("avx512_spr", "avx512_icl"))
163+
DISPATCH_FP16(partial_qsort, _Float16, ISA_LIST("avx512_spr", "avx512_icl"))
143164
DISPATCH(argsort, _Float16, ISA_LIST("none"))
144165
DISPATCH(argselect, _Float16, ISA_LIST("none"))
145166
#endif

0 commit comments

Comments
 (0)