Skip to content

Commit 3001a22

Browse files
authored
Merge pull request #2044 from bstatcomp/check_cl_scalar_support
Check cl scalar support
2 parents 84f077b + d18343c commit 3001a22

File tree

10 files changed

+161
-148
lines changed

10 files changed

+161
-148
lines changed

stan/math/opencl/kernel_generator/check_cl.hpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <stan/math/opencl/value_type.hpp>
88
#include <stan/math/opencl/kernel_generator/as_operation_cl.hpp>
99
#include <stan/math/opencl/kernel_generator/operation_cl_lhs.hpp>
10+
#include <stan/math/opencl/kernel_generator/scalar.hpp>
1011

1112
namespace stan {
1213
namespace math {
@@ -34,12 +35,13 @@ class check_cl_ : public operation_cl_lhs<check_cl_<T>, bool> {
3435
// buffer[1,2] are problematic indices
3536
matrix_cl<int> buffer_;
3637
matrix_cl<value_type_t<T>> value_;
38+
39+
public:
3740
T arg_;
3841
const char* function_;
3942
const char* err_variable_;
4043
const char* must_be_;
4144

42-
public:
4345
/**
4446
* Constructor.
4547
* @param function function name (for error messages)
@@ -157,6 +159,13 @@ class check_cl_ : public operation_cl_lhs<check_cl_<T>, bool> {
157159
* @return number of columns
158160
*/
159161
inline int cols() const { return arg_.cols(); }
162+
163+
/**
164+
* Assignment of a scalar bool triggers the scalar check.
165+
* @param condition whether the state is ok.
166+
* @throws std::domain_error condition is false (chack failed).
167+
*/
168+
void operator=(bool condition) { *this = as_operation_cl(condition); }
160169
};
161170

162171
/**
@@ -169,8 +178,7 @@ class check_cl_ : public operation_cl_lhs<check_cl_<T>, bool> {
169178
* @param y variable to check (for error messages)
170179
* @param must_be description of what the value must be (for error messages)
171180
*/
172-
template <typename T,
173-
typename = require_all_kernel_expressions_and_none_scalar_t<T>>
181+
template <typename T, typename = require_all_kernel_expressions_t<T>>
174182
inline auto check_cl(const char* function, const char* var_name, T&& y,
175183
const char* must_be) {
176184
return check_cl_<as_operation_cl_t<T>>(

stan/math/opencl/kernel_generator/holder_cl.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
#include <stan/math/prim/functor.hpp>
77
#include <stan/math/opencl/err.hpp>
88
#include <stan/math/opencl/matrix_cl_view.hpp>
9-
#include <stan/math/opencl/kernel_generator/wrapper.hpp>
109
#include <stan/math/opencl/kernel_generator/type_str.hpp>
1110
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
1211
#include <stan/math/opencl/kernel_generator/operation_cl.hpp>

0 commit comments

Comments
 (0)