Skip to content

Commit 47d8489

Browse files
AlcpzjoeatoddPietroGhg
authored
[SYCL][COMPAT] Added id_query and launch headers (#11018)
SYCLcompat primary goals and features are documented [here](https://intel.github.io/llvm-docs/syclcompat/README.html). The PR includes the helper code to launch device code through the `launch<F>` mechanism, and the wrappers to the free function queries. --------- Co-authored-by: Joe Todd <[email protected]> Co-authored-by: Pietro Ghiglio <[email protected]>
1 parent bceed65 commit 47d8489

File tree

12 files changed

+1107
-6
lines changed

12 files changed

+1107
-6
lines changed

sycl/include/syclcompat/id_query.hpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/***************************************************************************
2+
*
3+
* Copyright (C) Codeplay Software Ltd.
4+
*
5+
* Part of the LLVM Project, under the Apache License v2.0 with LLVM
6+
* Exceptions. See https://llvm.org/LICENSE.txt for license information.
7+
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*
15+
* SYCL compatibility extension
16+
*
17+
* id_query.hpp
18+
*
19+
* Description:
20+
* id_query functionality for the SYCL compatibility extension
21+
**************************************************************************/
22+
23+
#pragma once
24+
25+
#include <sycl/nd_item.hpp>
26+
27+
namespace syclcompat {
28+
29+
using sycl::ext::oneapi::experimental::this_nd_item;
30+
31+
inline void wg_barrier() { this_nd_item<3>().barrier(); }
32+
33+
namespace local_id {
34+
inline size_t x() { return this_nd_item<3>().get_local_id(2); }
35+
inline size_t y() { return this_nd_item<3>().get_local_id(1); }
36+
inline size_t z() { return this_nd_item<3>().get_local_id(0); }
37+
} // namespace local_id
38+
39+
namespace local_range {
40+
inline size_t x() { return this_nd_item<3>().get_local_range(2); }
41+
inline size_t y() { return this_nd_item<3>().get_local_range(1); }
42+
inline size_t z() { return this_nd_item<3>().get_local_range(0); }
43+
} // namespace local_range
44+
45+
namespace work_group_id {
46+
inline size_t x() { return this_nd_item<3>().get_group(2); }
47+
inline size_t y() { return this_nd_item<3>().get_group(1); }
48+
inline size_t z() { return this_nd_item<3>().get_group(0); }
49+
} // namespace work_group_id
50+
51+
namespace work_group_range {
52+
inline size_t x() { return this_nd_item<3>().get_group_range(2); }
53+
inline size_t y() { return this_nd_item<3>().get_group_range(1); }
54+
inline size_t z() { return this_nd_item<3>().get_group_range(0); }
55+
} // namespace work_group_range
56+
57+
namespace global_range {
58+
inline size_t x() { return this_nd_item<3>().get_global_range(2); }
59+
inline size_t y() { return this_nd_item<3>().get_global_range(1); }
60+
inline size_t z() { return this_nd_item<3>().get_global_range(0); }
61+
} // namespace global_range
62+
63+
namespace global_id {
64+
inline size_t x() { return this_nd_item<3>().get_global_id(2); }
65+
inline size_t y() { return this_nd_item<3>().get_global_id(1); }
66+
inline size_t z() { return this_nd_item<3>().get_global_id(0); }
67+
} // namespace global_id
68+
69+
} // namespace syclcompat

sycl/include/syclcompat/launch.hpp

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
/***************************************************************************
2+
*
3+
* Copyright (C) Codeplay Software Ltd.
4+
*
5+
* Part of the LLVM Project, under the Apache License v2.0 with LLVM
6+
* Exceptions. See https://llvm.org/LICENSE.txt for license information.
7+
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*
15+
* SYCL compatibility extension
16+
*
17+
* launch.hpp
18+
*
19+
* Description:
20+
* launch functionality for the SYCL compatibility extension
21+
**************************************************************************/
22+
23+
#pragma once
24+
25+
#include <sycl/accessor.hpp>
26+
#include <sycl/event.hpp>
27+
#include <sycl/nd_range.hpp>
28+
#include <sycl/queue.hpp>
29+
#include <sycl/range.hpp>
30+
#include <sycl/reduction.hpp>
31+
32+
#include <syclcompat/device.hpp>
33+
#include <syclcompat/dims.hpp>
34+
35+
namespace syclcompat {
36+
37+
namespace detail {
38+
39+
template <typename R, typename... Types>
40+
constexpr size_t getArgumentCount(R (*f)(Types...)) {
41+
return sizeof...(Types);
42+
}
43+
44+
template <int Dim>
45+
sycl::nd_range<3> transform_nd_range(const sycl::nd_range<Dim> &range) {
46+
sycl::range<Dim> global_range = range.get_global_range();
47+
sycl::range<Dim> local_range = range.get_local_range();
48+
if constexpr (Dim == 3) {
49+
return range;
50+
} else if constexpr (Dim == 2) {
51+
return sycl::nd_range<3>{{1, global_range[0], global_range[1]},
52+
{1, local_range[0], local_range[1]}};
53+
}
54+
return sycl::nd_range<3>{{1, 1, global_range[0]}, {1, 1, local_range[0]}};
55+
}
56+
57+
template <auto F, typename... Args>
58+
std::enable_if_t<std::is_invocable_v<decltype(F), Args...>, sycl::event>
59+
launch(const sycl::nd_range<3> &range, sycl::queue q, Args... args) {
60+
static_assert(detail::getArgumentCount(F) == sizeof...(args),
61+
"Wrong number of arguments to SYCL kernel");
62+
static_assert(
63+
std::is_same<std::invoke_result_t<decltype(F), Args...>, void>::value,
64+
"SYCL kernels should return void");
65+
66+
return q.parallel_for(range, [=](sycl::nd_item<3>) { F(args...); });
67+
}
68+
69+
template <auto F, typename... Args>
70+
sycl::event launch(const sycl::nd_range<3> &range, size_t mem_size,
71+
sycl::queue q, Args... args) {
72+
static_assert(detail::getArgumentCount(F) == sizeof...(args) + 1,
73+
"Wrong number of arguments to SYCL kernel");
74+
75+
using F_t = decltype(F);
76+
using f_return_t = typename std::invoke_result_t<F_t, Args..., char *>;
77+
static_assert(std::is_same<f_return_t, void>::value,
78+
"SYCL kernels should return void");
79+
80+
return q.submit([&](sycl::handler &cgh) {
81+
auto local_acc = sycl::local_accessor<char, 1>(mem_size, cgh);
82+
cgh.parallel_for(range, [=](sycl::nd_item<3>) {
83+
auto local_mem = local_acc.get_pointer();
84+
F(args..., local_mem);
85+
});
86+
});
87+
}
88+
89+
} // namespace detail
90+
91+
template <int Dim>
92+
sycl::nd_range<Dim> compute_nd_range(sycl::range<Dim> global_size_in,
93+
sycl::range<Dim> work_group_size) {
94+
95+
if (global_size_in.size() == 0 || work_group_size.size() == 0) {
96+
throw std::invalid_argument("Global or local size is zero!");
97+
}
98+
for (size_t i = 0; i < Dim; ++i) {
99+
if (global_size_in[i] < work_group_size[i])
100+
throw std::invalid_argument("Work group size larger than global size");
101+
}
102+
103+
auto global_size =
104+
((global_size_in + work_group_size - 1) / work_group_size) *
105+
work_group_size;
106+
return {global_size, work_group_size};
107+
}
108+
109+
inline sycl::nd_range<1> compute_nd_range(int global_size_in,
110+
int work_group_size) {
111+
return compute_nd_range<1>(global_size_in, work_group_size);
112+
}
113+
114+
template <auto F, int Dim, typename... Args>
115+
std::enable_if_t<std::is_invocable_v<decltype(F), Args...>, sycl::event>
116+
launch(const sycl::nd_range<Dim> &range, sycl::queue q, Args... args) {
117+
return detail::launch<F>(detail::transform_nd_range<Dim>(range), q, args...);
118+
}
119+
120+
template <auto F, int Dim, typename... Args>
121+
std::enable_if_t<std::is_invocable_v<decltype(F), Args...>, sycl::event>
122+
launch(const sycl::nd_range<Dim> &range, Args... args) {
123+
return launch<F>(range, get_default_queue(), args...);
124+
}
125+
126+
// Alternative launch through dim3 objects
127+
template <auto F, typename... Args>
128+
std::enable_if_t<std::is_invocable_v<decltype(F), Args...>, sycl::event>
129+
launch(const dim3 &grid, const dim3 &threads, sycl::queue q, Args... args) {
130+
return launch<F>(sycl::nd_range<3>{grid * threads, threads}, q, args...);
131+
}
132+
133+
template <auto F, typename... Args>
134+
std::enable_if_t<std::is_invocable_v<decltype(F), Args...>, sycl::event>
135+
launch(const dim3 &grid, const dim3 &threads, Args... args) {
136+
return launch<F>(grid, threads, get_default_queue(), args...);
137+
}
138+
139+
/// Launches a kernel with the templated F param and arguments on a
140+
/// device specified by the given nd_range and SYCL queue.
141+
/// @tparam F SYCL kernel to be executed, expects signature F(T* local_mem,
142+
/// Args... args).
143+
/// @tparam Dim nd_range dimension number.
144+
/// @tparam Args Types of the arguments to be passed to the kernel.
145+
/// @param range Nd_range specifying the work group and global sizes for the
146+
/// kernel.
147+
/// @param q The SYCL queue on which to execute the kernel.
148+
/// @param mem_size The size, in number of bytes, of the local
149+
/// memory to be allocated for kernel.
150+
/// @param args The arguments to be passed to the kernel.
151+
/// @return A SYCL event object that can be used to synchronize with the
152+
/// kernel's execution.
153+
template <auto F, int Dim, typename... Args>
154+
sycl::event launch(const sycl::nd_range<Dim> &range, size_t mem_size,
155+
sycl::queue q, Args... args) {
156+
return detail::launch<F>(detail::transform_nd_range<Dim>(range), mem_size, q,
157+
args...);
158+
}
159+
160+
/// Launches a kernel with the templated F param and arguments on a
161+
/// device specified by the given nd_range using theSYCL default queue.
162+
/// @tparam F SYCL kernel to be executed, expects signature F(T* local_mem,
163+
/// Args... args).
164+
/// @tparam Dim nd_range dimension number.
165+
/// @tparam Args Types of the arguments to be passed to the kernel.
166+
/// @param range Nd_range specifying the work group and global sizes for the
167+
/// kernel.
168+
/// @param mem_size The size, in number of bytes, of the local
169+
/// memory to be allocated for kernel.
170+
/// @param args The arguments to be passed to the kernel.
171+
/// @return A SYCL event object that can be used to synchronize with the
172+
/// kernel's execution.
173+
template <auto F, int Dim, typename... Args>
174+
sycl::event launch(const sycl::nd_range<Dim> &range, size_t mem_size,
175+
Args... args) {
176+
return launch<F>(range, mem_size, get_default_queue(), args...);
177+
}
178+
179+
/// Launches a kernel with the templated F param and arguments on a
180+
/// device with a user-specified grid and block dimensions following the
181+
/// standard of other programming models using a user-defined SYCL queue.
182+
/// @tparam F SYCL kernel to be executed, expects signature F(T* local_mem,
183+
/// Args... args).
184+
/// @tparam Dim nd_range dimension number.
185+
/// @tparam Args Types of the arguments to be passed to the kernel.
186+
/// @param grid Grid dimensions represented with an (x, y, z) iteration space.
187+
/// @param threads Block dimensions represented with an (x, y, z) iteration
188+
/// space.
189+
/// @param mem_size The size, in number of bytes, of the local
190+
/// memory to be allocated for kernel.
191+
/// @param args The arguments to be passed to the kernel.
192+
/// @return A SYCL event object that can be used to synchronize with the
193+
/// kernel's execution.
194+
template <auto F, typename... Args>
195+
sycl::event launch(const dim3 &grid, const dim3 &threads, size_t mem_size,
196+
sycl::queue q, Args... args) {
197+
return launch<F>(sycl::nd_range<3>{grid * threads, threads}, mem_size, q,
198+
args...);
199+
}
200+
201+
/// Launches a kernel with the templated F param and arguments on a
202+
/// device with a user-specified grid and block dimensions following the
203+
/// standard of other programming models using the default SYCL queue.
204+
/// @tparam F SYCL kernel to be executed, expects signature F(T* local_mem,
205+
/// Args... args).
206+
/// @tparam Dim nd_range dimension number.
207+
/// @tparam Args Types of the arguments to be passed to the kernel.
208+
/// @param grid Grid dimensions represented with an (x, y, z) iteration space.
209+
/// @param threads Block dimensions represented with an (x, y, z) iteration
210+
/// space.
211+
/// @param mem_size The size, in number of bytes, of the
212+
/// local memory to be allocated.
213+
/// @param args The arguments to be passed to the kernel.
214+
/// @return A SYCL event object that can be used to synchronize with the
215+
/// kernel's execution.
216+
template <auto F, typename... Args>
217+
sycl::event launch(const dim3 &grid, const dim3 &threads, size_t mem_size,
218+
Args... args) {
219+
return launch<F>(grid, threads, mem_size, get_default_queue(), args...);
220+
}
221+
222+
} // namespace syclcompat

sycl/include/syclcompat/memory.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
#include <utility>
4343

4444
#include <sycl/builtins.hpp>
45+
#include <sycl/ext/oneapi/group_local_memory.hpp>
4546
#include <sycl/usm.hpp>
4647

4748
#include <syclcompat/device.hpp>
@@ -57,6 +58,14 @@
5758

5859
namespace syclcompat {
5960

61+
template <typename AllocT> auto *local_mem() {
62+
sycl::multi_ptr<AllocT, sycl::access::address_space::local_space>
63+
As_multi_ptr = sycl::ext::oneapi::group_local_memory<AllocT>(
64+
sycl::ext::oneapi::experimental::this_nd_item<3>().get_group());
65+
auto *As = *As_multi_ptr;
66+
return As;
67+
}
68+
6069
namespace detail {
6170
enum memcpy_direction {
6271
host_to_host,

sycl/include/syclcompat/syclcompat.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,7 @@
2525
#include <syclcompat/defs.hpp>
2626
#include <syclcompat/device.hpp>
2727
#include <syclcompat/dims.hpp>
28+
#include <syclcompat/id_query.hpp>
2829
#include <syclcompat/kernel.hpp>
30+
#include <syclcompat/launch.hpp>
2931
#include <syclcompat/memory.hpp>

sycl/test-e2e/syclcompat/common.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222

2323
#pragma once
2424

25+
#include <sycl/half_type.hpp>
26+
#include <tuple>
27+
2528
// Typed call helper
2629
// Iterates over all types and calls Functor f for each of them
2730
template <typename tuple, typename Functor>
@@ -34,3 +37,7 @@ void instantiate_all_types(Functor &&f) {
3437

3538
#define INSTANTIATE_ALL_TYPES(tuple, f) \
3639
instantiate_all_types<tuple>([]<typename T>() { f<T>(); });
40+
41+
using value_type_list =
42+
std::tuple<int, unsigned int, short, unsigned short, long, unsigned long,
43+
long long, unsigned long long, float, double, sycl::half>;

0 commit comments

Comments
 (0)