Skip to content

Commit 8123af7

Browse files
committed
remove wrapper
1 parent 0591065 commit 8123af7

File tree

5 files changed

+60
-113
lines changed

5 files changed

+60
-113
lines changed

stan/math/opencl/kernel_generator/holder_cl.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
#include <stan/math/prim/functor.hpp>
77
#include <stan/math/opencl/err.hpp>
88
#include <stan/math/opencl/matrix_cl_view.hpp>
9-
#include <stan/math/opencl/kernel_generator/wrapper.hpp>
109
#include <stan/math/opencl/kernel_generator/type_str.hpp>
1110
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
1211
#include <stan/math/opencl/kernel_generator/operation_cl.hpp>

stan/math/opencl/kernel_generator/multi_result_kernel.hpp

Lines changed: 55 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#ifdef STAN_OPENCL
44

55
#include <stan/math/prim/err.hpp>
6-
#include <stan/math/opencl/kernel_generator/wrapper.hpp>
76
#include <stan/math/opencl/kernel_generator/is_kernel_expression.hpp>
87
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
98
#include <stan/math/opencl/kernel_generator/as_operation_cl.hpp>
@@ -46,12 +45,11 @@ struct multi_result_kernel_internal {
4645
* @param expressions expressions
4746
*/
4847
static void get_clear_events(
49-
std::vector<cl::Event>& events,
50-
const std::tuple<wrapper<T_results>...>& results,
51-
const std::tuple<wrapper<T_expressions>...>& expressions) {
48+
std::vector<cl::Event>& events, const std::tuple<T_results...>& results,
49+
const std::tuple<T_expressions...>& expressions) {
5250
next::get_clear_events(events, results, expressions);
53-
std::get<N>(expressions).x.get_clear_write_events(events);
54-
std::get<N>(results).x.get_clear_read_write_events(events);
51+
std::get<N>(expressions).get_clear_write_events(events);
52+
std::get<N>(results).get_clear_read_write_events(events);
5553
}
5654
/**
5755
* Assigns the dimensions of expressions to matching results if possible.
@@ -65,12 +63,11 @@ struct multi_result_kernel_internal {
6563
* @param expressions expressions
6664
*/
6765
static void check_assign_dimensions(
68-
int n_rows, int n_cols,
69-
const std::tuple<wrapper<T_results>...>& results,
70-
const std::tuple<wrapper<T_expressions>...>& expressions) {
66+
int n_rows, int n_cols, const std::tuple<T_results...>& results,
67+
const std::tuple<T_expressions...>& expressions) {
7168
next::check_assign_dimensions(n_rows, n_cols, results, expressions);
72-
const auto& expression = std::get<N>(expressions).x;
73-
const auto& result = std::get<N>(results).x;
69+
const auto& expression = std::get<N>(expressions);
70+
const auto& result = std::get<N>(results);
7471
const char* function = "results.operator=";
7572
if (!is_without_output<T_current_expression>::value) {
7673
check_size_match(function, "Rows of ", "expression",
@@ -102,17 +99,17 @@ struct multi_result_kernel_internal {
10299
static kernel_parts generate(
103100
std::set<const operation_cl_base*>& generated, name_generator& ng,
104101
const std::string& row_index_name, const std::string& col_index_name,
105-
const std::tuple<wrapper<T_results>...>& results,
106-
const std::tuple<wrapper<T_expressions>...>& expressions) {
102+
const std::tuple<T_results...>& results,
103+
const std::tuple<T_expressions...>& expressions) {
107104
kernel_parts parts = next::generate(generated, ng, row_index_name,
108105
col_index_name, results, expressions);
109106
if (is_without_output<T_current_expression>::value) {
110107
return parts;
111108
}
112-
kernel_parts parts0 = std::get<N>(expressions)
113-
.x.get_whole_kernel_parts(
114-
generated, ng, row_index_name,
115-
col_index_name, std::get<N>(results).x);
109+
kernel_parts parts0
110+
= std::get<N>(expressions)
111+
.get_whole_kernel_parts(generated, ng, row_index_name,
112+
col_index_name, std::get<N>(results));
116113
parts += parts0;
117114
return parts;
118115
}
@@ -125,18 +122,18 @@ struct multi_result_kernel_internal {
125122
* @param results results
126123
* @param expressions expressions
127124
*/
128-
static void set_args(
129-
std::set<const operation_cl_base*>& generated, cl::Kernel& kernel,
130-
int& arg_num, const std::tuple<wrapper<T_results>...>& results,
131-
const std::tuple<wrapper<T_expressions>...>& expressions) {
125+
static void set_args(std::set<const operation_cl_base*>& generated,
126+
cl::Kernel& kernel, int& arg_num,
127+
const std::tuple<T_results...>& results,
128+
const std::tuple<T_expressions...>& expressions) {
132129
next::set_args(generated, kernel, arg_num, results, expressions);
133130

134131
if (is_without_output<T_current_expression>::value) {
135132
return;
136133
}
137134

138-
std::get<N>(expressions).x.set_args(generated, kernel, arg_num);
139-
std::get<N>(results).x.set_args(generated, kernel, arg_num);
135+
std::get<N>(expressions).set_args(generated, kernel, arg_num);
136+
std::get<N>(results).set_args(generated, kernel, arg_num);
140137
}
141138

142139
/**
@@ -145,13 +142,12 @@ struct multi_result_kernel_internal {
145142
* @param results results
146143
* @param expressions expressions
147144
*/
148-
static void add_event(
149-
cl::Event e, const std::tuple<wrapper<T_results>...>& results,
150-
const std::tuple<wrapper<T_expressions>...>& expressions) {
145+
static void add_event(cl::Event e, const std::tuple<T_results...>& results,
146+
const std::tuple<T_expressions...>& expressions) {
151147
next::add_event(e, results, expressions);
152148

153-
std::get<N>(expressions).x.add_read_event(e);
154-
std::get<N>(results).x.add_write_event(e);
149+
std::get<N>(expressions).add_read_event(e);
150+
std::get<N>(results).add_write_event(e);
155151
}
156152
};
157153
};
@@ -162,35 +158,32 @@ struct multi_result_kernel_internal<-1, T_results...> {
162158
template <typename... T_expressions>
163159
struct inner {
164160
static void get_clear_events(
165-
std::vector<cl::Event>& events,
166-
const std::tuple<wrapper<T_results>...>& results,
167-
const std::tuple<wrapper<T_expressions>...>& expressions) {}
161+
std::vector<cl::Event>& events, const std::tuple<T_results...>& results,
162+
const std::tuple<T_expressions...>& expressions) {}
168163

169164
static void check_assign_dimensions(
170-
int n_rows, int n_cols,
171-
const std::tuple<wrapper<T_results>...>& results,
172-
const std::tuple<wrapper<T_expressions>...>& expressions) {
165+
int n_rows, int n_cols, const std::tuple<T_results...>& results,
166+
const std::tuple<T_expressions...>& expressions) {
173167
return;
174168
}
175169

176170
static kernel_parts generate(
177171
std::set<const operation_cl_base*>& generated, name_generator& ng,
178172
const std::string& row_index_name, const std::string& col_index_name,
179-
const std::tuple<wrapper<T_results>...>& results,
180-
const std::tuple<wrapper<T_expressions>...>& expressions) {
173+
const std::tuple<T_results...>& results,
174+
const std::tuple<T_expressions...>& expressions) {
181175
return {};
182176
}
183177

184-
static void set_args(
185-
std::set<const operation_cl_base*>& generated, cl::Kernel& kernel,
186-
int& arg_num, const std::tuple<wrapper<T_results>...>& results,
187-
const std::tuple<wrapper<T_expressions>...>& expressions) {
178+
static void set_args(std::set<const operation_cl_base*>& generated,
179+
cl::Kernel& kernel, int& arg_num,
180+
const std::tuple<T_results...>& results,
181+
const std::tuple<T_expressions...>& expressions) {
188182
return;
189183
}
190184

191-
static void add_event(
192-
cl::Event e, const std::tuple<wrapper<T_results>...>& results,
193-
const std::tuple<wrapper<T_expressions>...>& expressions) {
185+
static void add_event(cl::Event e, const std::tuple<T_results...>& results,
186+
const std::tuple<T_expressions...>& expressions) {
194187
return;
195188
}
196189
};
@@ -215,11 +208,11 @@ class expressions_cl {
215208
* @param expressions expressions that will be calculated in same kernel.
216209
*/
217210
explicit expressions_cl(T_expressions&&... expressions)
218-
: expressions_(internal::wrapper<T_expressions>(
219-
std::forward<T_expressions>(expressions))...) {}
211+
: expressions_(
212+
T_expressions(std::forward<T_expressions>(expressions))...) {}
220213

221214
private:
222-
std::tuple<internal::wrapper<T_expressions>...> expressions_;
215+
std::tuple<T_expressions...> expressions_;
223216
template <typename... T_results>
224217
friend class results_cl;
225218
};
@@ -241,17 +234,15 @@ expressions_cl<T_expressions...> expressions(T_expressions&&... expressions) {
241234
*/
242235
template <typename... T_results>
243236
class results_cl {
244-
std::tuple<internal::wrapper<T_results>...> results_;
237+
std::tuple<T_results...> results_;
245238

246239
public:
247240
/**
248241
* Constructor.
249242
* @param results results that will be calculated in same kernel
250243
*/
251244
explicit results_cl(T_results&&... results)
252-
: results_(
253-
internal::wrapper<T_results>(std::forward<T_results>(results))...) {
254-
}
245+
: results_(std::forward<T_results>(results)...) {}
255246

256247
/**
257248
* Assigning \c expressions_cl object to \c results_ object generates and
@@ -295,14 +286,10 @@ class results_cl {
295286
std::string get_kernel_source_for_evaluating_impl(
296287
const expressions_cl<T_expressions...>& exprs,
297288
std::index_sequence<Is...>) {
298-
auto expressions = std::make_tuple(
299-
internal::make_wrapper(std::forward<decltype(as_operation_cl(
300-
std::get<Is>(exprs.expressions_).x))>(
301-
as_operation_cl(std::get<Is>(exprs.expressions_).x)))...);
302-
auto results = std::make_tuple(internal::make_wrapper(
303-
std::forward<decltype(as_operation_cl(std::get<Is>(results_).x))>(
304-
as_operation_cl(std::get<Is>(results_).x)))...);
305-
return get_kernel_source_impl(results, expressions);
289+
return get_kernel_source_impl(
290+
std::forward_as_tuple(as_operation_cl(std::get<Is>(results_))...),
291+
std::forward_as_tuple(
292+
as_operation_cl(std::get<Is>(exprs.expressions_))...));
306293
}
307294

308295
/**
@@ -314,8 +301,8 @@ class results_cl {
314301
*/
315302
template <typename... T_res, typename... T_expressions>
316303
static std::string get_kernel_source_impl(
317-
const std::tuple<internal::wrapper<T_res>...>& results,
318-
const std::tuple<internal::wrapper<T_expressions>...>& expressions) {
304+
const std::tuple<T_res...>& results,
305+
const std::tuple<T_expressions...>& expressions) {
319306
using impl = typename internal::multi_result_kernel_internal<
320307
std::tuple_size<std::tuple<T_expressions...>>::value - 1,
321308
T_res...>::template inner<T_expressions...>;
@@ -378,14 +365,10 @@ class results_cl {
378365
template <typename... T_expressions, size_t... Is>
379366
void assignment(const expressions_cl<T_expressions...>& exprs,
380367
std::index_sequence<Is...>) {
381-
auto expressions = std::make_tuple(
382-
internal::make_wrapper(std::forward<decltype(as_operation_cl(
383-
std::get<Is>(exprs.expressions_).x))>(
384-
as_operation_cl(std::get<Is>(exprs.expressions_).x)))...);
385-
auto results = std::make_tuple(internal::make_wrapper(
386-
std::forward<decltype(as_operation_cl(std::get<Is>(results_).x))>(
387-
as_operation_cl(std::get<Is>(results_).x)))...);
388-
assignment_impl(results, expressions);
368+
assignment_impl(
369+
std::forward_as_tuple(as_operation_cl(std::get<Is>(results_))...),
370+
std::forward_as_tuple(
371+
as_operation_cl(std::get<Is>(exprs.expressions_))...));
389372
}
390373

391374
/**
@@ -396,9 +379,8 @@ class results_cl {
396379
* @param expressions expressions
397380
*/
398381
template <typename... T_res, typename... T_expressions>
399-
static void assignment_impl(
400-
const std::tuple<internal::wrapper<T_res>...>& results,
401-
const std::tuple<internal::wrapper<T_expressions>...>& expressions) {
382+
static void assignment_impl(const std::tuple<T_res...>& results,
383+
const std::tuple<T_expressions...>& expressions) {
402384
using T_First_Expr = typename std::remove_reference_t<
403385
std::tuple_element_t<0, std::tuple<T_expressions...>>>;
404386
using impl = typename internal::multi_result_kernel_internal<
@@ -414,8 +396,8 @@ class results_cl {
414396
static const bool require_specific_local_size = std::max(
415397
{std::decay_t<T_expressions>::Deriv::require_specific_local_size...});
416398

417-
int n_rows = std::get<0>(expressions).x.thread_rows();
418-
int n_cols = std::get<0>(expressions).x.thread_cols();
399+
int n_rows = std::get<0>(expressions).thread_rows();
400+
int n_cols = std::get<0>(expressions).thread_cols();
419401
const char* function = "results_cl.assignment";
420402
impl::check_assign_dimensions(n_rows, n_cols, results, expressions);
421403
if (n_rows * n_cols == 0) {

stan/math/opencl/kernel_generator/operation_cl.hpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
#include <stan/math/prim/meta.hpp>
66
#include <stan/math/prim/err.hpp>
7-
#include <stan/math/opencl/kernel_generator/wrapper.hpp>
87
#include <stan/math/opencl/kernel_generator/type_str.hpp>
98
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
109
#include <stan/math/opencl/kernel_generator/is_kernel_expression.hpp>
@@ -79,7 +78,7 @@ class operation_cl : public operation_cl_base {
7978
"operation_cl: all arguments to operation must be operations!");
8079

8180
protected:
82-
std::tuple<internal::wrapper<Args>...> arguments_;
81+
std::tuple<Args...> arguments_;
8382
mutable std::string var_name_; // name of the variable that holds result of
8483
// this operation in the kernel
8584

@@ -113,7 +112,7 @@ class operation_cl : public operation_cl_base {
113112
*/
114113
template <size_t N>
115114
const auto& get_arg() const {
116-
return std::get<N>(arguments_).x;
115+
return std::get<N>(arguments_);
117116
}
118117

119118
/**
@@ -122,7 +121,7 @@ class operation_cl : public operation_cl_base {
122121
* expressions
123122
*/
124123
explicit operation_cl(Args&&... arguments)
125-
: arguments_(internal::wrapper<Args>(std::forward<Args>(arguments))...) {}
124+
: arguments_(std::forward<Args>(arguments)...) {}
126125

127126
/**
128127
* Evaluates the expression.

stan/math/opencl/kernel_generator/wrapper.hpp

Lines changed: 0 additions & 33 deletions
This file was deleted.

test/unit/math/opencl/kernel_generator/operation_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ TEST(KernelGenerator, kernel_caching) {
3030
matrix_cl<double> m2_cl(m2);
3131
auto tmp = m1_cl + 0.1234 * m2_cl;
3232
using cache = stan::math::internal::multi_result_kernel_internal<
33-
0, stan::math::load_<matrix_cl<double>&>>::inner<const decltype(tmp)&>;
33+
0, stan::math::load_<matrix_cl<double>&>&&>::inner<const decltype(tmp)&>;
3434
using unused_cache = stan::math::internal::multi_result_kernel_internal<
35-
0, stan::math::load_<matrix_cl<int>&>>::inner<const decltype(tmp)&>;
35+
0, stan::math::load_<matrix_cl<int>&>&&>::inner<const decltype(tmp)&>;
3636
EXPECT_EQ(cache::kernel_(), nullptr);
3737
matrix_cl<double> res_cl = tmp;
3838
cl_kernel cached_kernel = cache::kernel_();

0 commit comments

Comments
 (0)