Skip to content

Commit d24fe7e

Browse files
committed
make check_cl support scalars
1 parent 8123af7 commit d24fe7e

File tree

8 files changed

+161
-83
lines changed

8 files changed

+161
-83
lines changed

stan/math/opencl/kernel_generator/check_cl.hpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <stan/math/opencl/value_type.hpp>
88
#include <stan/math/opencl/kernel_generator/as_operation_cl.hpp>
99
#include <stan/math/opencl/kernel_generator/operation_cl_lhs.hpp>
10+
#include <stan/math/opencl/kernel_generator/scalar.hpp>
1011

1112
namespace stan {
1213
namespace math {
@@ -34,12 +35,13 @@ class check_cl_ : public operation_cl_lhs<check_cl_<T>, bool> {
3435
// buffer[1,2] are problematic indices
3536
matrix_cl<int> buffer_;
3637
matrix_cl<value_type_t<T>> value_;
38+
39+
public:
3740
T arg_;
3841
const char* function_;
3942
const char* err_variable_;
4043
const char* must_be_;
4144

42-
public:
4345
/**
4446
* Constructor.
4547
* @param function function name (for error messages)
@@ -157,6 +159,13 @@ class check_cl_ : public operation_cl_lhs<check_cl_<T>, bool> {
157159
* @return number of columns
158160
*/
159161
inline int cols() const { return arg_.cols(); }
162+
163+
/**
164+
* Assignment of a scalar bool triggers the scalar check.
165+
* @param condition whether the state is ok.
166+
* @throws std::domain_error condition is false (chack failed).
167+
*/
168+
void operator=(bool condition) { *this = as_operation_cl(condition); }
160169
};
161170

162171
/**
@@ -169,8 +178,7 @@ class check_cl_ : public operation_cl_lhs<check_cl_<T>, bool> {
169178
* @param y variable to check (for error messages)
170179
* @param must_be description of what the value must be (for error messages)
171180
*/
172-
template <typename T,
173-
typename = require_all_kernel_expressions_and_none_scalar_t<T>>
181+
template <typename T, typename = require_all_kernel_expressions_t<T>>
174182
inline auto check_cl(const char* function, const char* var_name, T&& y,
175183
const char* must_be) {
176184
return check_cl_<as_operation_cl_t<T>>(

stan/math/opencl/kernel_generator/multi_result_kernel.hpp

Lines changed: 114 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
88
#include <stan/math/opencl/kernel_generator/as_operation_cl.hpp>
99
#include <stan/math/opencl/kernel_generator/calc_if.hpp>
10+
#include <stan/math/opencl/kernel_generator/check_cl.hpp>
1011
#include <stan/math/opencl/kernel_generator/load.hpp>
1112
#include <stan/math/opencl/opencl_context.hpp>
1213
#include <algorithm>
@@ -45,11 +46,12 @@ struct multi_result_kernel_internal {
4546
* @param expressions expressions
4647
*/
4748
static void get_clear_events(
48-
std::vector<cl::Event>& events, const std::tuple<T_results...>& results,
49-
const std::tuple<T_expressions...>& expressions) {
50-
next::get_clear_events(events, results, expressions);
51-
std::get<N>(expressions).get_clear_write_events(events);
52-
std::get<N>(results).get_clear_read_write_events(events);
49+
std::vector<cl::Event>& events,
50+
const std::tuple<std::pair<T_results, T_expressions>...>&
51+
assignment_pairs) {
52+
next::get_clear_events(events, assignment_pairs);
53+
std::get<N>(assignment_pairs).second.get_clear_write_events(events);
54+
std::get<N>(assignment_pairs).first.get_clear_read_write_events(events);
5355
}
5456
/**
5557
* Assigns the dimensions of expressions to matching results if possible.
@@ -63,11 +65,12 @@ struct multi_result_kernel_internal {
6365
* @param expressions expressions
6466
*/
6567
static void check_assign_dimensions(
66-
int n_rows, int n_cols, const std::tuple<T_results...>& results,
67-
const std::tuple<T_expressions...>& expressions) {
68-
next::check_assign_dimensions(n_rows, n_cols, results, expressions);
69-
const auto& expression = std::get<N>(expressions);
70-
const auto& result = std::get<N>(results);
68+
int n_rows, int n_cols,
69+
const std::tuple<std::pair<T_results, T_expressions>...>&
70+
assignment_pairs) {
71+
next::check_assign_dimensions(n_rows, n_cols, assignment_pairs);
72+
const auto& expression = std::get<N>(assignment_pairs).second;
73+
const auto& result = std::get<N>(assignment_pairs).first;
7174
const char* function = "results.operator=";
7275
if (!is_without_output<T_current_expression>::value) {
7376
check_size_match(function, "Rows of ", "expression",
@@ -99,17 +102,18 @@ struct multi_result_kernel_internal {
99102
static kernel_parts generate(
100103
std::set<const operation_cl_base*>& generated, name_generator& ng,
101104
const std::string& row_index_name, const std::string& col_index_name,
102-
const std::tuple<T_results...>& results,
103-
const std::tuple<T_expressions...>& expressions) {
105+
const std::tuple<std::pair<T_results, T_expressions>...>&
106+
assignment_pairs) {
104107
kernel_parts parts = next::generate(generated, ng, row_index_name,
105-
col_index_name, results, expressions);
108+
col_index_name, assignment_pairs);
106109
if (is_without_output<T_current_expression>::value) {
107110
return parts;
108111
}
109112
kernel_parts parts0
110-
= std::get<N>(expressions)
111-
.get_whole_kernel_parts(generated, ng, row_index_name,
112-
col_index_name, std::get<N>(results));
113+
= std::get<N>(assignment_pairs)
114+
.second.get_whole_kernel_parts(
115+
generated, ng, row_index_name, col_index_name,
116+
std::get<N>(assignment_pairs).first);
113117
parts += parts0;
114118
return parts;
115119
}
@@ -122,18 +126,19 @@ struct multi_result_kernel_internal {
122126
* @param results results
123127
* @param expressions expressions
124128
*/
125-
static void set_args(std::set<const operation_cl_base*>& generated,
126-
cl::Kernel& kernel, int& arg_num,
127-
const std::tuple<T_results...>& results,
128-
const std::tuple<T_expressions...>& expressions) {
129-
next::set_args(generated, kernel, arg_num, results, expressions);
129+
static void set_args(
130+
std::set<const operation_cl_base*>& generated, cl::Kernel& kernel,
131+
int& arg_num,
132+
const std::tuple<std::pair<T_results, T_expressions>...>&
133+
assignment_pairs) {
134+
next::set_args(generated, kernel, arg_num, assignment_pairs);
130135

131136
if (is_without_output<T_current_expression>::value) {
132137
return;
133138
}
134139

135-
std::get<N>(expressions).set_args(generated, kernel, arg_num);
136-
std::get<N>(results).set_args(generated, kernel, arg_num);
140+
std::get<N>(assignment_pairs).second.set_args(generated, kernel, arg_num);
141+
std::get<N>(assignment_pairs).first.set_args(generated, kernel, arg_num);
137142
}
138143

139144
/**
@@ -142,12 +147,13 @@ struct multi_result_kernel_internal {
142147
* @param results results
143148
* @param expressions expressions
144149
*/
145-
static void add_event(cl::Event e, const std::tuple<T_results...>& results,
146-
const std::tuple<T_expressions...>& expressions) {
147-
next::add_event(e, results, expressions);
150+
static void add_event(
151+
cl::Event e, const std::tuple<std::pair<T_results, T_expressions>...>&
152+
assignment_pairs) {
153+
next::add_event(e, assignment_pairs);
148154

149-
std::get<N>(expressions).add_read_event(e);
150-
std::get<N>(results).add_write_event(e);
155+
std::get<N>(assignment_pairs).second.add_read_event(e);
156+
std::get<N>(assignment_pairs).first.add_write_event(e);
151157
}
152158
};
153159
};
@@ -158,32 +164,36 @@ struct multi_result_kernel_internal<-1, T_results...> {
158164
template <typename... T_expressions>
159165
struct inner {
160166
static void get_clear_events(
161-
std::vector<cl::Event>& events, const std::tuple<T_results...>& results,
162-
const std::tuple<T_expressions...>& expressions) {}
167+
std::vector<cl::Event>& events,
168+
const std::tuple<std::pair<T_results, T_expressions>...>&
169+
assignment_pairs) {}
163170

164171
static void check_assign_dimensions(
165-
int n_rows, int n_cols, const std::tuple<T_results...>& results,
166-
const std::tuple<T_expressions...>& expressions) {
172+
int n_rows, int n_cols,
173+
const std::tuple<std::pair<T_results, T_expressions>...>&
174+
assignment_pairs) {
167175
return;
168176
}
169177

170178
static kernel_parts generate(
171179
std::set<const operation_cl_base*>& generated, name_generator& ng,
172180
const std::string& row_index_name, const std::string& col_index_name,
173-
const std::tuple<T_results...>& results,
174-
const std::tuple<T_expressions...>& expressions) {
181+
const std::tuple<std::pair<T_results, T_expressions>...>&
182+
assignment_pairs) {
175183
return {};
176184
}
177185

178-
static void set_args(std::set<const operation_cl_base*>& generated,
179-
cl::Kernel& kernel, int& arg_num,
180-
const std::tuple<T_results...>& results,
181-
const std::tuple<T_expressions...>& expressions) {
186+
static void set_args(
187+
std::set<const operation_cl_base*>& generated, cl::Kernel& kernel,
188+
int& arg_num,
189+
const std::tuple<std::pair<T_results, T_expressions>...>&
190+
assignment_pairs) {
182191
return;
183192
}
184193

185-
static void add_event(cl::Event e, const std::tuple<T_results...>& results,
186-
const std::tuple<T_expressions...>& expressions) {
194+
static void add_event(
195+
cl::Event e, const std::tuple<std::pair<T_results, T_expressions>...>&
196+
assignment_pairs) {
187197
return;
188198
}
189199
};
@@ -286,10 +296,9 @@ class results_cl {
286296
std::string get_kernel_source_for_evaluating_impl(
287297
const expressions_cl<T_expressions...>& exprs,
288298
std::index_sequence<Is...>) {
289-
return get_kernel_source_impl(
290-
std::forward_as_tuple(as_operation_cl(std::get<Is>(results_))...),
291-
std::forward_as_tuple(
292-
as_operation_cl(std::get<Is>(exprs.expressions_))...));
299+
return get_kernel_source_impl(std::tuple_cat(make_assignment_pair(
300+
as_operation_cl(std::get<Is>(results_)),
301+
as_operation_cl(std::get<Is>(exprs.expressions_)))...));
293302
}
294303

295304
/**
@@ -301,8 +310,7 @@ class results_cl {
301310
*/
302311
template <typename... T_res, typename... T_expressions>
303312
static std::string get_kernel_source_impl(
304-
const std::tuple<T_res...>& results,
305-
const std::tuple<T_expressions...>& expressions) {
313+
const std::tuple<std::pair<T_res, T_expressions>...>& assignment_pairs) {
306314
using impl = typename internal::multi_result_kernel_internal<
307315
std::tuple_size<std::tuple<T_expressions...>>::value - 1,
308316
T_res...>::template inner<T_expressions...>;
@@ -312,7 +320,7 @@ class results_cl {
312320
name_generator ng;
313321
std::set<const operation_cl_base*> generated;
314322
kernel_parts parts
315-
= impl::generate(generated, ng, "i", "j", results, expressions);
323+
= impl::generate(generated, ng, "i", "j", assignment_pairs);
316324
std::string src;
317325
if (require_specific_local_size) {
318326
src =
@@ -364,11 +372,10 @@ class results_cl {
364372
*/
365373
template <typename... T_expressions, size_t... Is>
366374
void assignment(const expressions_cl<T_expressions...>& exprs,
367-
std::index_sequence<Is...>) {
368-
assignment_impl(
369-
std::forward_as_tuple(as_operation_cl(std::get<Is>(results_))...),
370-
std::forward_as_tuple(
371-
as_operation_cl(std::get<Is>(exprs.expressions_))...));
375+
std::index_sequence<Is...>) {;
376+
assignment_impl(std::tuple_cat(make_assignment_pair(
377+
as_operation_cl(std::get<Is>(results_)),
378+
as_operation_cl(std::get<Is>(exprs.expressions_)))...));
372379
}
373380

374381
/**
@@ -379,8 +386,8 @@ class results_cl {
379386
* @param expressions expressions
380387
*/
381388
template <typename... T_res, typename... T_expressions>
382-
static void assignment_impl(const std::tuple<T_res...>& results,
383-
const std::tuple<T_expressions...>& expressions) {
389+
static void assignment_impl(
390+
const std::tuple<std::pair<T_res, T_expressions>...>& assignment_pairs) {
384391
using T_First_Expr = typename std::remove_reference_t<
385392
std::tuple_element_t<0, std::tuple<T_expressions...>>>;
386393
using impl = typename internal::multi_result_kernel_internal<
@@ -396,19 +403,27 @@ class results_cl {
396403
static const bool require_specific_local_size = std::max(
397404
{std::decay_t<T_expressions>::Deriv::require_specific_local_size...});
398405

399-
int n_rows = std::get<0>(expressions).thread_rows();
400-
int n_cols = std::get<0>(expressions).thread_cols();
406+
int n_rows = std::get<0>(assignment_pairs).second.thread_rows();
407+
int n_cols = std::get<0>(assignment_pairs).second.thread_cols();
401408
const char* function = "results_cl.assignment";
402-
impl::check_assign_dimensions(n_rows, n_cols, results, expressions);
409+
impl::check_assign_dimensions(n_rows, n_cols, assignment_pairs);
403410
if (n_rows * n_cols == 0) {
404411
return;
405412
}
406-
check_nonnegative(function, "expr.rows()", n_rows);
407-
check_nonnegative(function, "expr.cols()", n_cols);
413+
if (n_rows < 0) {
414+
invalid_argument(function, "Number of rows of expression", n_rows,
415+
" must be nonnegative, but is ",
416+
" (broadcasted expressions can not be evaluated)");
417+
}
418+
if (n_cols < 0) {
419+
invalid_argument(function, "Number of columns of expression", n_cols,
420+
" must be nonnegative, but is ",
421+
" (broadcasted expressions can not be evaluated)");
422+
}
408423

409424
try {
410425
if (impl::kernel_() == NULL) {
411-
std::string src = get_kernel_source_impl(results, expressions);
426+
std::string src = get_kernel_source_impl(assignment_pairs);
412427
auto opts = opencl_context.base_opts();
413428
impl::kernel_ = opencl_kernels::compile_kernel(
414429
"calculate", {view_kernel_helpers, src}, opts);
@@ -417,10 +432,10 @@ class results_cl {
417432
int arg_num = 0;
418433

419434
std::set<const operation_cl_base*> generated;
420-
impl::set_args(generated, kernel, arg_num, results, expressions);
435+
impl::set_args(generated, kernel, arg_num, assignment_pairs);
421436

422437
std::vector<cl::Event> events;
423-
impl::get_clear_events(events, results, expressions);
438+
impl::get_clear_events(events, assignment_pairs);
424439
cl::Event e;
425440
if (require_specific_local_size) {
426441
kernel.setArg(arg_num++, n_rows);
@@ -437,16 +452,50 @@ class results_cl {
437452
cl::NDRange(n_rows, n_cols),
438453
cl::NullRange, &events, &e);
439454
}
440-
impl::add_event(e, results, expressions);
455+
impl::add_event(e, assignment_pairs);
441456
} catch (const cl::Error& e) {
442457
check_opencl_error(function, e);
443458
}
444459
}
460+
445461
/**
446462
* Implementation of assignments of no expressions to no results
447463
*/
448-
static void assignment_impl(const std::tuple<>& /*results*/,
449-
const std::tuple<>& /*expressions*/) {}
464+
static void assignment_impl(const std::tuple<>& /*assignment_pairs*/) {}
465+
466+
/**
467+
* Makes a std::pair of one result and one expression and wraps it into a
468+
* tuple.
469+
* @param result result
470+
* @param expression expression
471+
* @return a tuple of pair of result and expression
472+
*/
473+
template <typename T_result, typename T_expression>
474+
static auto make_assignment_pair(T_result&& result,
475+
T_expression&& expression) {
476+
return std::make_tuple(std::pair<T_result&&, T_expression&&>(
477+
std::forward<T_result>(result),
478+
std::forward<T_expression>(expression)));
479+
}
480+
481+
/**
482+
* Checks on scalars are done separately in this overload instead of in
483+
* kernel.
484+
* @param result result - check
485+
* @param expression expression - bool scalar
486+
* @return an empty tuple
487+
*/
488+
template <typename Scal>
489+
static std::tuple<> make_assignment_pair(
490+
check_cl_<scalar_<Scal>>& result, scalar_<char> expression) {
491+
if (!expression.a_) {
492+
std::stringstream s;
493+
s << result.function_ << ": " << result.err_variable_ << " = "
494+
<< result.arg_.a_ << ", but it must be " << result.must_be_ << "!";
495+
throw std::domain_error(s.str());
496+
}
497+
return std::make_tuple();
498+
}
450499
};
451500

452501
/**

0 commit comments

Comments
 (0)