Skip to content

Commit 9396adf

Browse files
committed
Use macros to reduce duplicated code
1 parent 1faa106 commit 9396adf

File tree

2 files changed

+68
-224
lines changed

2 files changed

+68
-224
lines changed

lib/x86simdsort-internal.h

+52-199
Original file line numberDiff line numberDiff line change
@@ -4,205 +4,58 @@
44
#include <stdint.h>
55
#include <vector>
66

7+
#define DECLAREALLFUNCS(name) \
8+
namespace name { \
9+
template <typename T> \
10+
XSS_HIDE_SYMBOL void \
11+
qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); \
12+
template <typename T1, typename T2> \
13+
XSS_HIDE_SYMBOL void keyvalue_qsort(T1 *key, \
14+
T2 *val, \
15+
size_t arrsize, \
16+
bool hasnan = false, \
17+
bool descending = false); \
18+
template <typename T> \
19+
XSS_HIDE_SYMBOL void qselect(T *arr, \
20+
size_t k, \
21+
size_t arrsize, \
22+
bool hasnan = false, \
23+
bool descending = false); \
24+
template <typename T1, typename T2> \
25+
XSS_HIDE_SYMBOL void keyvalue_select(T1 *key, \
26+
T2 *val, \
27+
size_t k, \
28+
size_t arrsize, \
29+
bool hasnan = false, \
30+
bool descending = false); \
31+
template <typename T> \
32+
XSS_HIDE_SYMBOL void partial_qsort(T *arr, \
33+
size_t k, \
34+
size_t arrsize, \
35+
bool hasnan = false, \
36+
bool descending = false); \
37+
template <typename T1, typename T2> \
38+
XSS_HIDE_SYMBOL void keyvalue_partial_sort(T1 *key, \
39+
T2 *val, \
40+
size_t k, \
41+
size_t arrsize, \
42+
bool hasnan = false, \
43+
bool descending = false); \
44+
template <typename T> \
45+
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr, \
46+
size_t arrsize, \
47+
bool hasnan = false, \
48+
bool descending = false); \
49+
template <typename T> \
50+
XSS_HIDE_SYMBOL std::vector<size_t> \
51+
argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false); \
52+
} \
53+
754
namespace xss {
8-
namespace avx512 {
9-
// quicksort
10-
template <typename T>
11-
XSS_HIDE_SYMBOL void
12-
qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false);
13-
// key-value quicksort
14-
template <typename T1, typename T2>
15-
XSS_HIDE_SYMBOL void keyvalue_qsort(T1 *key,
16-
T2 *val,
17-
size_t arrsize,
18-
bool hasnan = false,
19-
bool descending = false);
20-
// quickselect
21-
template <typename T>
22-
XSS_HIDE_SYMBOL void qselect(T *arr,
23-
size_t k,
24-
size_t arrsize,
25-
bool hasnan = false,
26-
bool descending = false);
27-
// key-value select
28-
template <typename T1, typename T2>
29-
XSS_HIDE_SYMBOL void keyvalue_select(T1 *key,
30-
T2 *val,
31-
size_t k,
32-
size_t arrsize,
33-
bool hasnan = false,
34-
bool descending = false);
35-
// partial sort
36-
template <typename T>
37-
XSS_HIDE_SYMBOL void partial_qsort(T *arr,
38-
size_t k,
39-
size_t arrsize,
40-
bool hasnan = false,
41-
bool descending = false);
42-
// key-value partial sort
43-
template <typename T1, typename T2>
44-
XSS_HIDE_SYMBOL void keyvalue_partial_sort(T1 *key,
45-
T2 *val,
46-
size_t k,
47-
size_t arrsize,
48-
bool hasnan = false,
49-
bool descending = false);
50-
// argsort
51-
template <typename T>
52-
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr,
53-
size_t arrsize,
54-
bool hasnan = false,
55-
bool descending = false);
56-
// argselect
57-
template <typename T>
58-
XSS_HIDE_SYMBOL std::vector<size_t>
59-
argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false);
60-
} // namespace avx512
61-
namespace avx2 {
62-
// quicksort
63-
template <typename T>
64-
XSS_HIDE_SYMBOL void
65-
qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false);
66-
// key-value quicksort
67-
template <typename T1, typename T2>
68-
XSS_HIDE_SYMBOL void keyvalue_qsort(T1 *key,
69-
T2 *val,
70-
size_t arrsize,
71-
bool hasnan = false,
72-
bool descending = false);
73-
// quickselect
74-
template <typename T>
75-
XSS_HIDE_SYMBOL void qselect(T *arr,
76-
size_t k,
77-
size_t arrsize,
78-
bool hasnan = false,
79-
bool descending = false);
80-
// key-value select
81-
template <typename T1, typename T2>
82-
XSS_HIDE_SYMBOL void keyvalue_select(T1 *key,
83-
T2 *val,
84-
size_t k,
85-
size_t arrsize,
86-
bool hasnan = false,
87-
bool descending = false);
88-
// partial sort
89-
template <typename T>
90-
XSS_HIDE_SYMBOL void partial_qsort(T *arr,
91-
size_t k,
92-
size_t arrsize,
93-
bool hasnan = false,
94-
bool descending = false);
95-
// key-value partial sort
96-
template <typename T1, typename T2>
97-
XSS_HIDE_SYMBOL void keyvalue_partial_sort(T1 *key,
98-
T2 *val,
99-
size_t k,
100-
size_t arrsize,
101-
bool hasnan = false,
102-
bool descending = false);
103-
// argsort
104-
template <typename T>
105-
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr,
106-
size_t arrsize,
107-
bool hasnan = false,
108-
bool descending = false);
109-
// argselect
110-
template <typename T>
111-
XSS_HIDE_SYMBOL std::vector<size_t>
112-
argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false);
113-
} // namespace avx2
114-
namespace scalar {
115-
// quicksort
116-
template <typename T>
117-
XSS_HIDE_SYMBOL void
118-
qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false);
119-
// key-value quicksort
120-
template <typename T1, typename T2>
121-
XSS_HIDE_SYMBOL void keyvalue_qsort(T1 *key,
122-
T2 *val,
123-
size_t arrsize,
124-
bool hasnan = false,
125-
bool descending = false);
126-
// quickselect
127-
template <typename T>
128-
XSS_HIDE_SYMBOL void qselect(T *arr,
129-
size_t k,
130-
size_t arrsize,
131-
bool hasnan = false,
132-
bool descending = false);
133-
// key-value select
134-
template <typename T1, typename T2>
135-
XSS_HIDE_SYMBOL void keyvalue_select(T1 *key,
136-
T2 *val,
137-
size_t k,
138-
size_t arrsize,
139-
bool hasnan = false,
140-
bool descending = false);
141-
// partial sort
142-
template <typename T>
143-
XSS_HIDE_SYMBOL void partial_qsort(T *arr,
144-
size_t k,
145-
size_t arrsize,
146-
bool hasnan = false,
147-
bool descending = false);
148-
// key-value partial sort
149-
template <typename T1, typename T2>
150-
XSS_HIDE_SYMBOL void keyvalue_partial_sort(T1 *key,
151-
T2 *val,
152-
size_t k,
153-
size_t arrsize,
154-
bool hasnan = false,
155-
bool descending = false);
156-
// argsort
157-
template <typename T>
158-
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr,
159-
size_t arrsize,
160-
bool hasnan = false,
161-
bool descending = false);
162-
// argselect
163-
template <typename T>
164-
XSS_HIDE_SYMBOL std::vector<size_t>
165-
argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false);
166-
} // namespace scalar
167-
namespace fp16_spr {
168-
// quicksort
169-
template <typename T>
170-
XSS_HIDE_SYMBOL void
171-
qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false);
172-
// quickselect
173-
template <typename T>
174-
XSS_HIDE_SYMBOL void qselect(T *arr,
175-
size_t k,
176-
size_t arrsize,
177-
bool hasnan = false,
178-
bool descending = false);
179-
// partial sort
180-
template <typename T>
181-
XSS_HIDE_SYMBOL void partial_qsort(T *arr,
182-
size_t k,
183-
size_t arrsize,
184-
bool hasnan = false,
185-
bool descending = false);
186-
} // namespace fp16_spr
187-
namespace fp16_icl {
188-
// quicksort
189-
template <typename T>
190-
XSS_HIDE_SYMBOL void
191-
qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false);
192-
// quickselect
193-
template <typename T>
194-
XSS_HIDE_SYMBOL void qselect(T *arr,
195-
size_t k,
196-
size_t arrsize,
197-
bool hasnan = false,
198-
bool descending = false);
199-
// partial sort
200-
template <typename T>
201-
XSS_HIDE_SYMBOL void partial_qsort(T *arr,
202-
size_t k,
203-
size_t arrsize,
204-
bool hasnan = false,
205-
bool descending = false);
206-
} // namespace fp16_icl
55+
DECLAREALLFUNCS(avx512)
56+
DECLAREALLFUNCS(avx2)
57+
DECLAREALLFUNCS(scalar)
58+
DECLAREALLFUNCS(fp16_spr)
59+
DECLAREALLFUNCS(fp16_icl)
20760
} // namespace xss
20861
#endif

