Skip to content

Commit 724e92e

Browse files
authored
Merge pull request #200 from sterrettm2/fp16_nonnative
Enable fp16 nonnative support for dynamic dispatch, make more ergonomic for static dispatch
2 parents 745324c + 2c39de4 commit 724e92e

10 files changed

+219
-193
lines changed

.github/workflows/c-cpp.yml

+3
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,9 @@ jobs:
208208
209209
- name: Run test suite on SPR
210210
run: sde -spr -- ./builddir/testexe
211+
- name: Run ICL fp16 tests
212+
# Note: This filters for the _Float16 tests based on the number assigned to it, which could change in the future
213+
run: sde -icx -- ./builddir/testexe --gtest_filter="*/simdsort/2*"
211214

212215
SKX-SKL-openmp:
213216

lib/x86simdsort-icl.cpp

+27
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,31 @@ namespace avx512 {
5151
x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending);
5252
}
5353
} // namespace avx512
54+
namespace fp16_icl {
55+
#ifdef __FLT16_MAX__
56+
template <>
57+
void qsort(_Float16 *arr, size_t size, bool hasnan, bool descending)
58+
{
59+
x86simdsortStatic::qsort(arr, size, hasnan, descending);
60+
}
61+
template <>
62+
void qselect(_Float16 *arr,
63+
size_t k,
64+
size_t arrsize,
65+
bool hasnan,
66+
bool descending)
67+
{
68+
x86simdsortStatic::qselect(arr, k, arrsize, hasnan, descending);
69+
}
70+
template <>
71+
void partial_qsort(_Float16 *arr,
72+
size_t k,
73+
size_t arrsize,
74+
bool hasnan,
75+
bool descending)
76+
{
77+
x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending);
78+
}
79+
#endif
80+
} // namespace fp16_icl
5481
} // namespace xss

lib/x86simdsort-internal.h

+54-159
Original file line numberDiff line numberDiff line change
@@ -4,165 +4,60 @@
44
#include <stdint.h>
55
#include <vector>
66

