|
| 1 | +/*************************************************************************** |
| 2 | + * |
| 3 | + * Copyright (C) Codeplay Software Ltd. |
| 4 | + * |
| 5 | + * Part of the LLVM Project, under the Apache License v2.0 with LLVM |
| 6 | + * Exceptions. See https://llvm.org/LICENSE.txt for license information. |
| 7 | + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 8 | + * |
| 9 | + * Unless required by applicable law or agreed to in writing, software |
| 10 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | + * See the License for the specific language governing permissions and |
| 13 | + * limitations under the License. |
| 14 | + * |
| 15 | + * SYCL compatibility extension |
| 16 | + * |
| 17 | + * launch.hpp |
| 18 | + * |
| 19 | + * Description: |
| 20 | + * launch functionality for the SYCL compatibility extension |
| 21 | + **************************************************************************/ |
| 22 | + |
| 23 | +#pragma once |
| 24 | + |
| 25 | +#include <sycl/accessor.hpp> |
| 26 | +#include <sycl/event.hpp> |
| 27 | +#include <sycl/nd_range.hpp> |
| 28 | +#include <sycl/queue.hpp> |
| 29 | +#include <sycl/range.hpp> |
| 30 | +#include <sycl/reduction.hpp> |
| 31 | + |
| 32 | +#include <syclcompat/device.hpp> |
| 33 | +#include <syclcompat/dims.hpp> |
| 34 | + |
| 35 | +namespace syclcompat { |
| 36 | + |
| 37 | +namespace detail { |
| 38 | + |
| 39 | +template <typename R, typename... Types> |
| 40 | +constexpr size_t getArgumentCount(R (*f)(Types...)) { |
| 41 | + return sizeof...(Types); |
| 42 | +} |
| 43 | + |
| 44 | +template <int Dim> |
| 45 | +sycl::nd_range<3> transform_nd_range(const sycl::nd_range<Dim> &range) { |
| 46 | + sycl::range<Dim> global_range = range.get_global_range(); |
| 47 | + sycl::range<Dim> local_range = range.get_local_range(); |
| 48 | + if constexpr (Dim == 3) { |
| 49 | + return range; |
| 50 | + } else if constexpr (Dim == 2) { |
| 51 | + return sycl::nd_range<3>{{1, global_range[0], global_range[1]}, |
| 52 | + {1, local_range[0], local_range[1]}}; |
| 53 | + } |
| 54 | + return sycl::nd_range<3>{{1, 1, global_range[0]}, {1, 1, local_range[0]}}; |
| 55 | +} |
| 56 | + |
| 57 | +template <auto F, typename... Args> |
| 58 | +std::enable_if_t<std::is_invocable_v<decltype(F), Args...>, sycl::event> |
| 59 | +launch(const sycl::nd_range<3> &range, sycl::queue q, Args... args) { |
| 60 | + static_assert(detail::getArgumentCount(F) == sizeof...(args), |
| 61 | + "Wrong number of arguments to SYCL kernel"); |
| 62 | + static_assert( |
| 63 | + std::is_same<std::invoke_result_t<decltype(F), Args...>, void>::value, |
| 64 | + "SYCL kernels should return void"); |
| 65 | + |
| 66 | + return q.parallel_for(range, [=](sycl::nd_item<3>) { F(args...); }); |
| 67 | +} |
| 68 | + |
| 69 | +template <auto F, typename... Args> |
| 70 | +sycl::event launch(const sycl::nd_range<3> &range, size_t mem_size, |
| 71 | + sycl::queue q, Args... args) { |
| 72 | + static_assert(detail::getArgumentCount(F) == sizeof...(args) + 1, |
| 73 | + "Wrong number of arguments to SYCL kernel"); |
| 74 | + |
| 75 | + using F_t = decltype(F); |
| 76 | + using f_return_t = typename std::invoke_result_t<F_t, Args..., char *>; |
| 77 | + static_assert(std::is_same<f_return_t, void>::value, |
| 78 | + "SYCL kernels should return void"); |
| 79 | + |
| 80 | + return q.submit([&](sycl::handler &cgh) { |
| 81 | + auto local_acc = sycl::local_accessor<char, 1>(mem_size, cgh); |
| 82 | + cgh.parallel_for(range, [=](sycl::nd_item<3>) { |
| 83 | + auto local_mem = local_acc.get_pointer(); |
| 84 | + F(args..., local_mem); |
| 85 | + }); |
| 86 | + }); |
| 87 | +} |
| 88 | + |
| 89 | +} // namespace detail |
| 90 | + |
| 91 | +template <int Dim> |
| 92 | +sycl::nd_range<Dim> compute_nd_range(sycl::range<Dim> global_size_in, |
| 93 | + sycl::range<Dim> work_group_size) { |
| 94 | + |
| 95 | + if (global_size_in.size() == 0 || work_group_size.size() == 0) { |
| 96 | + throw std::invalid_argument("Global or local size is zero!"); |
| 97 | + } |
| 98 | + for (size_t i = 0; i < Dim; ++i) { |
| 99 | + if (global_size_in[i] < work_group_size[i]) |
| 100 | + throw std::invalid_argument("Work group size larger than global size"); |
| 101 | + } |
| 102 | + |
| 103 | + auto global_size = |
| 104 | + ((global_size_in + work_group_size - 1) / work_group_size) * |
| 105 | + work_group_size; |
| 106 | + return {global_size, work_group_size}; |
| 107 | +} |
| 108 | + |
| 109 | +inline sycl::nd_range<1> compute_nd_range(int global_size_in, |
| 110 | + int work_group_size) { |
| 111 | + return compute_nd_range<1>(global_size_in, work_group_size); |
| 112 | +} |
| 113 | + |
| 114 | +template <auto F, int Dim, typename... Args> |
| 115 | +std::enable_if_t<std::is_invocable_v<decltype(F), Args...>, sycl::event> |
| 116 | +launch(const sycl::nd_range<Dim> &range, sycl::queue q, Args... args) { |
| 117 | + return detail::launch<F>(detail::transform_nd_range<Dim>(range), q, args...); |
| 118 | +} |
| 119 | + |
| 120 | +template <auto F, int Dim, typename... Args> |
| 121 | +std::enable_if_t<std::is_invocable_v<decltype(F), Args...>, sycl::event> |
| 122 | +launch(const sycl::nd_range<Dim> &range, Args... args) { |
| 123 | + return launch<F>(range, get_default_queue(), args...); |
| 124 | +} |
| 125 | + |
| 126 | +// Alternative launch through dim3 objects |
| 127 | +template <auto F, typename... Args> |
| 128 | +std::enable_if_t<std::is_invocable_v<decltype(F), Args...>, sycl::event> |
| 129 | +launch(const dim3 &grid, const dim3 &threads, sycl::queue q, Args... args) { |
| 130 | + return launch<F>(sycl::nd_range<3>{grid * threads, threads}, q, args...); |
| 131 | +} |
| 132 | + |
| 133 | +template <auto F, typename... Args> |
| 134 | +std::enable_if_t<std::is_invocable_v<decltype(F), Args...>, sycl::event> |
| 135 | +launch(const dim3 &grid, const dim3 &threads, Args... args) { |
| 136 | + return launch<F>(grid, threads, get_default_queue(), args...); |
| 137 | +} |
| 138 | + |
| 139 | +/// Launches a kernel with the templated F param and arguments on a |
| 140 | +/// device specified by the given nd_range and SYCL queue. |
| 141 | +/// @tparam F SYCL kernel to be executed, expects signature F(T* local_mem, |
| 142 | +/// Args... args). |
| 143 | +/// @tparam Dim nd_range dimension number. |
| 144 | +/// @tparam Args Types of the arguments to be passed to the kernel. |
| 145 | +/// @param range Nd_range specifying the work group and global sizes for the |
| 146 | +/// kernel. |
| 147 | +/// @param q The SYCL queue on which to execute the kernel. |
| 148 | +/// @param mem_size The size, in number of bytes, of the local |
| 149 | +/// memory to be allocated for kernel. |
| 150 | +/// @param args The arguments to be passed to the kernel. |
| 151 | +/// @return A SYCL event object that can be used to synchronize with the |
| 152 | +/// kernel's execution. |
| 153 | +template <auto F, int Dim, typename... Args> |
| 154 | +sycl::event launch(const sycl::nd_range<Dim> &range, size_t mem_size, |
| 155 | + sycl::queue q, Args... args) { |
| 156 | + return detail::launch<F>(detail::transform_nd_range<Dim>(range), mem_size, q, |
| 157 | + args...); |
| 158 | +} |
| 159 | + |
| 160 | +/// Launches a kernel with the templated F param and arguments on a |
| 161 | +/// device specified by the given nd_range using theSYCL default queue. |
| 162 | +/// @tparam F SYCL kernel to be executed, expects signature F(T* local_mem, |
| 163 | +/// Args... args). |
| 164 | +/// @tparam Dim nd_range dimension number. |
| 165 | +/// @tparam Args Types of the arguments to be passed to the kernel. |
| 166 | +/// @param range Nd_range specifying the work group and global sizes for the |
| 167 | +/// kernel. |
| 168 | +/// @param mem_size The size, in number of bytes, of the local |
| 169 | +/// memory to be allocated for kernel. |
| 170 | +/// @param args The arguments to be passed to the kernel. |
| 171 | +/// @return A SYCL event object that can be used to synchronize with the |
| 172 | +/// kernel's execution. |
| 173 | +template <auto F, int Dim, typename... Args> |
| 174 | +sycl::event launch(const sycl::nd_range<Dim> &range, size_t mem_size, |
| 175 | + Args... args) { |
| 176 | + return launch<F>(range, mem_size, get_default_queue(), args...); |
| 177 | +} |
| 178 | + |
| 179 | +/// Launches a kernel with the templated F param and arguments on a |
| 180 | +/// device with a user-specified grid and block dimensions following the |
| 181 | +/// standard of other programming models using a user-defined SYCL queue. |
| 182 | +/// @tparam F SYCL kernel to be executed, expects signature F(T* local_mem, |
| 183 | +/// Args... args). |
| 184 | +/// @tparam Dim nd_range dimension number. |
| 185 | +/// @tparam Args Types of the arguments to be passed to the kernel. |
| 186 | +/// @param grid Grid dimensions represented with an (x, y, z) iteration space. |
| 187 | +/// @param threads Block dimensions represented with an (x, y, z) iteration |
| 188 | +/// space. |
| 189 | +/// @param mem_size The size, in number of bytes, of the local |
| 190 | +/// memory to be allocated for kernel. |
| 191 | +/// @param args The arguments to be passed to the kernel. |
| 192 | +/// @return A SYCL event object that can be used to synchronize with the |
| 193 | +/// kernel's execution. |
| 194 | +template <auto F, typename... Args> |
| 195 | +sycl::event launch(const dim3 &grid, const dim3 &threads, size_t mem_size, |
| 196 | + sycl::queue q, Args... args) { |
| 197 | + return launch<F>(sycl::nd_range<3>{grid * threads, threads}, mem_size, q, |
| 198 | + args...); |
| 199 | +} |
| 200 | + |
| 201 | +/// Launches a kernel with the templated F param and arguments on a |
| 202 | +/// device with a user-specified grid and block dimensions following the |
| 203 | +/// standard of other programming models using the default SYCL queue. |
| 204 | +/// @tparam F SYCL kernel to be executed, expects signature F(T* local_mem, |
| 205 | +/// Args... args). |
| 206 | +/// @tparam Dim nd_range dimension number. |
| 207 | +/// @tparam Args Types of the arguments to be passed to the kernel. |
| 208 | +/// @param grid Grid dimensions represented with an (x, y, z) iteration space. |
| 209 | +/// @param threads Block dimensions represented with an (x, y, z) iteration |
| 210 | +/// space. |
| 211 | +/// @param mem_size The size, in number of bytes, of the |
| 212 | +/// local memory to be allocated. |
| 213 | +/// @param args The arguments to be passed to the kernel. |
| 214 | +/// @return A SYCL event object that can be used to synchronize with the |
| 215 | +/// kernel's execution. |
| 216 | +template <auto F, typename... Args> |
| 217 | +sycl::event launch(const dim3 &grid, const dim3 &threads, size_t mem_size, |
| 218 | + Args... args) { |
| 219 | + return launch<F>(grid, threads, mem_size, get_default_queue(), args...); |
| 220 | +} |
| 221 | + |
| 222 | +} // namespace syclcompat |
0 commit comments