16
16
#include < emscripten.h>
17
17
#endif
18
18
19
- #include < math.h>
20
19
#include < algorithm>
21
20
#include < cmath>
22
21
#include < cstddef>
25
24
#include " tfjs-backend-wasm/src/cc/backend.h"
26
25
#include " tfjs-backend-wasm/src/cc/util.h"
27
26
27
+ using std::swap;
28
+ using std::vector;
29
+
28
30
namespace tfjs {
29
31
30
32
template <typename T>
31
33
struct ValAndInd {
32
34
T value;
33
- int32_t index;
35
+ size_t index;
36
+ bool operator <(const ValAndInd& other) const {
37
+ return value == other.value ? index < other.index : value > other.value ;
38
+ }
39
+ bool operator ==(const ValAndInd& other) const {
40
+ return value == other.value && index == other.index ;
41
+ }
34
42
};
35
43
44
+ template <typename T>
45
+ T sign (T value) {
46
+ if (value == 0 ) return 0 ;
47
+ return value < 0 ? -1 : 1 ;
48
+ }
49
+
50
+ template <typename T>
51
+ void select (ValAndInd<T>* array, int k, int left, int right) {
52
+ while (right > left) {
53
+ // Use select recursively to sample a smaller set of size s
54
+ // the arbitrary constants 600 and 0.5 are used in the original
55
+ // version to minimize execution time.
56
+ if (right - left > 600 ) {
57
+ const int n = right - left + 1 ;
58
+ const int i = k - left + 1 ;
59
+ const auto z = log (n);
60
+ const auto s = 0.5 * exp (2 * z / 3 );
61
+ const auto sd = 0.5 * sqrt (z * s * (n - s) / n) * sign (i - n / 2 );
62
+ const int newLeft = std::max (left, static_cast <int >(k - i * s / n + sd));
63
+ const int newRight =
64
+ std::min (right, static_cast <int >(k + (n - i) * s / n + sd));
65
+ select (array, k, newLeft, newRight);
66
+ }
67
+ // partition the elements between left and right around t
68
+ auto t = array[k];
69
+ int i = left;
70
+ int j = right;
71
+
72
+ swap (array[left], array[k]);
73
+
74
+ if (t < array[right]) {
75
+ swap (array[left], array[right]);
76
+ }
77
+ while (i < j) {
78
+ swap (array[i], array[j]);
79
+ i++;
80
+ j--;
81
+ while (array[i] < t) {
82
+ i = i + 1 ;
83
+ }
84
+ while (t < array[j]) {
85
+ j = j - 1 ;
86
+ }
87
+ }
88
+ if (array[left] == t) {
89
+ swap (array[left], array[j]);
90
+ } else {
91
+ j = j + 1 ;
92
+ swap (array[j], array[right]);
93
+ }
94
+ // Adjust left and right towards the boundaries of the subset
95
+ // containing the (k - left + 1)th smallest element.
96
+ if (j <= k) {
97
+ left = j + 1 ;
98
+ }
99
+ if (k <= j) {
100
+ right = j - 1 ;
101
+ }
102
+ }
103
+ }
104
+
36
105
// Based on tfjs-core/src/backends/topk_impl.ts
37
106
template <typename T>
38
- void topk (const T* x_data, const size_t x_len,
39
- const std::vector<size_t >& x_shape, const int k, const bool sorted,
40
- T* out_values_data, int32_t * out_indices_data) {
41
- int last_dim = x_shape.back ();
42
- int batch = x_len / last_dim;
43
- int size = last_dim;
44
-
45
- for (int b = 0 ; b < batch; b++) {
46
- int offset = b * size;
47
- std::vector<ValAndInd<T>> val_and_ind;
48
- for (int i = offset; i < offset + size; i++) {
107
+ void topk (const T* x_data, const size_t x_len, const vector<size_t >& x_shape,
108
+ const int k, const bool sorted, T* out_values_data,
109
+ int32_t * out_indices_data) {
110
+ size_t last_dim = x_shape.back ();
111
+ size_t batch = x_len / last_dim;
112
+ size_t size = last_dim;
113
+
114
+ for (size_t b = 0 ; b < batch; b++) {
115
+ size_t offset = b * size;
116
+ vector<ValAndInd<T>> val_and_ind;
117
+ val_and_ind.reserve (size);
118
+ for (size_t i = offset; i < offset + size; i++) {
49
119
val_and_ind.push_back ({.value = x_data[i], .index = i - offset});
50
120
}
51
- std::sort (val_and_ind.begin (), val_and_ind.end (),
52
- [](const ValAndInd<T>& a, const ValAndInd<T>& b) -> bool {
53
- return a.value == b.value ? a.index < b.index
54
- : a.value > b.value ;
55
- });
56
- int out_offset = b * k;
57
- for (int i = 0 ; i < k; i++) {
58
- int index = out_offset + i;
121
+
122
+ if (k < size) {
123
+ select (val_and_ind.data (), k, 0 , size - 1 );
124
+ val_and_ind.resize (k);
125
+ }
126
+
127
+ if (sorted) {
128
+ std::sort (val_and_ind.begin (), val_and_ind.end ());
129
+ }
130
+
131
+ size_t out_offset = b * k;
132
+ for (size_t i = 0 ; i < k; i++) {
133
+ size_t index = out_offset + i;
59
134
out_values_data[index] = val_and_ind[i].value ;
60
- out_indices_data[index] = val_and_ind[i].index ;
135
+ out_indices_data[index] = static_cast < int32_t >( val_and_ind[i].index ) ;
61
136
}
62
137
}
63
138
}
@@ -74,7 +149,7 @@ void TopK(const size_t x_id, const size_t* x_shape_ptr,
74
149
const size_t x_shape_length, const DType x_dtype, const int k,
75
150
const bool sorted, const size_t out_values_id,
76
151
const size_t out_indices_id) {
77
- auto x_shape = std:: vector<size_t >(x_shape_ptr, x_shape_ptr + x_shape_length);
152
+ auto x_shape = vector<size_t >(x_shape_ptr, x_shape_ptr + x_shape_length);
78
153
auto & x_info = backend::get_tensor_info (x_id);
79
154
auto & out_values_info = backend::get_tensor_info_out (out_values_id);
80
155
auto & out_indices_info = backend::get_tensor_info_out (out_indices_id);
0 commit comments