Skip to content

Commit 2b913b8

Browse files
committed
Rebase with main and fix formatting
1 parent af30caa commit 2b913b8

File tree

4 files changed

+89
-78
lines changed

4 files changed

+89
-78
lines changed

src/xss-common-qsort.h

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ X86_SIMD_SORT_INLINE bool array_has_nan(type_t *arr, arrsize_t size)
8787
else {
8888
in = vtype::loadu(arr + ii);
8989
}
90-
auto nanmask = vtype::convert_mask_to_int(vtype::template fpclass<0x01 | 0x80>(in));
90+
auto nanmask = vtype::convert_mask_to_int(
91+
vtype::template fpclass<0x01 | 0x80>(in));
9192
if (nanmask != 0x00) {
9293
found_nan = true;
9394
break;
@@ -498,24 +499,20 @@ qsort_(type_t *arr, arrsize_t left, arrsize_t right, arrsize_t max_iters)
498499
arr + left, (int32_t)(right + 1 - left));
499500
return;
500501
}
501-
502+
502503
auto pivot_result = get_pivot_smart<vtype, type_t>(arr, left, right);
503504
type_t pivot = pivot_result.pivot;
504-
505-
if (pivot_result.result == pivot_result_t::Sorted){
506-
return;
507-
}
508-
505+
506+
if (pivot_result.result == pivot_result_t::Sorted) { return; }
507+
509508
type_t smallest = vtype::type_max();
510509
type_t biggest = vtype::type_min();
511510

512511
arrsize_t pivot_index
513512
= partition_avx512_unrolled<vtype, vtype::partition_unroll_factor>(
514513
arr, left, right + 1, pivot, &smallest, &biggest);
515-
516-
if (pivot_result.result == pivot_result_t::Only2Values){
517-
return;
518-
}
514+
515+
if (pivot_result.result == pivot_result_t::Only2Values) { return; }
519516

520517
if (pivot != smallest)
521518
qsort_<vtype>(arr, left, pivot_index - 1, max_iters - 1);

src/xss-network-keyvaluesort.hpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -441,9 +441,8 @@ bitonic_fullmerge_n_vec(typename keyType::reg_t *keys,
441441
}
442442

