3
3
#ifdef STAN_OPENCL
4
4
5
5
#include < stan/math/prim/err.hpp>
6
- #include < stan/math/opencl/kernel_generator/wrapper.hpp>
7
6
#include < stan/math/opencl/kernel_generator/is_kernel_expression.hpp>
8
7
#include < stan/math/opencl/kernel_generator/name_generator.hpp>
9
8
#include < stan/math/opencl/kernel_generator/as_operation_cl.hpp>
@@ -46,12 +45,11 @@ struct multi_result_kernel_internal {
46
45
* @param expressions expressions
47
46
*/
48
47
static void get_clear_events (
49
- std::vector<cl::Event>& events,
50
- const std::tuple<wrapper<T_results>...>& results,
51
- const std::tuple<wrapper<T_expressions>...>& expressions) {
48
+ std::vector<cl::Event>& events, const std::tuple<T_results...>& results,
49
+ const std::tuple<T_expressions...>& expressions) {
52
50
next::get_clear_events (events, results, expressions);
53
- std::get<N>(expressions).x . get_clear_write_events (events);
54
- std::get<N>(results).x . get_clear_read_write_events (events);
51
+ std::get<N>(expressions).get_clear_write_events (events);
52
+ std::get<N>(results).get_clear_read_write_events (events);
55
53
}
56
54
/* *
57
55
* Assigns the dimensions of expressions to matching results if possible.
@@ -65,12 +63,11 @@ struct multi_result_kernel_internal {
65
63
* @param expressions expressions
66
64
*/
67
65
static void check_assign_dimensions (
68
- int n_rows, int n_cols,
69
- const std::tuple<wrapper<T_results>...>& results,
70
- const std::tuple<wrapper<T_expressions>...>& expressions) {
66
+ int n_rows, int n_cols, const std::tuple<T_results...>& results,
67
+ const std::tuple<T_expressions...>& expressions) {
71
68
next::check_assign_dimensions (n_rows, n_cols, results, expressions);
72
- const auto & expression = std::get<N>(expressions). x ;
73
- const auto & result = std::get<N>(results). x ;
69
+ const auto & expression = std::get<N>(expressions);
70
+ const auto & result = std::get<N>(results);
74
71
const char * function = " results.operator=" ;
75
72
if (!is_without_output<T_current_expression>::value) {
76
73
check_size_match (function, " Rows of " , " expression" ,
@@ -102,17 +99,17 @@ struct multi_result_kernel_internal {
102
99
static kernel_parts generate (
103
100
std::set<const operation_cl_base*>& generated, name_generator& ng,
104
101
const std::string& row_index_name, const std::string& col_index_name,
105
- const std::tuple<wrapper< T_results> ...>& results,
106
- const std::tuple<wrapper< T_expressions> ...>& expressions) {
102
+ const std::tuple<T_results...>& results,
103
+ const std::tuple<T_expressions...>& expressions) {
107
104
kernel_parts parts = next::generate (generated, ng, row_index_name,
108
105
col_index_name, results, expressions);
109
106
if (is_without_output<T_current_expression>::value) {
110
107
return parts;
111
108
}
112
- kernel_parts parts0 = std::get<N>(expressions)
113
- . x . get_whole_kernel_parts (
114
- generated, ng, row_index_name,
115
- col_index_name, std::get<N>(results). x );
109
+ kernel_parts parts0
110
+ = std::get<N>(expressions)
111
+ . get_whole_kernel_parts ( generated, ng, row_index_name,
112
+ col_index_name, std::get<N>(results));
116
113
parts += parts0;
117
114
return parts;
118
115
}
@@ -125,18 +122,18 @@ struct multi_result_kernel_internal {
125
122
* @param results results
126
123
* @param expressions expressions
127
124
*/
128
- static void set_args (
129
- std::set< const operation_cl_base*>& generated, cl::Kernel& kernel,
130
- int & arg_num, const std::tuple<wrapper< T_results> ...>& results,
131
- const std::tuple<wrapper< T_expressions> ...>& expressions) {
125
+ static void set_args (std::set< const operation_cl_base*>& generated,
126
+ cl::Kernel& kernel, int & arg_num ,
127
+ const std::tuple<T_results...>& results,
128
+ const std::tuple<T_expressions...>& expressions) {
132
129
next::set_args (generated, kernel, arg_num, results, expressions);
133
130
134
131
if (is_without_output<T_current_expression>::value) {
135
132
return ;
136
133
}
137
134
138
- std::get<N>(expressions).x . set_args (generated, kernel, arg_num);
139
- std::get<N>(results).x . set_args (generated, kernel, arg_num);
135
+ std::get<N>(expressions).set_args (generated, kernel, arg_num);
136
+ std::get<N>(results).set_args (generated, kernel, arg_num);
140
137
}
141
138
142
139
/* *
@@ -145,13 +142,12 @@ struct multi_result_kernel_internal {
145
142
* @param results results
146
143
* @param expressions expressions
147
144
*/
148
- static void add_event (
149
- cl::Event e, const std::tuple<wrapper<T_results>...>& results,
150
- const std::tuple<wrapper<T_expressions>...>& expressions) {
145
+ static void add_event (cl::Event e, const std::tuple<T_results...>& results,
146
+ const std::tuple<T_expressions...>& expressions) {
151
147
next::add_event (e, results, expressions);
152
148
153
- std::get<N>(expressions).x . add_read_event (e);
154
- std::get<N>(results).x . add_write_event (e);
149
+ std::get<N>(expressions).add_read_event (e);
150
+ std::get<N>(results).add_write_event (e);
155
151
}
156
152
};
157
153
};
@@ -162,35 +158,32 @@ struct multi_result_kernel_internal<-1, T_results...> {
162
158
template <typename ... T_expressions>
163
159
struct inner {
164
160
static void get_clear_events (
165
- std::vector<cl::Event>& events,
166
- const std::tuple<wrapper<T_results>...>& results,
167
- const std::tuple<wrapper<T_expressions>...>& expressions) {}
161
+ std::vector<cl::Event>& events, const std::tuple<T_results...>& results,
162
+ const std::tuple<T_expressions...>& expressions) {}
168
163
169
164
static void check_assign_dimensions (
170
- int n_rows, int n_cols,
171
- const std::tuple<wrapper<T_results>...>& results,
172
- const std::tuple<wrapper<T_expressions>...>& expressions) {
165
+ int n_rows, int n_cols, const std::tuple<T_results...>& results,
166
+ const std::tuple<T_expressions...>& expressions) {
173
167
return ;
174
168
}
175
169
176
170
static kernel_parts generate (
177
171
std::set<const operation_cl_base*>& generated, name_generator& ng,
178
172
const std::string& row_index_name, const std::string& col_index_name,
179
- const std::tuple<wrapper< T_results> ...>& results,
180
- const std::tuple<wrapper< T_expressions> ...>& expressions) {
173
+ const std::tuple<T_results...>& results,
174
+ const std::tuple<T_expressions...>& expressions) {
181
175
return {};
182
176
}
183
177
184
- static void set_args (
185
- std::set< const operation_cl_base*>& generated, cl::Kernel& kernel,
186
- int & arg_num, const std::tuple<wrapper< T_results> ...>& results,
187
- const std::tuple<wrapper< T_expressions> ...>& expressions) {
178
+ static void set_args (std::set< const operation_cl_base*>& generated,
179
+ cl::Kernel& kernel, int & arg_num ,
180
+ const std::tuple<T_results...>& results,
181
+ const std::tuple<T_expressions...>& expressions) {
188
182
return ;
189
183
}
190
184
191
- static void add_event (
192
- cl::Event e, const std::tuple<wrapper<T_results>...>& results,
193
- const std::tuple<wrapper<T_expressions>...>& expressions) {
185
+ static void add_event (cl::Event e, const std::tuple<T_results...>& results,
186
+ const std::tuple<T_expressions...>& expressions) {
194
187
return ;
195
188
}
196
189
};
@@ -215,11 +208,11 @@ class expressions_cl {
215
208
* @param expressions expressions that will be calculated in same kernel.
216
209
*/
217
210
explicit expressions_cl (T_expressions&&... expressions)
218
- : expressions_(internal::wrapper<T_expressions>(
219
- std::forward<T_expressions>(expressions))...) {}
211
+ : expressions_(
212
+ T_expressions ( std::forward<T_expressions>(expressions))...) {}
220
213
221
214
private:
222
- std::tuple<internal::wrapper< T_expressions> ...> expressions_;
215
+ std::tuple<T_expressions...> expressions_;
223
216
template <typename ... T_results>
224
217
friend class results_cl ;
225
218
};
@@ -241,17 +234,15 @@ expressions_cl<T_expressions...> expressions(T_expressions&&... expressions) {
241
234
*/
242
235
template <typename ... T_results>
243
236
class results_cl {
244
- std::tuple<internal::wrapper< T_results> ...> results_;
237
+ std::tuple<T_results...> results_;
245
238
246
239
public:
247
240
/* *
248
241
* Constructor.
249
242
* @param results results that will be calculated in same kernel
250
243
*/
251
244
explicit results_cl (T_results&&... results)
252
- : results_(
253
- internal::wrapper<T_results>(std::forward<T_results>(results))...) {
254
- }
245
+ : results_(std::forward<T_results>(results)...) {}
255
246
256
247
/* *
257
248
* Assigning \c expressions_cl object to \c results_ object generates and
@@ -295,14 +286,10 @@ class results_cl {
295
286
std::string get_kernel_source_for_evaluating_impl (
296
287
const expressions_cl<T_expressions...>& exprs,
297
288
std::index_sequence<Is...>) {
298
- auto expressions = std::make_tuple (
299
- internal::make_wrapper (std::forward<decltype (as_operation_cl (
300
- std::get<Is>(exprs.expressions_ ).x ))>(
301
- as_operation_cl (std::get<Is>(exprs.expressions_ ).x )))...);
302
- auto results = std::make_tuple (internal::make_wrapper (
303
- std::forward<decltype (as_operation_cl (std::get<Is>(results_).x ))>(
304
- as_operation_cl (std::get<Is>(results_).x )))...);
305
- return get_kernel_source_impl (results, expressions);
289
+ return get_kernel_source_impl (
290
+ std::forward_as_tuple (as_operation_cl (std::get<Is>(results_))...),
291
+ std::forward_as_tuple (
292
+ as_operation_cl (std::get<Is>(exprs.expressions_ ))...));
306
293
}
307
294
308
295
/* *
@@ -314,8 +301,8 @@ class results_cl {
314
301
*/
315
302
template <typename ... T_res, typename ... T_expressions>
316
303
static std::string get_kernel_source_impl (
317
- const std::tuple<internal::wrapper< T_res> ...>& results,
318
- const std::tuple<internal::wrapper< T_expressions> ...>& expressions) {
304
+ const std::tuple<T_res...>& results,
305
+ const std::tuple<T_expressions...>& expressions) {
319
306
using impl = typename internal::multi_result_kernel_internal<
320
307
std::tuple_size<std::tuple<T_expressions...>>::value - 1 ,
321
308
T_res...>::template inner<T_expressions...>;
@@ -378,14 +365,10 @@ class results_cl {
378
365
template <typename ... T_expressions, size_t ... Is>
379
366
void assignment (const expressions_cl<T_expressions...>& exprs,
380
367
std::index_sequence<Is...>) {
381
- auto expressions = std::make_tuple (
382
- internal::make_wrapper (std::forward<decltype (as_operation_cl (
383
- std::get<Is>(exprs.expressions_ ).x ))>(
384
- as_operation_cl (std::get<Is>(exprs.expressions_ ).x )))...);
385
- auto results = std::make_tuple (internal::make_wrapper (
386
- std::forward<decltype (as_operation_cl (std::get<Is>(results_).x ))>(
387
- as_operation_cl (std::get<Is>(results_).x )))...);
388
- assignment_impl (results, expressions);
368
+ assignment_impl (
369
+ std::forward_as_tuple (as_operation_cl (std::get<Is>(results_))...),
370
+ std::forward_as_tuple (
371
+ as_operation_cl (std::get<Is>(exprs.expressions_ ))...));
389
372
}
390
373
391
374
/* *
@@ -396,9 +379,8 @@ class results_cl {
396
379
* @param expressions expressions
397
380
*/
398
381
template <typename ... T_res, typename ... T_expressions>
399
- static void assignment_impl (
400
- const std::tuple<internal::wrapper<T_res>...>& results,
401
- const std::tuple<internal::wrapper<T_expressions>...>& expressions) {
382
+ static void assignment_impl (const std::tuple<T_res...>& results,
383
+ const std::tuple<T_expressions...>& expressions) {
402
384
using T_First_Expr = typename std::remove_reference_t <
403
385
std::tuple_element_t <0 , std::tuple<T_expressions...>>>;
404
386
using impl = typename internal::multi_result_kernel_internal<
@@ -414,8 +396,8 @@ class results_cl {
414
396
static const bool require_specific_local_size = std::max (
415
397
{std::decay_t <T_expressions>::Deriv::require_specific_local_size...});
416
398
417
- int n_rows = std::get<0 >(expressions).x . thread_rows ();
418
- int n_cols = std::get<0 >(expressions).x . thread_cols ();
399
+ int n_rows = std::get<0 >(expressions).thread_rows ();
400
+ int n_cols = std::get<0 >(expressions).thread_cols ();
419
401
const char * function = " results_cl.assignment" ;
420
402
impl::check_assign_dimensions (n_rows, n_cols, results, expressions);
421
403
if (n_rows * n_cols == 0 ) {
0 commit comments