Skip to content

Commit c69c3ba

Browse files
[NFC][SYCL] Refactor code around device_image_impl::KernelIDs
* Clarify the reason for `std::shared_ptr` (performance) * Clarify that `get_kernel_ids` is preferred over `get_kernel_ids_ptr` and update some of the existing uses of the latter to use the former * Change `get_kernel_ids` to return `iterator_range` instead of the reference to the underlying container. That way we can also return an empty range when `shared_ptr` isn't initialized, which was possible before but many ctors led to believe that is guaranteed to never happen. * Change those ctors to keep empty `shared_ptr` instead of creating empty `vector`s
1 parent 22ee417 commit c69c3ba

File tree

5 files changed

+59
-47
lines changed

5 files changed

+59
-47
lines changed

sycl/source/detail/device_image_impl.hpp

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,6 @@ class device_image_impl
312312
: MBinImage(BinImage), MContext(std::move(Context)),
313313
MDevices(Devices.to<std::vector<device_impl *>>()), MState(State),
314314
MProgram(Program),
315-
MKernelIDs(std::make_shared<std::vector<kernel_id>>()),
316315
MKernelNames{std::move(KernelNames)},
317316
MEliminatedKernelArgMasks{std::move(EliminatedKernelArgMasks)},
318317
MSpecConstsDefValBlob(getSpecConstsDefValBlob()),
@@ -347,7 +346,6 @@ class device_image_impl
347346
: MBinImage(Src), MContext(std::move(Context)),
348347
MDevices(Devices.to<std::vector<device_impl *>>()),
349348
MState(bundle_state::ext_oneapi_source), MProgram(nullptr),
350-
MKernelIDs(std::make_shared<std::vector<kernel_id>>()),
351349
MSpecConstsDefValBlob(getSpecConstsDefValBlob()),
352350
MOrigins(ImageOriginKernelCompiler),
353351
MRTCBinInfo(
@@ -361,7 +359,6 @@ class device_image_impl
361359
: MBinImage(Bytes), MContext(std::move(Context)),
362360
MDevices(Devices.to<std::vector<device_impl *>>()),
363361
MState(bundle_state::ext_oneapi_source), MProgram(nullptr),
364-
MKernelIDs(std::make_shared<std::vector<kernel_id>>()),
365362
MSpecConstsDefValBlob(getSpecConstsDefValBlob()),
366363
MOrigins(ImageOriginKernelCompiler),
367364
MRTCBinInfo(KernelCompilerBinaryInfo{Lang}) {
@@ -376,7 +373,6 @@ class device_image_impl
376373
MContext(std::move(Context)),
377374
MDevices(Devices.to<std::vector<device_impl *>>()), MState(State),
378375
MProgram(Program),
379-
MKernelIDs(std::make_shared<std::vector<kernel_id>>()),
380376
MKernelNames{std::move(KernelNames)},
381377
MSpecConstsDefValBlob(getSpecConstsDefValBlob()),
382378
MOrigins(ImageOriginKernelCompiler),
@@ -389,6 +385,8 @@ class device_image_impl
389385
}
390386

391387
bool has_kernel(const kernel_id &KernelIDCand) const noexcept {
388+
if (!MKernelIDs)
389+
return false;
392390
return std::binary_search(MKernelIDs->begin(), MKernelIDs->end(),
393391
KernelIDCand, LessByHash<kernel_id>{});
394392
}
@@ -414,8 +412,18 @@ class device_image_impl
414412
return false;
415413
}
416414

