7
7
#include < stan/math/opencl/value_type.hpp>
8
8
#include < stan/math/opencl/kernel_generator/as_operation_cl.hpp>
9
9
#include < stan/math/opencl/kernel_generator/operation_cl_lhs.hpp>
10
+ #include < stan/math/opencl/kernel_generator/scalar.hpp>
10
11
11
12
namespace stan {
12
13
namespace math {
@@ -34,12 +35,13 @@ class check_cl_ : public operation_cl_lhs<check_cl_<T>, bool> {
34
35
// buffer[1,2] are problematic indices
35
36
matrix_cl<int > buffer_;
36
37
matrix_cl<value_type_t <T>> value_;
38
+
39
+ public:
37
40
T arg_;
38
41
const char * function_;
39
42
const char * err_variable_;
40
43
const char * must_be_;
41
44
42
- public:
43
45
/* *
44
46
* Constructor.
45
47
* @param function function name (for error messages)
@@ -157,6 +159,13 @@ class check_cl_ : public operation_cl_lhs<check_cl_<T>, bool> {
157
159
* @return number of columns
158
160
*/
159
161
inline int cols () const { return arg_.cols (); }
162
+
163
+ /* *
164
+ * Assignment of a scalar bool triggers the scalar check.
165
+ * @param condition whether the state is ok.
166
+ * @throws std::domain_error condition is false (chack failed).
167
+ */
168
+ void operator =(bool condition) { *this = as_operation_cl (condition); }
160
169
};
161
170
162
171
/* *
@@ -169,8 +178,7 @@ class check_cl_ : public operation_cl_lhs<check_cl_<T>, bool> {
169
178
* @param y variable to check (for error messages)
170
179
* @param must_be description of what the value must be (for error messages)
171
180
*/
172
- template <typename T,
173
- typename = require_all_kernel_expressions_and_none_scalar_t <T>>
181
+ template <typename T, typename = require_all_kernel_expressions_t <T>>
174
182
inline auto check_cl (const char * function, const char * var_name, T&& y,
175
183
const char * must_be) {
176
184
return check_cl_<as_operation_cl_t <T>>(
0 commit comments