Skip to content

Commit 1decede

Browse files
authored
Merge pull request #1726 from bstatcomp/cl_kernel_generator_bugfix_common_subexprs
Bugfix common subexpression elimination in kernel generator
2 parents bb3a5ca + 496161e commit 1decede

File tree

9 files changed

+173
-37
lines changed

9 files changed

+173
-37
lines changed

stan/math/opencl/kernel_generator/binary_operation.hpp

+24-4
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,20 @@ class binary_operation : public operation_cl<Derived, T_res, T_a, T_b> {
109109
template <typename T_a, typename T_b> \
110110
class class_name : public binary_operation<class_name<T_a, T_b>, \
111111
scalar_type_expr, T_a, T_b> { \
112+
using base \
113+
= binary_operation<class_name<T_a, T_b>, scalar_type_expr, T_a, T_b>; \
114+
using base::arguments_; \
115+
\
112116
public: \
113117
class_name(T_a&& a, T_b&& b) /* NOLINT */ \
114-
: binary_operation<class_name<T_a, T_b>, scalar_type_expr, T_a, T_b>( \
115-
std::forward<T_a>(a), std::forward<T_b>(b), operation) {} \
118+
: base(std::forward<T_a>(a), std::forward<T_b>(b), operation) {} \
119+
inline auto deep_copy() { \
120+
auto&& a_copy = std::get<0>(arguments_).deep_copy(); \
121+
auto&& b_copy = std::get<1>(arguments_).deep_copy(); \
122+
return class_name<std::remove_reference_t<decltype(a_copy)>, \
123+
std::remove_reference_t<decltype(b_copy)>>( \
124+
std::move(a_copy), std::move(b_copy)); \
125+
} \
116126
}; \
117127
\
118128
template <typename T_a, typename T_b, \
@@ -146,10 +156,20 @@ class binary_operation : public operation_cl<Derived, T_res, T_a, T_b> {
146156
template <typename T_a, typename T_b> \
147157
class class_name : public binary_operation<class_name<T_a, T_b>, \
148158
scalar_type_expr, T_a, T_b> { \
159+
using base \
160+
= binary_operation<class_name<T_a, T_b>, scalar_type_expr, T_a, T_b>; \
161+
using base::arguments_; \
162+
\
149163
public: \
150164
class_name(T_a&& a, T_b&& b) /* NOLINT */ \
151-
: binary_operation<class_name<T_a, T_b>, scalar_type_expr, T_a, T_b>( \
152-
std::forward<T_a>(a), std::forward<T_b>(b), operation) {} \
165+
: base(std::forward<T_a>(a), std::forward<T_b>(b), operation) {} \
166+
inline auto deep_copy() { \
167+
auto&& a_copy = std::get<0>(arguments_).deep_copy(); \
168+
auto&& b_copy = std::get<1>(arguments_).deep_copy(); \
169+
return class_name<std::remove_reference_t<decltype(a_copy)>, \
170+
std::remove_reference_t<decltype(b_copy)>>( \
171+
std::move(a_copy), std::move(b_copy)); \
172+
} \
153173
inline matrix_cl_view view() const { __VA_ARGS__; } \
154174
}; \
155175
\

stan/math/opencl/kernel_generator/block.hpp

+14-4
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,16 @@ class block_
5757
}
5858
}
5959

60+
/**
61+
* Creates a deep copy of this expression.
62+
* @return copy of \c *this
63+
*/
64+
inline auto deep_copy() {
65+
auto&& arg_copy = std::get<0>(arguments_).deep_copy();
66+
return block_<std::remove_reference_t<decltype(arg_copy)>>{
67+
std::move(arg_copy), start_row_, start_col_, rows_, cols_};
68+
}
69+
6070
/**
6171
* Generates kernel code for this expression.
6272
* @param i row index variable name
@@ -227,10 +237,10 @@ class block_
227237
*/
228238
template <typename T,
229239
typename = require_all_valid_expressions_and_none_scalar_t<T>>
230-
inline block_<as_operation_cl_t<T>> block(T&& a, int start_row, int start_col,
231-
int rows, int cols) {
232-
return block_<as_operation_cl_t<T>>(as_operation_cl(std::forward<T>(a)),
233-
start_row, start_col, rows, cols);
240+
inline auto block(T&& a, int start_row, int start_col, int rows, int cols) {
241+
auto&& a_operation = as_operation_cl(std::forward<T>(a)).deep_copy();
242+
return block_<std::remove_reference_t<decltype(a_operation)>>(
243+
std::move(a_operation), start_row, start_col, rows, cols);
234244
}
235245

236246
} // namespace math