lib/x86simdsort.cpp

+16-25
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,19 @@ namespace x86simdsort {
118118
std::string_view preferred_cpu = find_preferred_cpu(ISA); \
119119
if constexpr (dispatch_requested("avx512", ISA)) { \
120120
if (preferred_cpu.find("avx512") != std::string_view::npos) { \
121-
CAT(CAT(internal_, func), TYPE) = &xss::avx512::func<TYPE>; \
121+
if constexpr (std::is_same_v<TYPE, _Float16>) { \
122+
if (preferred_cpu.find("avx512_spr") != std::string_view::npos) { \
123+
CAT(CAT(internal_, func), TYPE) = &xss::fp16_spr::func<TYPE>; \
124+
return; \
125+
} \
126+
if (preferred_cpu.find("avx512_icl") != std::string_view::npos) { \
127+
CAT(CAT(internal_, func), TYPE) = &xss::fp16_icl::func<TYPE>; \
128+
return; \
129+
} \
130+
} \
131+
else { \
132+
CAT(CAT(internal_, func), TYPE) = &xss::avx512::func<TYPE>; \
133+
}\
122134
return; \
123135
} \
124136
} \
@@ -130,37 +142,16 @@ namespace x86simdsort {
130142
} \
131143
}
132144

133-
#define DISPATCH_FP16(func, TYPE, ISA) \
134-
DECLARE_INTERNAL_##func(TYPE) static __attribute__((constructor)) void \
135-
CAT(CAT(resolve_, func), TYPE)(void) \
136-
{ \
137-
CAT(CAT(internal_, func), TYPE) = &xss::scalar::func<TYPE>; \
138-
__builtin_cpu_init(); \
139-
std::string_view preferred_cpu = find_preferred_cpu(ISA); \
140-
if constexpr (dispatch_requested("avx512_spr", ISA)) { \
141-
if (preferred_cpu.find("avx512_spr") != std::string_view::npos) { \
142-
CAT(CAT(internal_, func), TYPE) = &xss::fp16_spr::func<TYPE>; \
143-
return; \
144-
} \
145-
} \
146-
if constexpr (dispatch_requested("avx512_icl", ISA)) { \
147-
if (preferred_cpu.find("avx512_icl") != std::string_view::npos) { \
148-
CAT(CAT(internal_, func), TYPE) = &xss::fp16_icl::func<TYPE>; \
149-
return; \
150-
} \
151-
} \
152-
}
153-
154145
#define ISA_LIST(...) \
155146
std::initializer_list<std::string_view> \
156147
{ \
157148
__VA_ARGS__ \
158149
}
159150

160151
#ifdef __FLT16_MAX__
161-
DISPATCH_FP16(qsort, _Float16, ISA_LIST("avx512_spr", "avx512_icl"))
162-
DISPATCH_FP16(qselect, _Float16, ISA_LIST("avx512_spr", "avx512_icl"))
163-
DISPATCH_FP16(partial_qsort, _Float16, ISA_LIST("avx512_spr", "avx512_icl"))
152+
DISPATCH(qsort, _Float16, ISA_LIST("avx512_spr", "avx512_icl"))
153+
DISPATCH(qselect, _Float16, ISA_LIST("avx512_spr", "avx512_icl"))
154+
DISPATCH(partial_qsort, _Float16, ISA_LIST("avx512_spr", "avx512_icl"))
164155
DISPATCH(argsort, _Float16, ISA_LIST("none"))
165156
DISPATCH(argselect, _Float16, ISA_LIST("none"))
166157
#endif

0 commit comments

Comments
 (0)