Skip to content

Commit b6cac57

Browse files
authored
Switch WASM TopK Algorithm to use Floyd-Rivest (#5244)
1 parent 2d16dc9 commit b6cac57

File tree

1 file changed

+98
-23
lines changed
  • tfjs-backend-wasm/src/cc/kernels

1 file changed

+98
-23
lines changed

tfjs-backend-wasm/src/cc/kernels/TopK.cc

Lines changed: 98 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#include <emscripten.h>
1717
#endif
1818

19-
#include <math.h>
2019
#include <algorithm>
2120
#include <cmath>
2221
#include <cstddef>
@@ -25,39 +24,115 @@
2524
#include "tfjs-backend-wasm/src/cc/backend.h"
2625
#include "tfjs-backend-wasm/src/cc/util.h"
2726

27+
using std::swap;
28+
using std::vector;
29+
2830
namespace tfjs {
2931

3032
template <typename T>
3133
struct ValAndInd {
3234
T value;
33-
int32_t index;
35+
size_t index;
36+
bool operator<(const ValAndInd& other) const {
37+
return value == other.value ? index < other.index : value > other.value;
38+
}
39+
bool operator==(const ValAndInd& other) const {
40+
return value == other.value && index == other.index;
41+
}
3442
};
3543

44+
template <typename T>
45+
T sign(T value) {
46+
if (value == 0) return 0;
47+
return value < 0 ? -1 : 1;
48+
}
49+
50+
template <typename T>
51+
void select(ValAndInd<T>* array, int k, int left, int right) {
52+
while (right > left) {
53+
// Use select recursively to sample a smaller set of size s
54+
// the arbitrary constants 600 and 0.5 are used in the original
55+
// version to minimize execution time.
56+
if (right - left > 600) {
57+
const int n = right - left + 1;
58+
const int i = k - left + 1;
59+
const auto z = log(n);
60+
const auto s = 0.5 * exp(2 * z / 3);
61+
const auto sd = 0.5 * sqrt(z * s * (n - s) / n) * sign(i - n / 2);
62+
const int newLeft = std::max(left, static_cast<int>(k - i * s / n + sd));
63+
const int newRight =
64+
std::min(right, static_cast<int>(k + (n - i) * s / n + sd));
65+
select(array, k, newLeft, newRight);
66+
}
67+
// partition the elements between left and right around t
68+
auto t = array[k];
69+
int i = left;
70+
int j = right;
71+
72+
swap(array[left], array[k]);
73+
74+
if (t < array[right]) {
75+
swap(array[left], array[right]);
76+
}
77+
while (i < j) {
78+
swap(array[i], array[j]);
79+
i++;
80+
j--;
81+
while (array[i] < t) {
82+
i = i + 1;
83+
}
84+
while (t < array[j]) {
85+
j = j - 1;
86+
}
87+
}
88+
if (array[left] == t) {
89+
swap(array[left], array[j]);
90+
} else {
91+
j = j + 1;
92+
swap(array[j], array[right]);
93+
}
94+
// Adjust left and right towards the boundaries of the subset
95+
// containing the (k - left + 1)th smallest element.
96+
if (j <= k) {
97+
left = j + 1;
98+
}
99+
if (k <= j) {
100+
right = j - 1;
101+
}
102+
}
103+
}
104+
36105
// Based on tfjs-core/src/backends/topk_impl.ts
37106
template <typename T>
38-
void topk(const T* x_data, const size_t x_len,
39-
const std::vector<size_t>& x_shape, const int k, const bool sorted,
40-
T* out_values_data, int32_t* out_indices_data) {
41-
int last_dim = x_shape.back();
42-
int batch = x_len / last_dim;
43-
int size = last_dim;
44-
45-
for (int b = 0; b < batch; b++) {
46-
int offset = b * size;
47-
std::vector<ValAndInd<T>> val_and_ind;
48-
for (int i = offset; i < offset + size; i++) {
107+
void topk(const T* x_data, const size_t x_len, const vector<size_t>& x_shape,
108+
const int k, const bool sorted, T* out_values_data,
109+
int32_t* out_indices_data) {
110+
size_t last_dim = x_shape.back();
111+
size_t batch = x_len / last_dim;
112+
size_t size = last_dim;
113+
114+
for (size_t b = 0; b < batch; b++) {
115+
size_t offset = b * size;
116+
vector<ValAndInd<T>> val_and_ind;
117+
val_and_ind.reserve(size);
118+
for (size_t i = offset; i < offset + size; i++) {
49119
val_and_ind.push_back({.value = x_data[i], .index = i - offset});
50120
}
51-
std::sort(val_and_ind.begin(), val_and_ind.end(),
52-
[](const ValAndInd<T>& a, const ValAndInd<T>& b) -> bool {
53-
return a.value == b.value ? a.index < b.index
54-
: a.value > b.value;
55-
});
56-
int out_offset = b * k;
57-
for (int i = 0; i < k; i++) {
58-
int index = out_offset + i;
121+
122+
if (k < size) {
123+
select(val_and_ind.data(), k, 0, size - 1);
124+
val_and_ind.resize(k);
125+
}
126+
127+
if (sorted) {
128+
std::sort(val_and_ind.begin(), val_and_ind.end());
129+
}
130+
131+
size_t out_offset = b * k;
132+
for (size_t i = 0; i < k; i++) {
133+
size_t index = out_offset + i;
59134
out_values_data[index] = val_and_ind[i].value;
60-
out_indices_data[index] = val_and_ind[i].index;
135+
out_indices_data[index] = static_cast<int32_t>(val_and_ind[i].index);
61136
}
62137
}
63138
}
@@ -74,7 +149,7 @@ void TopK(const size_t x_id, const size_t* x_shape_ptr,
74149
const size_t x_shape_length, const DType x_dtype, const int k,
75150
const bool sorted, const size_t out_values_id,
76151
const size_t out_indices_id) {
77-
auto x_shape = std::vector<size_t>(x_shape_ptr, x_shape_ptr + x_shape_length);
152+
auto x_shape = vector<size_t>(x_shape_ptr, x_shape_ptr + x_shape_length);
78153
auto& x_info = backend::get_tensor_info(x_id);
79154
auto& out_values_info = backend::get_tensor_info_out(out_values_id);
80155
auto& out_indices_info = backend::get_tensor_info_out(out_indices_id);

0 commit comments

Comments
 (0)