7+
#define DECLAREALLFUNCS(name) \
8+
namespace name { \
9+
template <typename T> \
10+
XSS_HIDE_SYMBOL void qsort(T *arr, \
11+
size_t arrsize, \
12+
bool hasnan = false, \
13+
bool descending = false); \
14+
template <typename T1, typename T2> \
15+
XSS_HIDE_SYMBOL void keyvalue_qsort(T1 *key, \
16+
T2 *val, \
17+
size_t arrsize, \
18+
bool hasnan = false, \
19+
bool descending = false); \
20+
template <typename T> \
21+
XSS_HIDE_SYMBOL void qselect(T *arr, \
22+
size_t k, \
23+
size_t arrsize, \
24+
bool hasnan = false, \
25+
bool descending = false); \
26+
template <typename T1, typename T2> \
27+
XSS_HIDE_SYMBOL void keyvalue_select(T1 *key, \
28+
T2 *val, \
29+
size_t k, \
30+
size_t arrsize, \
31+
bool hasnan = false, \
32+
bool descending = false); \
33+
template <typename T> \
34+
XSS_HIDE_SYMBOL void partial_qsort(T *arr, \
35+
size_t k, \
36+
size_t arrsize, \
37+
bool hasnan = false, \
38+
bool descending = false); \
39+
template <typename T1, typename T2> \
40+
XSS_HIDE_SYMBOL void keyvalue_partial_sort(T1 *key, \
41+
T2 *val, \
42+
size_t k, \
43+
size_t arrsize, \
44+
bool hasnan = false, \
45+
bool descending = false); \
46+
template <typename T> \
47+
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr, \
48+
size_t arrsize, \
49+
bool hasnan = false, \
50+
bool descending = false); \
51+
template <typename T> \
52+
XSS_HIDE_SYMBOL std::vector<size_t> \
53+
argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false); \
54+
}
55+
756
namespace xss {
8-
namespace avx512 {
9-
// quicksort
10-
template <typename T>
11-
XSS_HIDE_SYMBOL void
12-
qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false);
13-
// key-value quicksort
14-
template <typename T1, typename T2>
15-
XSS_HIDE_SYMBOL void keyvalue_qsort(T1 *key,
16-
T2 *val,
17-
size_t arrsize,
18-
bool hasnan = false,
19-
bool descending = false);
20-
// quickselect
21-
template <typename T>
22-
XSS_HIDE_SYMBOL void qselect(T *arr,
23-
size_t k,
24-
size_t arrsize,
25-
bool hasnan = false,
26-
bool descending = false);
27-
// key-value select
28-
template <typename T1, typename T2>
29-
XSS_HIDE_SYMBOL void keyvalue_select(T1 *key,
30-
T2 *val,
31-
size_t k,
32-
size_t arrsize,
33-
bool hasnan = false,
34-
bool descending = false);
35-
// partial sort
36-
template <typename T>
37-
XSS_HIDE_SYMBOL void partial_qsort(T *arr,
38-
size_t k,
39-
size_t arrsize,
40-
bool hasnan = false,
41-
bool descending = false);
42-
// key-value partial sort
43-
template <typename T1, typename T2>
44-
XSS_HIDE_SYMBOL void keyvalue_partial_sort(T1 *key,
45-
T2 *val,
46-
size_t k,
47-
size_t arrsize,
48-
bool hasnan = false,
49-
bool descending = false);
50-
// argsort
51-
template <typename T>
52-
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr,
53-
size_t arrsize,
54-
bool hasnan = false,
55-
bool descending = false);
56-
// argselect
57-
template <typename T>
58-
XSS_HIDE_SYMBOL std::vector<size_t>
59-
argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false);
60-
} // namespace avx512
61-
namespace avx2 {
62-
// quicksort
63-
template <typename T>
64-
XSS_HIDE_SYMBOL void
65-
qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false);
66-
// key-value quicksort
67-
template <typename T1, typename T2>
68-
XSS_HIDE_SYMBOL void keyvalue_qsort(T1 *key,
69-
T2 *val,
70-
size_t arrsize,
71-
bool hasnan = false,
72-
bool descending = false);
73-
// quickselect
74-
template <typename T>
75-
XSS_HIDE_SYMBOL void qselect(T *arr,
76-
size_t k,
77-
size_t arrsize,
78-
bool hasnan = false,
79-
bool descending = false);
80-
// key-value select
81-
template <typename T1, typename T2>
82-
XSS_HIDE_SYMBOL void keyvalue_select(T1 *key,
83-
T2 *val,
84-
size_t k,
85-
size_t arrsize,
86-
bool hasnan = false,
87-
bool descending = false);
88-
// partial sort
89-
template <typename T>
90-
XSS_HIDE_SYMBOL void partial_qsort(T *arr,
91-
size_t k,
92-
size_t arrsize,
93-
bool hasnan = false,
94-
bool descending = false);
95-
// key-value partial sort
96-
template <typename T1, typename T2>
97-
XSS_HIDE_SYMBOL void keyvalue_partial_sort(T1 *key,
98-
T2 *val,
99-
size_t k,
100-
size_t arrsize,
101-
bool hasnan = false,
102-
bool descending = false);
103-
// argsort
104-
template <typename T>
105-
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr,
106-
size_t arrsize,
107-
bool hasnan = false,
108-
bool descending = false);
109-
// argselect
110-
template <typename T>
111-
XSS_HIDE_SYMBOL std::vector<size_t>
112-
argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false);
113-
} // namespace avx2
114-
namespace scalar {
115-
// quicksort
116-
template <typename T>
117-
XSS_HIDE_SYMBOL void
118-
qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false);
119-
// key-value quicksort
120-
template <typename T1, typename T2>
121-
XSS_HIDE_SYMBOL void keyvalue_qsort(T1 *key,
122-
T2 *val,
123-
size_t arrsize,
124-
bool hasnan = false,
125-
bool descending = false);
126-
// quickselect
127-
template <typename T>
128-
XSS_HIDE_SYMBOL void qselect(T *arr,
129-
size_t k,
130-
size_t arrsize,
131-
bool hasnan = false,
132-
bool descending = false);
133-
// key-value select
134-
template <typename T1, typename T2>
135-
XSS_HIDE_SYMBOL void keyvalue_select(T1 *key,
136-
T2 *val,
137-
size_t k,
138-
size_t arrsize,
139-
bool hasnan = false,
140-
bool descending = false);
141-
// partial sort
142-
template <typename T>
143-
XSS_HIDE_SYMBOL void partial_qsort(T *arr,
144-
size_t k,
145-
size_t arrsize,
146-
bool hasnan = false,
147-
bool descending = false);
148-
// key-value partial sort
149-
template <typename T1, typename T2>
150-
XSS_HIDE_SYMBOL void keyvalue_partial_sort(T1 *key,
151-
T2 *val,
152-
size_t k,
153-
size_t arrsize,
154-
bool hasnan = false,
155-
bool descending = false);
156-
// argsort
157-
template <typename T>
158-
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr,
159-
size_t arrsize,
160-
bool hasnan = false,
161-
bool descending = false);
162-
// argselect
163-
template <typename T>
164-
XSS_HIDE_SYMBOL std::vector<size_t>
165-
argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false);
166-
} // namespace scalar
57+
DECLAREALLFUNCS(avx512)
58+
DECLAREALLFUNCS(avx2)
59+
DECLAREALLFUNCS(scalar)
60+
DECLAREALLFUNCS(fp16_spr)
61+
DECLAREALLFUNCS(fp16_icl)
16762
} // namespace xss
16863
#endif