443443
template <typename keyType, typename indexType, int numVecs>
444-
X86_SIMD_SORT_INLINE void argsort_n_vec(typename keyType::type_t *keys,
445-
arrsize_t *indices,
446-
int N)
444+
X86_SIMD_SORT_INLINE void
445+
argsort_n_vec(typename keyType::type_t *keys, arrsize_t *indices, int N)
447446
{
448447
using kreg_t = typename keyType::reg_t;
449448
using ireg_t = typename indexType::reg_t;
@@ -586,9 +585,8 @@ X86_SIMD_SORT_INLINE void kvsort_n_vec(typename keyType::type_t *keys,
586585
}
587586

588587
template <typename keyType, typename indexType, int maxN>
589-
X86_SIMD_SORT_INLINE void argsort_n(typename keyType::type_t *keys,
590-
arrsize_t *indices,
591-
int N)
588+
X86_SIMD_SORT_INLINE void
589+
argsort_n(typename keyType::type_t *keys, arrsize_t *indices, int N)
592590
{
593591
static_assert(keyType::numlanes == indexType::numlanes,
594592
"invalid pairing of value/index types");

src/xss-network-qsort.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ X86_SIMD_SORT_FINLINE void merge_n_vec(reg_t *regs)
144144
}
145145

146146
template <typename vtype, int numVecs, typename reg_t = typename vtype::reg_t>
147-
X86_SIMD_SORT_FINLINE void sort_vectors(reg_t * vecs){
147+
X86_SIMD_SORT_FINLINE void sort_vectors(reg_t *vecs)
148+
{
148149
/* Run the initial sorting network to sort the columns of the [numVecs x
149150
* num_lanes] matrix
150151
*/
@@ -188,7 +189,7 @@ X86_SIMD_SORT_INLINE void sort_n_vec(typename vtype::type_t *arr, int N)
188189
vecs[i] = vtype::mask_loadu(
189190
vtype::zmm_max(), ioMasks[j], arr + i * vtype::numlanes);
190191
}
191-
192+
192193
sort_vectors<vtype, numVecs>(vecs);
193194

194195
// Unmasked part of the store

src/xss-pivot-selection.hpp

Lines changed: 74 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,29 @@
66
enum class pivot_result_t : int { Normal, Sorted, Only2Values };
77

88
template <typename type_t>
9-
struct pivot_results{
10-
9+
struct pivot_results {
10+
1111
pivot_result_t result = pivot_result_t::Normal;
1212
type_t pivot = 0;
13-
14-
pivot_results(type_t _pivot, pivot_result_t _result = pivot_result_t::Normal){
13+
14+
pivot_results(type_t _pivot,
15+
pivot_result_t _result = pivot_result_t::Normal)
16+
{
1517
pivot = _pivot;
1618
result = _result;
1719
}
1820
};
1921

2022
template <typename type_t>
21-
type_t next_value(type_t value){
23+
type_t next_value(type_t value)
24+
{
2225
// TODO this probably handles non-native float16 wrong
23-
if constexpr (std::is_floating_point<type_t>::value){
26+
if constexpr (std::is_floating_point<type_t>::value) {
2427
return std::nextafter(value, std::numeric_limits<type_t>::infinity());
25-
}else{
26-
if (value < std::numeric_limits<type_t>::max()){
27-
return value + 1;
28-
}else{
28+
}
29+
else {
30+
if (value < std::numeric_limits<type_t>::max()) { return value + 1; }
31+
else {
2932
return value;
3033
}
3134
}
@@ -96,23 +99,23 @@ X86_SIMD_SORT_INLINE type_t get_pivot_blocks(type_t *arr,
9699
}
97100

98101
template <typename vtype, typename type_t>
99-
X86_SIMD_SORT_INLINE pivot_results<type_t> get_pivot_near_constant(type_t *arr,
100-
type_t commonValue,
101-
const arrsize_t left,
102-
const arrsize_t right);
102+
X86_SIMD_SORT_INLINE pivot_results<type_t>
103+
get_pivot_near_constant(type_t *arr,
104+
type_t commonValue,
105+
const arrsize_t left,
106+
const arrsize_t right);
103107

104108
template <typename vtype, typename type_t>
105-
X86_SIMD_SORT_INLINE pivot_results<type_t> get_pivot_smart(type_t *arr,
106-
const arrsize_t left,
107-
const arrsize_t right)
109+
X86_SIMD_SORT_INLINE pivot_results<type_t>
110+
get_pivot_smart(type_t *arr, const arrsize_t left, const arrsize_t right)
108111
{
109112
using reg_t = typename vtype::reg_t;
110113
constexpr int numVecs = 4;
111-
112-
if (right - left + 1 <= 4 * numVecs * vtype::numlanes){
113-
return pivot_results<type_t>(get_pivot<vtype>(arr, left, right));
114+
115+
if (right - left + 1 <= 4 * numVecs * vtype::numlanes) {
116+
return pivot_results<type_t>(get_pivot<vtype>(arr, left, right));
114117
}
115-
118+
116119
constexpr int N = numVecs * vtype::numlanes;
117120

118121
arrsize_t width = (right - vtype::numlanes) - left;
@@ -122,111 +125,123 @@ X86_SIMD_SORT_INLINE pivot_results<type_t> get_pivot_smart(type_t *arr,
122125
for (int i = 0; i < numVecs; i++) {
123126
vecs[i] = vtype::loadu(arr + left + delta * i);
124127
}
125-
128+
126129
// Sort the samples
127130
sort_vectors<vtype, numVecs>(vecs);
128-
131+
129132
type_t samples[N];
130-
for (int i = 0; i < numVecs; i++){
133+
for (int i = 0; i < numVecs; i++) {
131134
vtype::storeu(samples + vtype::numlanes * i, vecs[i]);
132135
}
133-
136+
134137
type_t smallest = samples[0];
135138
type_t largest = samples[N - 1];
136139
type_t median = samples[N / 2];
137-
138-
if (smallest == largest){
140+
141+
if (smallest == largest) {
139142
// We have a very unlucky sample, or the array is constant / near constant
140143
// Run a special function meant to deal with this situation
141144
return get_pivot_near_constant<vtype, type_t>(arr, median, left, right);
142-
}else if (median != smallest && median != largest){
145+
}
146+
else if (median != smallest && median != largest) {
143147
// We have a normal sample; use it's median
144148
return pivot_results<type_t>(median);
145-
}else if (median == smallest){
149+
}
150+
else if (median == smallest) {
146151
// If median == smallest, that implies approximately half the array is equal to smallest, unless we were very unlucky with our sample
147152
// Try just doing the next largest value greater than this seemingly very common value to seperate them out
148153
return pivot_results<type_t>(next_value<type_t>(median));
149-
}else if (median == largest){
154+
}
155+
else if (median == largest) {
150156
// If median == largest, that implies approximately half the array is equal to largest, unless we were very unlucky with our sample
151157
// Thus, median probably is a fine pivot, since it will move all of this common value into its own partition
152158
return pivot_results<type_t>(median);
153-
}else{
159+
}
160+
else {
154161
// Should be unreachable
155162
return pivot_results<type_t>(median);
156163
}
157-
164+
158165
// Should be unreachable
159166
return pivot_results<type_t>(median);
160167
}
161168

162169
// Handles the case where we seem to have a near-constant array, since our sample of the array was constant
163170
template <typename vtype, typename type_t>
164-
X86_SIMD_SORT_INLINE pivot_results<type_t> get_pivot_near_constant(type_t *arr,
165-
type_t commonValue,
166-
const arrsize_t left,
167-
const arrsize_t right)
171+
X86_SIMD_SORT_INLINE pivot_results<type_t>
172+
get_pivot_near_constant(type_t *arr,
173+
type_t commonValue,
174+
const arrsize_t left,
175+
const arrsize_t right)
168176
{
169177
using reg_t = typename vtype::reg_t;
170-
178+
171179
arrsize_t index = left;
172-
180+
173181
type_t value1 = 0;
174182
type_t value2 = 0;
175-
183+
176184
// First, search for any value not equal to the common value
177185
// First vectorized
178186
reg_t commonVec = vtype::set1(commonValue);
179-
for (; index <= right - vtype::numlanes; index += vtype::numlanes){
187+
for (; index <= right - vtype::numlanes; index += vtype::numlanes) {
180188
reg_t data = vtype::loadu(arr + index);
181-
if (!vtype::all_false(vtype::knot_opmask(vtype::eq(data, commonVec)))){
189+
if (!vtype::all_false(vtype::knot_opmask(vtype::eq(data, commonVec)))) {
182190
break;
183191
}
184192
}
185-
193+
186194
// Than scalar at the end
187-
for (; index <= right; index++){
188-
if (arr[index] != commonValue){
195+
for (; index <= right; index++) {
196+
if (arr[index] != commonValue) {
189197
value1 = arr[index];
190198
break;
191-
}
199+
}
192200
}
193-
194-
if (index == right + 1){
201+
202+
if (index == right + 1) {
195203
// The array is completely constant
196204
// Setting the second flag to true skips partitioning, as the array is constant and thus sorted
197205
return pivot_results<type_t>(commonValue, pivot_result_t::Sorted);
198206
}
199-
207+
200208
// Secondly, search for a second value not equal to either of the previous two
201209
// First vectorized
202210
reg_t value1Vec = vtype::set1(value1);
203-
for (; index <= right - vtype::numlanes; index += vtype::numlanes){
211+
for (; index <= right - vtype::numlanes; index += vtype::numlanes) {
204212
reg_t data = vtype::loadu(arr + index);
205-
if (!vtype::all_false(vtype::knot_opmask(vtype::eq(data, commonVec))) && !vtype::all_false(vtype::knot_opmask(vtype::eq(data, value1Vec)))){
213+
if (!vtype::all_false(vtype::knot_opmask(vtype::eq(data, commonVec)))
214+
&& !vtype::all_false(
215+
vtype::knot_opmask(vtype::eq(data, value1Vec)))) {
206216
break;
207217
}
208218
}
209-
219+
210220
// Then scalar
211-
for (; index <= right; index++){
212-
if (arr[index] != commonValue && arr[index] != value1){
221+
for (; index <= right; index++) {
222+
if (arr[index] != commonValue && arr[index] != value1) {
213223
value2 = arr[index];
214224
break;
215-
}
225+
}
216226
}
217-
218-
if (index == right + 1){
227+
228+
if (index == right + 1) {
219229
// The array contains only 2 values
220230
// We must pick the larger one, else the right partition is empty
221231
// We can also skip recursing, as it is guaranteed both partitions are constant after partitioning with the larger value
222232
// TODO this logic now assumes we use greater than or equal to specifically when partitioning, might be worth noting that somewhere
223233
type_t pivot = std::max(value1, commonValue, comparison_func<vtype>);
224234
return pivot_results<type_t>(pivot, pivot_result_t::Only2Values);
225235
}
226-
236+
227237
// The array has at least 3 distinct values. Use the middle one as the pivot
228-
type_t median = std::max(std::min(value1,value2, comparison_func<vtype>), std::min(std::max(value1,value2, comparison_func<vtype>),commonValue, comparison_func<vtype>), comparison_func<vtype>);
238+
type_t median = std::max(
239+
std::min(value1, value2, comparison_func<vtype>),
240+
std::min(std::max(value1, value2, comparison_func<vtype>),
241+
commonValue,
242+
comparison_func<vtype>),
243+
comparison_func<vtype>);
229244
return pivot_results<type_t>(median);
230245
}
231246

232-
#endif
247+
#endif

0 commit comments

Comments
 (0)