stan/math/opencl/kernel_generator/load.hpp

+6
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ class load_
4444
*/
4545
explicit load_(T&& a) : a_(std::forward<T>(a)) {}
4646

47+
/**
48+
* Creates a deep copy of this expression.
49+
* @return copy of \c *this
50+
*/
51+
inline load_<T&> deep_copy() { return load_<T&>(a_); }
52+
4753
/**
4854
* generates kernel code for this expression.
4955
* @param i row index variable name

stan/math/opencl/kernel_generator/operation_cl.hpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,8 @@ class operation_cl : public operation_cl_base {
245245
*/
246246
inline int bottom_diagonal() const {
247247
return index_apply<N>([&](auto... Is) {
248-
return std::min({std::get<Is>(arguments_).bottom_diagonal()...});
248+
return std::min(std::initializer_list<int>(
249+
{std::get<Is>(arguments_).bottom_diagonal()...}));
249250
});
250251
}
251252

@@ -256,7 +257,8 @@ class operation_cl : public operation_cl_base {
256257
*/
257258
inline int top_diagonal() const {
258259
return index_apply<N>([&](auto... Is) {
259-
return std::max({std::get<Is>(arguments_).top_diagonal()...});
260+
return std::max(std::initializer_list<int>(
261+
{std::get<Is>(arguments_).top_diagonal()...}));
260262
});
261263
}
262264
};

stan/math/opencl/kernel_generator/rowwise_reduction.hpp

+56-23
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class rowwise_reduction
4646
* @param a the expression to reduce
4747
* @param init OpenCL source code of initialization value for reduction
4848
*/
49-
rowwise_reduction(T&& a, const std::string& init)
49+
explicit rowwise_reduction(T&& a, const std::string& init)
5050
: base(std::forward<T>(a)), init_(init) {}
5151