lib/x86simdsort-spr.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#include "x86simdsort-internal.h"
44

55
namespace xss {
6-
namespace avx512 {
6+
namespace fp16_spr {
77
template <>
88
void qsort(_Float16 *arr, size_t size, bool hasnan, bool descending)
99
{
@@ -27,5 +27,5 @@ namespace avx512 {
2727
{
2828
x86simdsortStatic::partial_qsort(arr, k, arrsize, hasnan, descending);
2929
}
30-
} // namespace avx512
30+
} // namespace fp16_spr
3131
} // namespace xss

lib/x86simdsort.cpp

+32-4
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,17 @@ namespace x86simdsort {
108108
return (*internal_argselect##TYPE)(arr, k, arrsize, hasnan); \
109109
}
110110

111+
/* simple constexpr function as a way around having #ifdef __FLT16_MAX__ block
112+
* within the DISPATCH macro */
113+
template <typename T>
114+
constexpr bool IS_TYPE_FLOAT16()
115+
{
116+
#ifdef __FLT16_MAX__
117+
if constexpr (std::is_same_v<T, _Float16>) { return true; }
118+
#endif
119+
return false;
120+
}
121+
111122
/* runtime dispatch mechanism */
112123
#define DISPATCH(func, TYPE, ISA) \
113124
DECLARE_INTERNAL_##func(TYPE) static __attribute__((constructor)) void \
@@ -118,7 +129,24 @@ namespace x86simdsort {
118129
std::string_view preferred_cpu = find_preferred_cpu(ISA); \
119130
if constexpr (dispatch_requested("avx512", ISA)) { \
120131
if (preferred_cpu.find("avx512") != std::string_view::npos) { \
121-
CAT(CAT(internal_, func), TYPE) = &xss::avx512::func<TYPE>; \
132+
if constexpr (IS_TYPE_FLOAT16<TYPE>()) { \
133+
if (preferred_cpu.find("avx512_spr") \
134+
!= std::string_view::npos) { \
135+
CAT(CAT(internal_, func), TYPE) \
136+
= &xss::fp16_spr::func<TYPE>; \
137+
return; \
138+
} \
139+
if (preferred_cpu.find("avx512_icl") \
140+
!= std::string_view::npos) { \
141+
CAT(CAT(internal_, func), TYPE) \
142+
= &xss::fp16_icl::func<TYPE>; \
143+
return; \
144+
} \
145+
} \
146+
else { \
147+
CAT(CAT(internal_, func), TYPE) \
148+
= &xss::avx512::func<TYPE>; \
149+
} \
122150
return; \
123151
} \
124152
} \
@@ -137,9 +165,9 @@ namespace x86simdsort {
137165
}
138166

139167
#ifdef __FLT16_MAX__
140-
DISPATCH(qsort, _Float16, ISA_LIST("avx512_spr"))
141-
DISPATCH(qselect, _Float16, ISA_LIST("avx512_spr"))
142-
DISPATCH(partial_qsort, _Float16, ISA_LIST("avx512_spr"))
168+
DISPATCH(qsort, _Float16, ISA_LIST("avx512_spr", "avx512_icl"))
169+
DISPATCH(qselect, _Float16, ISA_LIST("avx512_spr", "avx512_icl"))
170+
DISPATCH(partial_qsort, _Float16, ISA_LIST("avx512_spr", "avx512_icl"))
143171
DISPATCH(argsort, _Float16, ISA_LIST("none"))
144172
DISPATCH(argselect, _Float16, ISA_LIST("none"))
145173
#endif

0 commit comments

Comments
 (0)