7
7
#include < stan/math/opencl/kernel_generator/name_generator.hpp>
8
8
#include < stan/math/opencl/kernel_generator/as_operation_cl.hpp>
9
9
#include < stan/math/opencl/kernel_generator/calc_if.hpp>
10
+ #include < stan/math/opencl/kernel_generator/check_cl.hpp>
10
11
#include < stan/math/opencl/kernel_generator/load.hpp>
11
12
#include < stan/math/opencl/opencl_context.hpp>
12
13
#include < algorithm>
@@ -45,11 +46,12 @@ struct multi_result_kernel_internal {
45
46
* @param expressions expressions
46
47
*/
47
48
static void get_clear_events (
48
- std::vector<cl::Event>& events, const std::tuple<T_results...>& results,
49
- const std::tuple<T_expressions...>& expressions) {
50
- next::get_clear_events (events, results, expressions);
51
- std::get<N>(expressions).get_clear_write_events (events);
52
- std::get<N>(results).get_clear_read_write_events (events);
49
+ std::vector<cl::Event>& events,
50
+ const std::tuple<std::pair<T_results, T_expressions>...>&
51
+ assignment_pairs) {
52
+ next::get_clear_events (events, assignment_pairs);
53
+ std::get<N>(assignment_pairs).second .get_clear_write_events (events);
54
+ std::get<N>(assignment_pairs).first .get_clear_read_write_events (events);
53
55
}
54
56
/* *
55
57
* Assigns the dimensions of expressions to matching results if possible.
@@ -63,11 +65,12 @@ struct multi_result_kernel_internal {
63
65
* @param expressions expressions
64
66
*/
65
67
static void check_assign_dimensions (
66
- int n_rows, int n_cols, const std::tuple<T_results...>& results,
67
- const std::tuple<T_expressions...>& expressions) {
68
- next::check_assign_dimensions (n_rows, n_cols, results, expressions);
69
- const auto & expression = std::get<N>(expressions);
70
- const auto & result = std::get<N>(results);
68
+ int n_rows, int n_cols,
69
+ const std::tuple<std::pair<T_results, T_expressions>...>&
70
+ assignment_pairs) {
71
+ next::check_assign_dimensions (n_rows, n_cols, assignment_pairs);
72
+ const auto & expression = std::get<N>(assignment_pairs).second ;
73
+ const auto & result = std::get<N>(assignment_pairs).first ;
71
74
const char * function = " results.operator=" ;
72
75
if (!is_without_output<T_current_expression>::value) {
73
76
check_size_match (function, " Rows of " , " expression" ,
@@ -99,17 +102,18 @@ struct multi_result_kernel_internal {
99
102
static kernel_parts generate (
100
103
std::set<const operation_cl_base*>& generated, name_generator& ng,
101
104
const std::string& row_index_name, const std::string& col_index_name,
102
- const std::tuple<T_results...>& results,
103
- const std::tuple<T_expressions...>& expressions ) {
105
+ const std::tuple<std::pair< T_results, T_expressions> ...>&
106
+ assignment_pairs ) {
104
107
kernel_parts parts = next::generate (generated, ng, row_index_name,
105
- col_index_name, results, expressions );
108
+ col_index_name, assignment_pairs );
106
109
if (is_without_output<T_current_expression>::value) {
107
110
return parts;
108
111
}
109
112
kernel_parts parts0
110
- = std::get<N>(expressions)
111
- .get_whole_kernel_parts (generated, ng, row_index_name,
112
- col_index_name, std::get<N>(results));
113
+ = std::get<N>(assignment_pairs)
114
+ .second .get_whole_kernel_parts (
115
+ generated, ng, row_index_name, col_index_name,
116
+ std::get<N>(assignment_pairs).first );
113
117
parts += parts0;
114
118
return parts;
115
119
}
@@ -122,18 +126,19 @@ struct multi_result_kernel_internal {
122
126
* @param results results
123
127
* @param expressions expressions
124
128
*/
125
- static void set_args (std::set<const operation_cl_base*>& generated,
126
- cl::Kernel& kernel, int & arg_num,
127
- const std::tuple<T_results...>& results,
128
- const std::tuple<T_expressions...>& expressions) {
129
- next::set_args (generated, kernel, arg_num, results, expressions);
129
+ static void set_args (
130
+ std::set<const operation_cl_base*>& generated, cl::Kernel& kernel,
131
+ int & arg_num,
132
+ const std::tuple<std::pair<T_results, T_expressions>...>&
133
+ assignment_pairs) {
134
+ next::set_args (generated, kernel, arg_num, assignment_pairs);
130
135
131
136
if (is_without_output<T_current_expression>::value) {
132
137
return ;
133
138
}
134
139
135
- std::get<N>(expressions) .set_args (generated, kernel, arg_num);
136
- std::get<N>(results) .set_args (generated, kernel, arg_num);
140
+ std::get<N>(assignment_pairs). second .set_args (generated, kernel, arg_num);
141
+ std::get<N>(assignment_pairs). first .set_args (generated, kernel, arg_num);
137
142
}
138
143
139
144
/* *
@@ -142,12 +147,13 @@ struct multi_result_kernel_internal {
142
147
* @param results results
143
148
* @param expressions expressions
144
149
*/
145
- static void add_event (cl::Event e, const std::tuple<T_results...>& results,
146
- const std::tuple<T_expressions...>& expressions) {
147
- next::add_event (e, results, expressions);
150
+ static void add_event (
151
+ cl::Event e, const std::tuple<std::pair<T_results, T_expressions>...>&
152
+ assignment_pairs) {
153
+ next::add_event (e, assignment_pairs);
148
154
149
- std::get<N>(expressions) .add_read_event (e);
150
- std::get<N>(results) .add_write_event (e);
155
+ std::get<N>(assignment_pairs). second .add_read_event (e);
156
+ std::get<N>(assignment_pairs). first .add_write_event (e);
151
157
}
152
158
};
153
159
};
@@ -158,32 +164,36 @@ struct multi_result_kernel_internal<-1, T_results...> {
158
164
template <typename ... T_expressions>
159
165
struct inner {
160
166
static void get_clear_events (
161
- std::vector<cl::Event>& events, const std::tuple<T_results...>& results,
162
- const std::tuple<T_expressions...>& expressions) {}
167
+ std::vector<cl::Event>& events,
168
+ const std::tuple<std::pair<T_results, T_expressions>...>&
169
+ assignment_pairs) {}
163
170
164
171
static void check_assign_dimensions (
165
- int n_rows, int n_cols, const std::tuple<T_results...>& results,
166
- const std::tuple<T_expressions...>& expressions) {
172
+ int n_rows, int n_cols,
173
+ const std::tuple<std::pair<T_results, T_expressions>...>&
174
+ assignment_pairs) {
167
175
return ;
168
176
}
169
177
170
178
static kernel_parts generate (
171
179
std::set<const operation_cl_base*>& generated, name_generator& ng,
172
180
const std::string& row_index_name, const std::string& col_index_name,
173
- const std::tuple<T_results...>& results,
174
- const std::tuple<T_expressions...>& expressions ) {
181
+ const std::tuple<std::pair< T_results, T_expressions> ...>&
182
+ assignment_pairs ) {
175
183
return {};
176
184
}
177
185
178
- static void set_args (std::set<const operation_cl_base*>& generated,
179
- cl::Kernel& kernel, int & arg_num,
180
- const std::tuple<T_results...>& results,
181
- const std::tuple<T_expressions...>& expressions) {
186
+ static void set_args (
187
+ std::set<const operation_cl_base*>& generated, cl::Kernel& kernel,
188
+ int & arg_num,
189
+ const std::tuple<std::pair<T_results, T_expressions>...>&
190
+ assignment_pairs) {
182
191
return ;
183
192
}
184
193
185
- static void add_event (cl::Event e, const std::tuple<T_results...>& results,
186
- const std::tuple<T_expressions...>& expressions) {
194
+ static void add_event (
195
+ cl::Event e, const std::tuple<std::pair<T_results, T_expressions>...>&
196
+ assignment_pairs) {
187
197
return ;
188
198
}
189
199
};
@@ -286,10 +296,9 @@ class results_cl {
286
296
std::string get_kernel_source_for_evaluating_impl (
287
297
const expressions_cl<T_expressions...>& exprs,
288
298
std::index_sequence<Is...>) {
289
- return get_kernel_source_impl (
290
- std::forward_as_tuple (as_operation_cl (std::get<Is>(results_))...),
291
- std::forward_as_tuple (
292
- as_operation_cl (std::get<Is>(exprs.expressions_ ))...));
299
+ return get_kernel_source_impl (std::tuple_cat (make_assignment_pair (
300
+ as_operation_cl (std::get<Is>(results_)),
301
+ as_operation_cl (std::get<Is>(exprs.expressions_ )))...));
293
302
}
294
303
295
304
/* *
@@ -301,8 +310,7 @@ class results_cl {
301
310
*/
302
311
template <typename ... T_res, typename ... T_expressions>
303
312
static std::string get_kernel_source_impl (
304
- const std::tuple<T_res...>& results,
305
- const std::tuple<T_expressions...>& expressions) {
313
+ const std::tuple<std::pair<T_res, T_expressions>...>& assignment_pairs) {
306
314
using impl = typename internal::multi_result_kernel_internal<
307
315
std::tuple_size<std::tuple<T_expressions...>>::value - 1 ,
308
316
T_res...>::template inner<T_expressions...>;
@@ -312,7 +320,7 @@ class results_cl {
312
320
name_generator ng;
313
321
std::set<const operation_cl_base*> generated;
314
322
kernel_parts parts
315
- = impl::generate (generated, ng, " i" , " j" , results, expressions );
323
+ = impl::generate (generated, ng, " i" , " j" , assignment_pairs );
316
324
std::string src;
317
325
if (require_specific_local_size) {
318
326
src =
@@ -364,11 +372,10 @@ class results_cl {
364
372
*/
365
373
template <typename ... T_expressions, size_t ... Is>
366
374
void assignment (const expressions_cl<T_expressions...>& exprs,
367
- std::index_sequence<Is...>) {
368
- assignment_impl (
369
- std::forward_as_tuple (as_operation_cl (std::get<Is>(results_))...),
370
- std::forward_as_tuple (
371
- as_operation_cl (std::get<Is>(exprs.expressions_ ))...));
375
+ std::index_sequence<Is...>) {;
376
+ assignment_impl (std::tuple_cat (make_assignment_pair (
377
+ as_operation_cl (std::get<Is>(results_)),
378
+ as_operation_cl (std::get<Is>(exprs.expressions_ )))...));
372
379
}
373
380
374
381
/* *
@@ -379,8 +386,8 @@ class results_cl {
379
386
* @param expressions expressions
380
387
*/
381
388
template <typename ... T_res, typename ... T_expressions>
382
- static void assignment_impl (const std::tuple<T_res...>& results,
383
- const std::tuple<T_expressions...>& expressions ) {
389
+ static void assignment_impl (
390
+ const std::tuple<std::pair<T_res, T_expressions> ...>& assignment_pairs ) {
384
391
using T_First_Expr = typename std::remove_reference_t <
385
392
std::tuple_element_t <0 , std::tuple<T_expressions...>>>;
386
393
using impl = typename internal::multi_result_kernel_internal<
@@ -396,19 +403,27 @@ class results_cl {
396
403
static const bool require_specific_local_size = std::max (
397
404
{std::decay_t <T_expressions>::Deriv::require_specific_local_size...});
398
405
399
- int n_rows = std::get<0 >(expressions) .thread_rows ();
400
- int n_cols = std::get<0 >(expressions) .thread_cols ();
406
+ int n_rows = std::get<0 >(assignment_pairs). second .thread_rows ();
407
+ int n_cols = std::get<0 >(assignment_pairs). second .thread_cols ();
401
408
const char * function = " results_cl.assignment" ;
402
- impl::check_assign_dimensions (n_rows, n_cols, results, expressions );
409
+ impl::check_assign_dimensions (n_rows, n_cols, assignment_pairs );
403
410
if (n_rows * n_cols == 0 ) {
404
411
return ;
405
412
}
406
- check_nonnegative (function, " expr.rows()" , n_rows);
407
- check_nonnegative (function, " expr.cols()" , n_cols);
413
+ if (n_rows < 0 ) {
414
+ invalid_argument (function, " Number of rows of expression" , n_rows,
415
+ " must be nonnegative, but is " ,
416
+ " (broadcasted expressions can not be evaluated)" );
417
+ }
418
+ if (n_cols < 0 ) {
419
+ invalid_argument (function, " Number of columns of expression" , n_cols,
420
+ " must be nonnegative, but is " ,
421
+ " (broadcasted expressions can not be evaluated)" );
422
+ }
408
423
409
424
try {
410
425
if (impl::kernel_ () == NULL ) {
411
- std::string src = get_kernel_source_impl (results, expressions );
426
+ std::string src = get_kernel_source_impl (assignment_pairs );
412
427
auto opts = opencl_context.base_opts ();
413
428
impl::kernel_ = opencl_kernels::compile_kernel (
414
429
" calculate" , {view_kernel_helpers, src}, opts);
@@ -417,10 +432,10 @@ class results_cl {
417
432
int arg_num = 0 ;
418
433
419
434
std::set<const operation_cl_base*> generated;
420
- impl::set_args (generated, kernel, arg_num, results, expressions );
435
+ impl::set_args (generated, kernel, arg_num, assignment_pairs );
421
436
422
437
std::vector<cl::Event> events;
423
- impl::get_clear_events (events, results, expressions );
438
+ impl::get_clear_events (events, assignment_pairs );
424
439
cl::Event e;
425
440
if (require_specific_local_size) {
426
441
kernel.setArg (arg_num++, n_rows);
@@ -437,16 +452,50 @@ class results_cl {
437
452
cl::NDRange (n_rows, n_cols),
438
453
cl::NullRange, &events, &e);
439
454
}
440
- impl::add_event (e, results, expressions );
455
+ impl::add_event (e, assignment_pairs );
441
456
} catch (const cl::Error& e) {
442
457
check_opencl_error (function, e);
443
458
}
444
459
}
460
+
445
461
/* *
446
462
* Implementation of assignments of no expressions to no results
447
463
*/
448
- static void assignment_impl (const std::tuple<>& /* results*/ ,
449
- const std::tuple<>& /* expressions*/ ) {}
464
+ static void assignment_impl (const std::tuple<>& /* assignment_pairs*/ ) {}
465
+
466
+ /* *
467
+ * Makes a std::pair of one result and one expression and wraps it into a
468
+ * tuple.
469
+ * @param result result
470
+ * @param expression expression
471
+ * @return a tuple of pair of result and expression
472
+ */
473
+ template <typename T_result, typename T_expression>
474
+ static auto make_assignment_pair (T_result&& result,
475
+ T_expression&& expression) {
476
+ return std::make_tuple (std::pair<T_result&&, T_expression&&>(
477
+ std::forward<T_result>(result),
478
+ std::forward<T_expression>(expression)));
479
+ }
480
+
481
+ /* *
482
+ * Checks on scalars are done separately in this overload instead of in
483
+ * kernel.
484
+ * @param result result - check
485
+ * @param expression expression - bool scalar
486
+ * @return an empty tuple
487
+ */
488
+ template <typename Scal>
489
+ static std::tuple<> make_assignment_pair (
490
+ check_cl_<scalar_<Scal>>& result, scalar_<char > expression) {
491
+ if (!expression.a_ ) {
492
+ std::stringstream s;
493
+ s << result.function_ << " : " << result.err_variable_ << " = "
494
+ << result.arg_ .a_ << " , but it must be " << result.must_be_ << " !" ;
495
+ throw std::domain_error (s.str ());
496
+ }
497
+ return std::make_tuple ();
498
+ }
450
499
};
451
500
452
501
/* *
0 commit comments