5252
/**
@@ -121,7 +121,7 @@ class rowwise_reduction
121121

122122
/**
123123
* Determine index of top diagonal written.
124-
* @return number of columns
124+
* @return top diagonal
125125
*/
126126
inline int top_diagonal() const { return 1; }
127127
};
@@ -149,10 +149,21 @@ struct sum_op {
149149
template <typename T>
150150
class rowwise_sum_
151151
: public rowwise_reduction<rowwise_sum_<T>, T, sum_op, true> {
152+
using base = rowwise_reduction<rowwise_sum_<T>, T, sum_op, true>;
153+
using base::arguments_;
154+
152155
public:
153-
explicit rowwise_sum_(T&& a)
154-
: rowwise_reduction<rowwise_sum_<T>, T, sum_op, true>(std::forward<T>(a),
155-
"0") {}
156+
explicit rowwise_sum_(T&& a) : base(std::forward<T>(a), "0") {}
157+
158+
/**
159+
* Creates a deep copy of this expression.
160+
* @return copy of \c *this
161+
*/
162+
inline rowwise_sum_<std::remove_reference_t<T>> deep_copy() {
163+
auto&& arg_copy = std::get<0>(arguments_).deep_copy();
164+
return rowwise_sum_<std::remove_reference_t<decltype(arg_copy)>>(
165+
std::move(arg_copy));
166+
}
156167
};
157168

158169
/**
@@ -163,9 +174,10 @@ class rowwise_sum_
163174
*/
164175
template <typename T,
165176
typename = require_all_valid_expressions_and_none_scalar_t<T>>
166-
inline rowwise_sum_<as_operation_cl_t<T>> rowwise_sum(T&& a) {
167-
return rowwise_sum_<as_operation_cl_t<T>>(
168-
as_operation_cl(std::forward<T>(a)));
177+
inline auto rowwise_sum(T&& a) {
178+
auto&& arg_copy = as_operation_cl(std::forward<T>(a)).deep_copy();
179+
return rowwise_sum_<std::remove_reference_t<decltype(arg_copy)>>(
180+
std::move(arg_copy));
169181
}
170182

171183
/**
@@ -205,11 +217,21 @@ class rowwise_max_
205217
: public rowwise_reduction<
206218
rowwise_max_<T>, T,
207219
max_op<typename std::remove_reference_t<T>::Scalar>, false> {
208-
public:
209220
using op = max_op<typename std::remove_reference_t<T>::Scalar>;
210-
explicit rowwise_max_(T&& a)
211-
: rowwise_reduction<rowwise_max_<T>, T, op, false>(std::forward<T>(a),
212-
op::init()) {}
221+
using base = rowwise_reduction<rowwise_max_<T>, T, op, false>;
222+
using base::arguments_;
223+
224+
public:
225+
explicit rowwise_max_(T&& a) : base(std::forward<T>(a), op::init()) {}
226+
/**
227+
* Creates a deep copy of this expression.
228+
* @return copy of \c *this
229+
*/
230+
inline auto deep_copy() {
231+
auto&& arg_copy = std::get<0>(arguments_).deep_copy();
232+
return rowwise_max_<std::remove_reference_t<decltype(arg_copy)>>(
233+
std::move(arg_copy));
234+
}
213235
};
214236

215237
/**
@@ -220,11 +242,11 @@ class rowwise_max_
220242
*/
221243
template <typename T,
222244
typename = require_all_valid_expressions_and_none_scalar_t<T>>
223-
inline rowwise_max_<as_operation_cl_t<T>> rowwise_max(T&& a) {
224-
return rowwise_max_<as_operation_cl_t<T>>(
225-
as_operation_cl(std::forward<T>(a)));
245+
inline auto rowwise_max(T&& a) {
246+
auto&& arg_copy = as_operation_cl(std::forward<T>(a)).deep_copy();
247+
return rowwise_max_<std::remove_reference_t<decltype(arg_copy)>>(
248+
std::move(arg_copy));
226249
}
227-
228250
/**
229251
* Operation for min reduction.
230252
* @tparam T type to reduce
@@ -262,11 +284,21 @@ class rowwise_min_
262284
: public rowwise_reduction<
263285
rowwise_min_<T>, T,
264286
min_op<typename std::remove_reference_t<T>::Scalar>, false> {
265-
public:
266287
using op = min_op<typename std::remove_reference_t<T>::Scalar>;
267-
explicit rowwise_min_(T&& a)
268-
: rowwise_reduction<rowwise_min_<T>, T, op, false>(std::forward<T>(a),
269-
op::init()) {}
288+
using base = rowwise_reduction<rowwise_min_<T>, T, op, false>;
289+
using base::arguments_;
290+
291+
public:
292+
explicit rowwise_min_(T&& a) : base(std::forward<T>(a), op::init()) {}
293+
/**
294+
* Creates a deep copy of this expression.
295+
* @return copy of \c *this
296+
*/
297+
inline auto deep_copy() {
298+
auto&& arg_copy = std::get<0>(arguments_).deep_copy();
299+
return rowwise_min_<std::remove_reference_t<decltype(arg_copy)>>(
300+
std::move(arg_copy));
301+
}
270302
};
271303

272304
/**
@@ -277,9 +309,10 @@ class rowwise_min_
277309
*/
278310
template <typename T,
279311
typename = require_all_valid_expressions_and_none_scalar_t<T>>
280-
inline rowwise_min_<as_operation_cl_t<T>> rowwise_min(T&& a) {
281-
return rowwise_min_<as_operation_cl_t<T>>(
282-
as_operation_cl(std::forward<T>(a)));
312+
inline auto rowwise_min(T&& a) {
313+
auto&& arg_copy = as_operation_cl(std::forward<T>(a)).deep_copy();
314+
return rowwise_min_<std::remove_reference_t<decltype(arg_copy)>>(
315+
std::move(arg_copy));
283316
}
284317

285318
} // namespace math

stan/math/opencl/kernel_generator/scalar.hpp

+19
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <stan/math/opencl/kernel_generator/type_str.hpp>
88
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
99
#include <stan/math/opencl/kernel_generator/operation_cl.hpp>
10+
#include <limits>
1011
#include <string>
1112
#include <type_traits>
1213
#include <set>
@@ -36,6 +37,12 @@ class scalar_ : public operation_cl<scalar_<T>, T> {
3637
*/
3738
explicit scalar_(const T a) : a_(a) {}
3839

40+
/**
41+
* Creates a deep copy of this expression.
42+
* @return copy of \c *this
43+
*/
44+
inline scalar_<T> deep_copy() { return scalar_<T>(a_); }
45+
3946
/**
4047
* generates kernel code for this expression.
4148
* @param i row index variable name
@@ -81,6 +88,18 @@ class scalar_ : public operation_cl<scalar_<T>, T> {
8188
* @return view
8289
*/
8390
inline matrix_cl_view view() const { return matrix_cl_view::Entire; }
91+
92+
/**
93+
* Determine index of bottom diagonal written.
94+
* @return number of columns
95+
*/
96+
inline int bottom_diagonal() const { return std::numeric_limits<int>::min(); }
97+
98+
/**
99+
* Determine index of top diagonal written.
100+
* @return number of columns
101+
*/
102+
inline int top_diagonal() const { return std::numeric_limits<int>::max(); }
84103
};
85104

86105
} // namespace math

