Skip to content

Commit 4de2ae7

Browse files
committed
fix deep_copy return types
1 parent 5d7fd44 commit 4de2ae7

File tree

5 files changed

+43
-31
lines changed

5 files changed

+43
-31
lines changed

stan/math/opencl/kernel_generator/binary_operation.hpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,12 @@ class binary_operation : public operation_cl<Derived, T_res, T_a, T_b> {
116116
public: \
117117
class_name(T_a&& a, T_b&& b) /* NOLINT */ \
118118
: base(std::forward<T_a>(a), std::forward<T_b>(b), operation) {} \
119-
inline class_name<std::remove_reference_t<T_a>, \
120-
std::remove_reference_t<T_b>> \
121-
deep_copy() { \
122-
return {std::get<0>(arguments_).deep_copy(), \
123-
std::get<1>(arguments_).deep_copy()}; \
119+
inline auto deep_copy() { \
120+
auto&& a_copy = std::get<0>(arguments_).deep_copy(); \
121+
auto&& b_copy = std::get<1>(arguments_).deep_copy(); \
122+
return class_name<std::remove_reference_t<decltype(a_copy)>, \
123+
std::remove_reference_t<decltype(b_copy)>>( \
124+
std::move(a_copy), std::move(b_copy)); \
124125
} \
125126
}; \
126127
\
@@ -162,11 +163,12 @@ class binary_operation : public operation_cl<Derived, T_res, T_a, T_b> {
162163
public: \
163164
class_name(T_a&& a, T_b&& b) /* NOLINT */ \
164165
: base(std::forward<T_a>(a), std::forward<T_b>(b), operation) {} \
165-
inline class_name<std::remove_reference_t<T_a>, \
166-
std::remove_reference_t<T_b>> \
167-
deep_copy() { \
168-
return {std::get<0>(arguments_).deep_copy(), \
169-
std::get<1>(arguments_).deep_copy()}; \
166+
inline auto deep_copy() { \
167+
auto&& a_copy = std::get<0>(arguments_).deep_copy(); \
168+
auto&& b_copy = std::get<1>(arguments_).deep_copy(); \
169+
return class_name<std::remove_reference_t<decltype(a_copy)>, \
170+
std::remove_reference_t<decltype(b_copy)>>( \
171+
std::move(a_copy), std::move(b_copy)); \
170172
} \
171173
inline matrix_cl_view view() const { __VA_ARGS__; } \
172174
}; \

stan/math/opencl/kernel_generator/block.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,10 @@ class block_
6161
* Creates a deep copy of this expression.
6262
* @return copy of \c *this
6363
*/
64-
inline block_<std::remove_reference_t<T>> deep_copy() {
65-
return {std::get<0>(arguments_).deep_copy(), start_row_, start_col_, rows_,
66-
cols_};
64+
inline auto deep_copy() {
65+
auto&& arg_copy = std::get<0>(arguments_).deep_copy();
66+
return block_<std::remove_reference_t<decltype(arg_copy)>>{
67+
std::move(arg_copy), start_row_, start_col_, rows_, cols_};
6768
}
6869

6970
/**

stan/math/opencl/kernel_generator/rowwise_reduction.hpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,9 @@ class rowwise_sum_
160160
* @return copy of \c *this
161161
*/
162162
inline rowwise_sum_<std::remove_reference_t<T>> deep_copy() {
163-
return rowwise_sum_<std::remove_reference_t<T>>{std::get<0>(arguments_)};
163+
auto&& arg_copy = std::get<0>(arguments_).deep_copy();
164+
return rowwise_sum_<std::remove_reference_t<decltype(arg_copy)>>(
165+
std::move(arg_copy));
164166
}
165167
};
166168

@@ -224,8 +226,10 @@ class rowwise_max_
224226
* Creates a deep copy of this expression.
225227
* @return copy of \c *this
226228
*/
227-
inline rowwise_max_<std::remove_reference_t<T>> deep_copy() {
228-
return rowwise_max_<std::remove_reference_t<T>>{std::get<0>(arguments_)};
229+
inline auto deep_copy() {
230+
auto&& arg_copy = std::get<0>(arguments_).deep_copy();
231+
return rowwise_max_<std::remove_reference_t<decltype(arg_copy)>>(
232+
std::move(arg_copy));
229233
}
230234
};
231235

@@ -289,8 +293,10 @@ class rowwise_min_
289293
* Creates a deep copy of this expression.
290294
* @return copy of \c *this
291295
*/
292-
inline rowwise_min_<std::remove_reference_t<T>> deep_copy() {
293-
return rowwise_min_<std::remove_reference_t<T>>{std::get<0>(arguments_)};
296+
inline auto deep_copy() {
297+
auto&& arg_copy = std::get<0>(arguments_).deep_copy();
298+
return rowwise_min_<std::remove_reference_t<decltype(arg_copy)>>(
299+
std::move(arg_copy));
294300
}
295301
};
296302

stan/math/opencl/kernel_generator/select.hpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,14 @@ class select_ : public operation_cl<select_<T_condition, T_then, T_else>,
7373
* Creates a deep copy of this expression.
7474
* @return copy of \c *this
7575
*/
76-
inline select_<std::remove_reference_t<T_condition>,
77-
std::remove_reference_t<T_then>,
78-
std::remove_reference_t<T_else>>
79-
deep_copy() {
80-
return {std::get<0>(arguments_).deep_copy(),
81-
std::get<1>(arguments_).deep_copy(),
82-
std::get<2>(arguments_).deep_copy()};
76+
inline auto deep_copy() {
77+
auto&& condition_copy = std::get<0>(arguments_).deep_copy();
78+
auto&& then_copy = std::get<0>(arguments_).deep_copy();
79+
auto&& else_copy = std::get<0>(arguments_).deep_copy();
80+
return select_<std::remove_reference_t<decltype(condition_copy)>,
81+
std::remove_reference_t<decltype(then_copy)>,
82+
std::remove_reference_t<decltype(else_copy)>>(
83+
std::move(condition_copy), std::move(then_copy), std::move(else_copy));
8384
}
8485

8586
/**

stan/math/opencl/kernel_generator/unary_function_cl.hpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,10 @@ class unary_function_cl
8383
\
8484
public: \
8585
explicit fun##_(T&& a) : base(std::forward<T>(a), #fun) {} \
86-
inline fun##_<std::remove_reference_t<T>> deep_copy() { \
87-
return fun##_<std::remove_reference_t<T>>{ \
88-
std::get<0>(arguments_).deep_copy()}; \
86+
inline auto deep_copy() { \
87+
auto&& arg_copy = std::get<0>(arguments_).deep_copy(); \
88+
return fun##_<std::remove_reference_t<decltype(arg_copy)>>{ \
89+
std::move(arg_copy)}; \
8990
} \
9091
inline matrix_cl_view view() const { return matrix_cl_view::Entire; } \
9192
}; \
@@ -110,9 +111,10 @@ class unary_function_cl
110111
\
111112
public: \
112113
explicit fun##_(T&& a) : base(std::forward<T>(a), #fun) {} \
113-
inline fun##_<std::remove_reference_t<T>> deep_copy() { \
114-
return fun##_<std::remove_reference_t<T>>{ \
115-
std::get<0>(arguments_).deep_copy()}; \
114+
inline auto deep_copy() { \
115+
auto&& arg_copy = std::get<0>(arguments_).deep_copy(); \
116+
return fun##_<std::remove_reference_t<decltype(arg_copy)>>{ \
117+
std::move(arg_copy)}; \
116118
} \
117119
}; \
118120
\

0 commit comments

Comments
 (0)