6
6
enum class pivot_result_t : int { Normal, Sorted, Only2Values };
7
7
8
8
template <typename type_t >
9
- struct pivot_results {
10
-
9
+ struct pivot_results {
10
+
11
11
pivot_result_t result = pivot_result_t ::Normal;
12
12
type_t pivot = 0 ;
13
-
14
- pivot_results (type_t _pivot, pivot_result_t _result = pivot_result_t ::Normal){
13
+
14
+ pivot_results (type_t _pivot,
15
+ pivot_result_t _result = pivot_result_t ::Normal)
16
+ {
15
17
pivot = _pivot;
16
18
result = _result;
17
19
}
18
20
};
19
21
20
22
template <typename type_t >
21
- type_t next_value (type_t value){
23
+ type_t next_value (type_t value)
24
+ {
22
25
// TODO this probably handles non-native float16 wrong
23
- if constexpr (std::is_floating_point<type_t >::value){
26
+ if constexpr (std::is_floating_point<type_t >::value) {
24
27
return std::nextafter (value, std::numeric_limits<type_t >::infinity ());
25
- }else {
26
- if (value < std::numeric_limits< type_t >:: max ()) {
27
- return value + 1 ;
28
- } else {
28
+ }
29
+ else {
30
+ if (value < std::numeric_limits< type_t >:: max ()) { return value + 1 ; }
31
+ else {
29
32
return value;
30
33
}
31
34
}
@@ -96,23 +99,23 @@ X86_SIMD_SORT_INLINE type_t get_pivot_blocks(type_t *arr,
96
99
}
97
100
98
101
template <typename vtype, typename type_t >
99
- X86_SIMD_SORT_INLINE pivot_results<type_t > get_pivot_near_constant (type_t *arr,
100
- type_t commonValue,
101
- const arrsize_t left,
102
- const arrsize_t right);
102
+ X86_SIMD_SORT_INLINE pivot_results<type_t >
103
+ get_pivot_near_constant (type_t *arr,
104
+ type_t commonValue,
105
+ const arrsize_t left,
106
+ const arrsize_t right);
103
107
104
108
template <typename vtype, typename type_t >
105
- X86_SIMD_SORT_INLINE pivot_results<type_t > get_pivot_smart (type_t *arr,
106
- const arrsize_t left,
107
- const arrsize_t right)
109
+ X86_SIMD_SORT_INLINE pivot_results<type_t >
110
+ get_pivot_smart (type_t *arr, const arrsize_t left, const arrsize_t right)
108
111
{
109
112
using reg_t = typename vtype::reg_t ;
110
113
constexpr int numVecs = 4 ;
111
-
112
- if (right - left + 1 <= 4 * numVecs * vtype::numlanes){
113
- return pivot_results<type_t >(get_pivot<vtype>(arr, left, right));
114
+
115
+ if (right - left + 1 <= 4 * numVecs * vtype::numlanes) {
116
+ return pivot_results<type_t >(get_pivot<vtype>(arr, left, right));
114
117
}
115
-
118
+
116
119
constexpr int N = numVecs * vtype::numlanes;
117
120
118
121
arrsize_t width = (right - vtype::numlanes) - left;
@@ -122,111 +125,123 @@ X86_SIMD_SORT_INLINE pivot_results<type_t> get_pivot_smart(type_t *arr,
122
125
for (int i = 0 ; i < numVecs; i++) {
123
126
vecs[i] = vtype::loadu (arr + left + delta * i);
124
127
}
125
-
128
+
126
129
// Sort the samples
127
130
sort_vectors<vtype, numVecs>(vecs);
128
-
131
+
129
132
type_t samples[N];
130
- for (int i = 0 ; i < numVecs; i++){
133
+ for (int i = 0 ; i < numVecs; i++) {
131
134
vtype::storeu (samples + vtype::numlanes * i, vecs[i]);
132
135
}
133
-
136
+
134
137
type_t smallest = samples[0 ];
135
138
type_t largest = samples[N - 1 ];
136
139
type_t median = samples[N / 2 ];
137
-
138
- if (smallest == largest){
140
+
141
+ if (smallest == largest) {
139
142
// We have a very unlucky sample, or the array is constant / near constant
140
143
// Run a special function meant to deal with this situation
141
144
return get_pivot_near_constant<vtype, type_t >(arr, median, left, right);
142
- }else if (median != smallest && median != largest){
145
+ }
146
+ else if (median != smallest && median != largest) {
143
147
// We have a normal sample; use it's median
144
148
return pivot_results<type_t >(median);
145
- }else if (median == smallest){
149
+ }
150
+ else if (median == smallest) {
146
151
// If median == smallest, that implies approximately half the array is equal to smallest, unless we were very unlucky with our sample
147
152
// Try just doing the next largest value greater than this seemingly very common value to seperate them out
148
153
return pivot_results<type_t >(next_value<type_t >(median));
149
- }else if (median == largest){
154
+ }
155
+ else if (median == largest) {
150
156
// If median == largest, that implies approximately half the array is equal to largest, unless we were very unlucky with our sample
151
157
// Thus, median probably is a fine pivot, since it will move all of this common value into its own partition
152
158
return pivot_results<type_t >(median);
153
- }else {
159
+ }
160
+ else {
154
161
// Should be unreachable
155
162
return pivot_results<type_t >(median);
156
163
}
157
-
164
+
158
165
// Should be unreachable
159
166
return pivot_results<type_t >(median);
160
167
}
161
168
162
169
// Handles the case where we seem to have a near-constant array, since our sample of the array was constant
163
170
template <typename vtype, typename type_t >
164
- X86_SIMD_SORT_INLINE pivot_results<type_t > get_pivot_near_constant (type_t *arr,
165
- type_t commonValue,
166
- const arrsize_t left,
167
- const arrsize_t right)
171
+ X86_SIMD_SORT_INLINE pivot_results<type_t >
172
+ get_pivot_near_constant (type_t *arr,
173
+ type_t commonValue,
174
+ const arrsize_t left,
175
+ const arrsize_t right)
168
176
{
169
177
using reg_t = typename vtype::reg_t ;
170
-
178
+
171
179
arrsize_t index = left;
172
-
180
+
173
181
type_t value1 = 0 ;
174
182
type_t value2 = 0 ;
175
-
183
+
176
184
// First, search for any value not equal to the common value
177
185
// First vectorized
178
186
reg_t commonVec = vtype::set1 (commonValue);
179
- for (; index <= right - vtype::numlanes; index += vtype::numlanes){
187
+ for (; index <= right - vtype::numlanes; index += vtype::numlanes) {
180
188
reg_t data = vtype::loadu (arr + index );
181
- if (!vtype::all_false (vtype::knot_opmask (vtype::eq (data, commonVec)))){
189
+ if (!vtype::all_false (vtype::knot_opmask (vtype::eq (data, commonVec)))) {
182
190
break ;
183
191
}
184
192
}
185
-
193
+
186
194
// Than scalar at the end
187
- for (; index <= right; index ++){
188
- if (arr[index ] != commonValue){
195
+ for (; index <= right; index ++) {
196
+ if (arr[index ] != commonValue) {
189
197
value1 = arr[index ];
190
198
break ;
191
- }
199
+ }
192
200
}
193
-
194
- if (index == right + 1 ){
201
+
202
+ if (index == right + 1 ) {
195
203
// The array is completely constant
196
204
// Setting the second flag to true skips partitioning, as the array is constant and thus sorted
197
205
return pivot_results<type_t >(commonValue, pivot_result_t ::Sorted);
198
206
}
199
-
207
+
200
208
// Secondly, search for a second value not equal to either of the previous two
201
209
// First vectorized
202
210
reg_t value1Vec = vtype::set1 (value1);
203
- for (; index <= right - vtype::numlanes; index += vtype::numlanes){
211
+ for (; index <= right - vtype::numlanes; index += vtype::numlanes) {
204
212
reg_t data = vtype::loadu (arr + index );
205
- if (!vtype::all_false (vtype::knot_opmask (vtype::eq (data, commonVec))) && !vtype::all_false (vtype::knot_opmask (vtype::eq (data, value1Vec)))){
213
+ if (!vtype::all_false (vtype::knot_opmask (vtype::eq (data, commonVec)))
214
+ && !vtype::all_false (
215
+ vtype::knot_opmask (vtype::eq (data, value1Vec)))) {
206
216
break ;
207
217
}
208
218
}
209
-
219
+
210
220
// Then scalar
211
- for (; index <= right; index ++){
212
- if (arr[index ] != commonValue && arr[index ] != value1){
221
+ for (; index <= right; index ++) {
222
+ if (arr[index ] != commonValue && arr[index ] != value1) {
213
223
value2 = arr[index ];
214
224
break ;
215
- }
225
+ }
216
226
}
217
-
218
- if (index == right + 1 ){
227
+
228
+ if (index == right + 1 ) {
219
229
// The array contains only 2 values
220
230
// We must pick the larger one, else the right partition is empty
221
231
// We can also skip recursing, as it is guaranteed both partitions are constant after partitioning with the larger value
222
232
// TODO this logic now assumes we use greater than or equal to specifically when partitioning, might be worth noting that somewhere
223
233
type_t pivot = std::max (value1, commonValue, comparison_func<vtype>);
224
234
return pivot_results<type_t >(pivot, pivot_result_t ::Only2Values);
225
235
}
226
-
236
+
227
237
// The array has at least 3 distinct values. Use the middle one as the pivot
228
- type_t median = std::max (std::min (value1,value2, comparison_func<vtype>), std::min (std::max (value1,value2, comparison_func<vtype>),commonValue, comparison_func<vtype>), comparison_func<vtype>);
238
+ type_t median = std::max (
239
+ std::min (value1, value2, comparison_func<vtype>),
240
+ std::min (std::max (value1, value2, comparison_func<vtype>),
241
+ commonValue,
242
+ comparison_func<vtype>),
243
+ comparison_func<vtype>);
229
244
return pivot_results<type_t >(median);
230
245
}
231
246
232
- #endif
247
+ #endif
0 commit comments