stan/math/opencl/kernel_generator/select.hpp

+14
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,20 @@ class select_ : public operation_cl<select_<T_condition, T_then, T_else>,
6969
}
7070
}
7171

72+
/**
73+
* Creates a deep copy of this expression.
74+
* @return copy of \c *this
75+
*/
76+
inline auto deep_copy() {
77+
auto&& condition_copy = std::get<0>(arguments_).deep_copy();
78+
auto&& then_copy = std::get<0>(arguments_).deep_copy();
79+
auto&& else_copy = std::get<0>(arguments_).deep_copy();
80+
return select_<std::remove_reference_t<decltype(condition_copy)>,
81+
std::remove_reference_t<decltype(then_copy)>,
82+
std::remove_reference_t<decltype(else_copy)>>(
83+
std::move(condition_copy), std::move(then_copy), std::move(else_copy));
84+
}
85+
7286
/**
7387
* generates kernel code for this (select) operation.
7488
* @param i row index variable name

stan/math/opencl/kernel_generator/unary_function_cl.hpp

+18-4
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,16 @@ class unary_function_cl
7878
#define ADD_UNARY_FUNCTION(fun) \
7979
template <typename T> \
8080
class fun##_ : public unary_function_cl<fun##_<T>, T> { \
81+
using base = unary_function_cl<fun##_<T>, T>; \
82+
using base::arguments_; \
83+
\
8184
public: \
82-
explicit fun##_(T&& a) \
83-
: unary_function_cl<fun##_<T>, T>(std::forward<T>(a), #fun) {} \
85+
explicit fun##_(T&& a) : base(std::forward<T>(a), #fun) {} \
86+
inline auto deep_copy() { \
87+
auto&& arg_copy = std::get<0>(arguments_).deep_copy(); \
88+
return fun##_<std::remove_reference_t<decltype(arg_copy)>>{ \
89+
std::move(arg_copy)}; \
90+
} \
8491
inline matrix_cl_view view() const { return matrix_cl_view::Entire; } \
8592
}; \
8693
\
@@ -99,9 +106,16 @@ class unary_function_cl
99106
#define ADD_UNARY_FUNCTION_PASS_ZERO(fun) \
100107
template <typename T> \
101108
class fun##_ : public unary_function_cl<fun##_<T>, T> { \
109+
using base = unary_function_cl<fun##_<T>, T>; \
110+
using base::arguments_; \
111+
\
102112
public: \
103-
explicit fun##_(T&& a) \
104-
: unary_function_cl<fun##_<T>, T>(std::forward<T>(a), #fun) {} \
113+
explicit fun##_(T&& a) : base(std::forward<T>(a), #fun) {} \
114+
inline auto deep_copy() { \
115+
auto&& arg_copy = std::get<0>(arguments_).deep_copy(); \
116+
return fun##_<std::remove_reference_t<decltype(arg_copy)>>{ \
117+
std::move(arg_copy)}; \
118+
} \
105119
}; \
106120
\
107121
template <typename T, typename Cond \

test/unit/math/opencl/kernel_generator/block_test.cpp

+18
Original file line numberDiff line numberDiff line change
@@ -108,4 +108,22 @@ TEST(MathMatrixCL, lhs_block_test) {
108108
EXPECT_MATRIX_NEAR(res, correct, 1e-9);
109109
}
110110

111+
TEST(MathMatrixCL, two_blocks_of_same_expression) {
112+
using stan::math::block;
113+
MatrixXd m(2, 3);
114+
m << 1, 2, 3, 4, 5, 6;
115+
116+
matrix_cl<double> m_cl(m);
117+
118+
auto tmp = m_cl + 1;
119+
auto tmp2 = block(tmp, 0, 0, 2, 2) + block(tmp, 0, 1, 2, 2);
120+
121+
matrix_cl<double> res_cl = tmp2;
122+
123+
MatrixXd res = stan::math::from_matrix_cl(res_cl);
124+
MatrixXd correct = (m.block(0, 0, 2, 2) + m.block(0, 1, 2, 2)).array() + 2;
125+
126+
EXPECT_MATRIX_NEAR(res, correct, 1e-9);
127+
}
128+
111129
#endif

0 commit comments

Comments
 (0)