Skip to content

Commit 496161e

Browse files
committed
fix rowwise reductions constructors
1 parent 4de2ae7 commit 496161e

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

stan/math/opencl/kernel_generator/rowwise_reduction.hpp

+12-10
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,10 @@ class rowwise_sum_
174174
*/
175175
template <typename T,
176176
typename = require_all_valid_expressions_and_none_scalar_t<T>>
177-
inline rowwise_sum_<as_operation_cl_t<T>> rowwise_sum(T&& a) {
178-
return rowwise_sum_<as_operation_cl_t<T>>(
179-
as_operation_cl(std::forward<T>(a)).deep_copy());
177+
inline auto rowwise_sum(T&& a) {
178+
auto&& arg_copy = as_operation_cl(std::forward<T>(a)).deep_copy();
179+
return rowwise_sum_<std::remove_reference_t<decltype(arg_copy)>>(
180+
std::move(arg_copy));
180181
}
181182

182183
/**
@@ -241,11 +242,11 @@ class rowwise_max_
241242
*/
242243
template <typename T,
243244
typename = require_all_valid_expressions_and_none_scalar_t<T>>
244-
inline rowwise_max_<as_operation_cl_t<T>> rowwise_max(T&& a) {
245-
return rowwise_max_<as_operation_cl_t<T>>(
246-
as_operation_cl(std::forward<T>(a)).deep_copy());
245+
inline auto rowwise_max(T&& a) {
246+
auto&& arg_copy = as_operation_cl(std::forward<T>(a)).deep_copy();
247+
return rowwise_max_<std::remove_reference_t<decltype(arg_copy)>>(
248+
std::move(arg_copy));
247249
}
248-
249250
/**
250251
* Operation for min reduction.
251252
* @tparam T type to reduce
@@ -308,9 +309,10 @@ class rowwise_min_
308309
*/
309310
template <typename T,
310311
typename = require_all_valid_expressions_and_none_scalar_t<T>>
311-
inline rowwise_min_<as_operation_cl_t<T>> rowwise_min(T&& a) {
312-
return rowwise_min_<as_operation_cl_t<T>>(
313-
as_operation_cl(std::forward<T>(a)).deep_copy());
312+
inline auto rowwise_min(T&& a) {
313+
auto&& arg_copy = as_operation_cl(std::forward<T>(a)).deep_copy();
314+
return rowwise_min_<std::remove_reference_t<decltype(arg_copy)>>(
315+
std::move(arg_copy));
314316
}
315317

316318
} // namespace math

0 commit comments

Comments
 (0)