417-
const std::vector<kernel_id> &get_kernel_ids() const noexcept {
418-
return *MKernelIDs;
415+
iterator_range<std::vector<kernel_id>::const_iterator>
416+
get_kernel_ids() const noexcept {
417+
if (MKernelIDs)
418+
return *MKernelIDs;
419+
else
420+
return {};
421+
}
422+
// This should only be used when creating new device_image_impls that have the
423+
// exact same set of kernels as the source one. In all other scenarios the
424+
// getter above is the one needed:
425+
std::shared_ptr<std::vector<kernel_id>> &get_kernel_ids_ptr() noexcept {
426+
return MKernelIDs;
419427
}
420428

421429
bool has_specialization_constants() const noexcept {
@@ -563,10 +571,6 @@ class device_image_impl
563571

564572
const context &get_context() const noexcept { return MContext; }
565573

566-
std::shared_ptr<std::vector<kernel_id>> &get_kernel_ids_ptr() noexcept {
567-
return MKernelIDs;
568-
}
569-
570574
std::vector<unsigned char> &get_spec_const_blob_ref() noexcept {
571575
return MSpecConstsBlob;
572576
}
@@ -1300,7 +1304,9 @@ class device_image_impl
13001304
ur_program_handle_t MProgram = nullptr;
13011305

13021306
// List of kernel ids available in this image, elements should be sorted
1303-
// according to LessByNameComp
1307+
// according to LessByNameComp. Shared between images for performance reasons
1308+
// (e.g. when we compile a single image it keeps the same kernels in it as the
1309+
// original source image).
13041310
std::shared_ptr<std::vector<kernel_id>> MKernelIDs;
13051311

13061312
// List of known kernel names.

sycl/source/detail/helpers.hpp

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class variadic_iterator {
5151
using pointer = value_type *;
5252
static_assert(std::is_same_v<reference, value_type &>);
5353

54+
variadic_iterator() = default;
5455
variadic_iterator(const variadic_iterator &) = default;
5556
variadic_iterator(variadic_iterator &&) = default;
5657
variadic_iterator(variadic_iterator &) = default;
@@ -88,7 +89,6 @@ class variadic_iterator {
8889
// Non-owning!
8990
template <typename iterator> class iterator_range {
9091
using value_type = typename iterator::value_type;
91-
using sycl_type = typename iterator::sycl_type;
9292

9393
template <typename Container, typename = void>
9494
struct has_reserve : public std::false_type {};
@@ -104,16 +104,20 @@ template <typename iterator> class iterator_range {
104104
iterator_range(IterTy Begin, IterTy End, size_t Size)
105105
: Begin(Begin), End(End), Size(Size) {}
106106

107-
iterator_range()
108-
: iterator_range(static_cast<value_type *>(nullptr),
109-
static_cast<value_type *>(nullptr), 0) {}
107+
iterator_range() : iterator_range(iterator{}, iterator{}, 0) {}
110108

111-
template <typename ContainerTy>
109+
template <typename ContainerTy, typename = std::void_t<decltype(iterator{
110+
std::declval<ContainerTy>().begin()})>>
112111
iterator_range(const ContainerTy &Container)
113112
: iterator_range(Container.begin(), Container.end(), Container.size()) {}
114113

115114
iterator_range(value_type &Obj) : iterator_range(&Obj, &Obj + 1, 1) {}
116115

116+
template <typename sycl_type,
117+
typename = std::void_t<decltype(iterator{
118+
&*getSyclObjImpl(std::declval<sycl_type>())})>,
119+
// To make it different from `ContainerTy` overload above:
120+
typename = void>
117121
iterator_range(const sycl_type &Obj)
118122
: iterator_range(&*getSyclObjImpl(Obj), (&*getSyclObjImpl(Obj) + 1), 1) {}
119123

@@ -123,13 +127,27 @@ template <typename iterator> class iterator_range {
123127
bool empty() const { return Size == 0; }
124128
decltype(auto) front() const { return *begin(); }
125129

126-
template <typename Container>
127-
std::enable_if_t<
128-
check_type_in_v<Container, std::vector<sycl_type>,
129-
std::queue<value_type *>, std::vector<value_type *>,
130-
std::vector<std::shared_ptr<value_type>>>,
131-
Container>
132-
to() const {
130+
bool contains(value_type &Other) const {
131+
return std::find_if(begin(), end(), [&Other](value_type &Elem) {
132+
return &Elem == &Other;
133+
}) != end();
134+
}
135+
136+
private:
137+
template <typename T, typename = void>
138+
struct allowed_elem_type
139+
: public std::bool_constant<
140+
check_type_in_v<T, value_type *, std::shared_ptr<value_type>>> {};
141+
142+
template <typename T>
143+
struct allowed_elem_type<T, std::void_t<decltype(createSyclObjFromImpl<T>(
144+
std::declval<value_type &>()))>>
145+
: public std::true_type {};
146+
147+
public:
148+
template <typename Container, typename = std::enable_if_t<allowed_elem_type<
149+
typename Container::value_type>::value>>
150+
Container to() const {
133151
std::conditional_t<std::is_same_v<Container, std::queue<value_type *>>,
134152
typename std::queue<value_type *>::container_type,
135153
Container>
@@ -138,31 +156,21 @@ template <typename iterator> class iterator_range {
138156
Result.reserve(size());
139157
std::transform(
140158
begin(), end(), std::back_inserter(Result), [](value_type &E) {
141-
if constexpr (std::is_same_v<Container, std::vector<sycl_type>>)
142-
return createSyclObjFromImpl<sycl_type>(E);
143-
else if constexpr (std::is_same_v<
144-
Container,
145-
std::vector<std::shared_ptr<value_type>>>)
159+
using container_value_type = typename Container::value_type;
160+
if constexpr (std::is_same_v<container_value_type,
161+
std::shared_ptr<value_type>>)
146162
return E.shared_from_this();
147-
else
163+
else if constexpr (std::is_same_v<container_value_type, value_type *>)
148164
return &E;
165+
else
166+
return createSyclObjFromImpl<container_value_type>(E);
149167
});
150168
if constexpr (std::is_same_v<Container, decltype(Result)>)
151169
return Result;
152170
else
153171
return Container{std::move(Result)};
154172
}
155173

156-
bool contains(value_type &Other) const {
157-
return std::find_if(begin(), end(), [&Other](value_type &Elem) {
158-
return &Elem == &Other;
159-
}) != end();
160-
}
161-
162-
protected:
163-
template <typename Container>
164-
static constexpr bool has_reserve_v = has_reserve<Container>::value;
165-
166174
private:
167175
iterator Begin;
168176
iterator End;

sycl/source/detail/kernel_bundle_impl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -782,7 +782,7 @@ class kernel_bundle_impl
782782
if (DevImgImpl.getRTCInfo())
783783
continue;
784784

785-
const std::vector<kernel_id> &KernelIDs = DevImgImpl.get_kernel_ids();
785+
auto KernelIDs = DevImgImpl.get_kernel_ids();
786786

787787
Result.insert(Result.end(), KernelIDs.begin(), KernelIDs.end());
788788
}

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2691,7 +2691,7 @@ ProgramManager::createDependencyImage(const context &Ctx, devices_range Devs,
26912691
assert(DepState == getBinImageState(DepImage) &&
26922692
"State mismatch between main image and its dependency");
26932693
DeviceImageImplPtr DepImpl =
2694-
device_image_impl::create(DepImage, Ctx, Devs, DepState, DepKernelIDs,
2694+
device_image_impl::create(DepImage, Ctx, Devs, DepState, std::move(DepKernelIDs),
26952695
/*PIProgram=*/nullptr, ImageOriginSYCLOffline);
26962696

26972697
return createSyclObjFromImpl<device_image_plain>(std::move(DepImpl));
@@ -2905,10 +2905,8 @@ mergeImageData(const std::vector<device_image_plain> &Imgs,
29052905
for (const device_image_plain &Img : Imgs) {
29062906
device_image_impl &DeviceImageImpl = *getSyclObjImpl(Img);
29072907
// Duplicates are not expected here, otherwise urProgramLink should fail
2908-
if (DeviceImageImpl.get_kernel_ids_ptr())
2909-
KernelIDs.insert(KernelIDs.end(),
2910-
DeviceImageImpl.get_kernel_ids_ptr()->begin(),
2911-
DeviceImageImpl.get_kernel_ids_ptr()->end());
2908+
KernelIDs.insert(KernelIDs.end(), DeviceImageImpl.get_kernel_ids().begin(),
2909+
DeviceImageImpl.get_kernel_ids().end());
29122910
// To be able to answer queries about specialziation constants, the new
29132911
// device image should have the specialization constants from all the linked
29142912
// images.

sycl/source/kernel_bundle.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,8 +288,8 @@ bool has_kernel_bundle_impl(const context &Ctx, const std::vector<device> &Devs,
288288
const std::shared_ptr<device_image_impl> &DeviceImageImpl =
289289
getSyclObjImpl(DeviceImage);
290290

291-
CombinedKernelIDs.insert(DeviceImageImpl->get_kernel_ids_ptr()->begin(),
292-
DeviceImageImpl->get_kernel_ids_ptr()->end());
291+
CombinedKernelIDs.insert(DeviceImageImpl->get_kernel_ids().begin(),
292+
DeviceImageImpl->get_kernel_ids().end());
293293
}
294294
}
295295

0 commit comments

Comments
 (0)