Skip to content

Commit 146abac

Browse files
[NFC][SYCL] Refactor code around device_image_impl::KernelIDs
* Clarify the reason for `std::shared_ptr` (performance) * Clarify that `get_kernel_ids` is preferred over `get_kernel_ids_ptr` and update some of the existing uses of the latter to use the former * Change `get_kernel_ids` to return `iterator_range` instead of the reference to the underlying container. That way we can also return an empty range when `shared_ptr` isn't initialized, which was possible before but many ctors led to believe that is guaranteed to never happen. * Change those ctors to keep empty `shared_ptr` instead of creating empty `vector`s
1 parent 22ee417 commit 146abac

File tree

5 files changed

+54
-47
lines changed

5 files changed

+54
-47
lines changed

sycl/source/detail/device_image_impl.hpp

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -311,9 +311,7 @@ class device_image_impl
311311
private_tag)
312312
: MBinImage(BinImage), MContext(std::move(Context)),
313313
MDevices(Devices.to<std::vector<device_impl *>>()), MState(State),
314-
MProgram(Program),
315-
MKernelIDs(std::make_shared<std::vector<kernel_id>>()),
316-
MKernelNames{std::move(KernelNames)},
314+
MProgram(Program), MKernelNames{std::move(KernelNames)},
317315
MEliminatedKernelArgMasks{std::move(EliminatedKernelArgMasks)},
318316
MSpecConstsDefValBlob(getSpecConstsDefValBlob()),
319317
MOrigins(ImageOriginKernelCompiler),
@@ -347,7 +345,6 @@ class device_image_impl
347345
: MBinImage(Src), MContext(std::move(Context)),
348346
MDevices(Devices.to<std::vector<device_impl *>>()),
349347
MState(bundle_state::ext_oneapi_source), MProgram(nullptr),
350-
MKernelIDs(std::make_shared<std::vector<kernel_id>>()),
351348
MSpecConstsDefValBlob(getSpecConstsDefValBlob()),
352349
MOrigins(ImageOriginKernelCompiler),
353350
MRTCBinInfo(
@@ -361,7 +358,6 @@ class device_image_impl
361358
: MBinImage(Bytes), MContext(std::move(Context)),
362359
MDevices(Devices.to<std::vector<device_impl *>>()),
363360
MState(bundle_state::ext_oneapi_source), MProgram(nullptr),
364-
MKernelIDs(std::make_shared<std::vector<kernel_id>>()),
365361
MSpecConstsDefValBlob(getSpecConstsDefValBlob()),
366362
MOrigins(ImageOriginKernelCompiler),
367363
MRTCBinInfo(KernelCompilerBinaryInfo{Lang}) {
@@ -375,9 +371,7 @@ class device_image_impl
375371
: MBinImage(static_cast<const RTDeviceBinaryImage *>(nullptr)),
376372
MContext(std::move(Context)),
377373
MDevices(Devices.to<std::vector<device_impl *>>()), MState(State),
378-
MProgram(Program),
379-
MKernelIDs(std::make_shared<std::vector<kernel_id>>()),
380-
MKernelNames{std::move(KernelNames)},
374+
MProgram(Program), MKernelNames{std::move(KernelNames)},
381375
MSpecConstsDefValBlob(getSpecConstsDefValBlob()),
382376
MOrigins(ImageOriginKernelCompiler),
383377
MRTCBinInfo(KernelCompilerBinaryInfo{Lang}) {}
@@ -389,6 +383,8 @@ class device_image_impl
389383
}
390384

391385
bool has_kernel(const kernel_id &KernelIDCand) const noexcept {
386+
if (!MKernelIDs)
387+
return false;
392388
return std::binary_search(MKernelIDs->begin(), MKernelIDs->end(),
393389
KernelIDCand, LessByHash<kernel_id>{});
394390
}
@@ -414,8 +410,18 @@ class device_image_impl
414410
return false;
415411
}
416412

417-
const std::vector<kernel_id> &get_kernel_ids() const noexcept {
418-
return *MKernelIDs;
413+
iterator_range<std::vector<kernel_id>::const_iterator>
414+
get_kernel_ids() const noexcept {
415+
if (MKernelIDs)
416+
return *MKernelIDs;
417+
else
418+
return {};
419+
}
420+
// This should only be used when creating new device_image_impls that have the
421+
// exact same set of kernels as the source one. In all other scenarios the
422+
// getter above is the one needed:
423+
std::shared_ptr<std::vector<kernel_id>> &get_kernel_ids_ptr() noexcept {
424+
return MKernelIDs;
419425
}
420426

421427
bool has_specialization_constants() const noexcept {
@@ -563,10 +569,6 @@ class device_image_impl
563569

564570
const context &get_context() const noexcept { return MContext; }
565571

566-
std::shared_ptr<std::vector<kernel_id>> &get_kernel_ids_ptr() noexcept {
567-
return MKernelIDs;
568-
}
569-
570572
std::vector<unsigned char> &get_spec_const_blob_ref() noexcept {
571573
return MSpecConstsBlob;
572574
}
@@ -1300,7 +1302,9 @@ class device_image_impl
13001302
ur_program_handle_t MProgram = nullptr;
13011303

13021304
// List of kernel ids available in this image, elements should be sorted
1303-
// according to LessByNameComp
1305+
// according to LessByNameComp. Shared between images for performance reasons
1306+
// (e.g. when we compile a single image it keeps the same kernels in it as the
1307+
// original source image).
13041308
std::shared_ptr<std::vector<kernel_id>> MKernelIDs;
13051309

13061310
// List of known kernel names.

sycl/source/detail/helpers.hpp

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class variadic_iterator {
5151
using pointer = value_type *;
5252
static_assert(std::is_same_v<reference, value_type &>);
5353

54+
variadic_iterator() = default;
5455
variadic_iterator(const variadic_iterator &) = default;
5556
variadic_iterator(variadic_iterator &&) = default;
5657
variadic_iterator(variadic_iterator &) = default;
@@ -88,7 +89,6 @@ class variadic_iterator {
8889
// Non-owning!
8990
template <typename iterator> class iterator_range {
9091
using value_type = typename iterator::value_type;
91-
using sycl_type = typename iterator::sycl_type;
9292

9393
template <typename Container, typename = void>
9494
struct has_reserve : public std::false_type {};
@@ -104,16 +104,20 @@ template <typename iterator> class iterator_range {
104104
iterator_range(IterTy Begin, IterTy End, size_t Size)
105105
: Begin(Begin), End(End), Size(Size) {}
106106

107-
iterator_range()
108-
: iterator_range(static_cast<value_type *>(nullptr),
109-
static_cast<value_type *>(nullptr), 0) {}
107+
iterator_range() : iterator_range(iterator{}, iterator{}, 0) {}
110108

111-
template <typename ContainerTy>
109+
template <typename ContainerTy, typename = std::void_t<decltype(iterator{
110+
std::declval<ContainerTy>().begin()})>>
112111
iterator_range(const ContainerTy &Container)
113112
: iterator_range(Container.begin(), Container.end(), Container.size()) {}
114113

115114
iterator_range(value_type &Obj) : iterator_range(&Obj, &Obj + 1, 1) {}
116115

116+
template <typename sycl_type,
117+
typename = std::void_t<decltype(iterator{
118+
&*getSyclObjImpl(std::declval<sycl_type>())})>,
119+
// To make it different from `ContainerTy` overload above:
120+
typename = void>
117121
iterator_range(const sycl_type &Obj)
118122
: iterator_range(&*getSyclObjImpl(Obj), (&*getSyclObjImpl(Obj) + 1), 1) {}
119123

@@ -123,13 +127,15 @@ template <typename iterator> class iterator_range {
123127
bool empty() const { return Size == 0; }
124128
decltype(auto) front() const { return *begin(); }
125129

126-
template <typename Container>
127-
std::enable_if_t<
128-
check_type_in_v<Container, std::vector<sycl_type>,
129-
std::queue<value_type *>, std::vector<value_type *>,
130-
std::vector<std::shared_ptr<value_type>>>,
131-
Container>
132-
to() const {
130+
// Only enable for ranges of `variadic_iterator` and for the containers with
131+
// proper `value_type`. The last part is important so that descendent
132+
// `devices_range` could provide its own specialization for
133+
// `to<std::vector<device_handle_t>>()`.
134+
template <typename Container, typename iterator_ = iterator,
135+
typename = std::enable_if_t<check_type_in_v<
136+
typename Container::value_type, value_type *,
137+
std::shared_ptr<value_type>, typename iterator_::sycl_type>>>
138+
Container to() const {
133139
std::conditional_t<std::is_same_v<Container, std::queue<value_type *>>,
134140
typename std::queue<value_type *>::container_type,
135141
Container>
@@ -138,31 +144,30 @@ template <typename iterator> class iterator_range {
138144
Result.reserve(size());
139145
std::transform(
140146
begin(), end(), std::back_inserter(Result), [](value_type &E) {
141-
if constexpr (std::is_same_v<Container, std::vector<sycl_type>>)
142-
return createSyclObjFromImpl<sycl_type>(E);
143-
else if constexpr (std::is_same_v<
144-
Container,
145-
std::vector<std::shared_ptr<value_type>>>)
147+
using container_value_type = typename Container::value_type;
148+
if constexpr (std::is_same_v<container_value_type,
149+
std::shared_ptr<value_type>>)
146150
return E.shared_from_this();
147-
else
151+
else if constexpr (std::is_same_v<container_value_type, value_type *>)
148152
return &E;
153+
else
154+
return createSyclObjFromImpl<container_value_type>(E);
149155
});
150156
if constexpr (std::is_same_v<Container, decltype(Result)>)
151157
return Result;
152158
else
153159
return Container{std::move(Result)};
154160
}
155161

162+
// Only enable for ranges of `variadic_iterator` above.
163+
template <typename T = iterator,
164+
typename = std::void_t<typename T::sycl_type>>
156165
bool contains(value_type &Other) const {
157166
return std::find_if(begin(), end(), [&Other](value_type &Elem) {
158167
return &Elem == &Other;
159168
}) != end();
160169
}
161170

162-
protected:
163-
template <typename Container>
164-
static constexpr bool has_reserve_v = has_reserve<Container>::value;
165-
166171
private:
167172
iterator Begin;
168173
iterator End;

sycl/source/detail/kernel_bundle_impl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -782,7 +782,7 @@ class kernel_bundle_impl
782782
if (DevImgImpl.getRTCInfo())
783783
continue;
784784

785-
const std::vector<kernel_id> &KernelIDs = DevImgImpl.get_kernel_ids();
785+
auto KernelIDs = DevImgImpl.get_kernel_ids();
786786

787787
Result.insert(Result.end(), KernelIDs.begin(), KernelIDs.end());
788788
}

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2690,9 +2690,9 @@ ProgramManager::createDependencyImage(const context &Ctx, devices_range Devs,
26902690

26912691
assert(DepState == getBinImageState(DepImage) &&
26922692
"State mismatch between main image and its dependency");
2693-
DeviceImageImplPtr DepImpl =
2694-
device_image_impl::create(DepImage, Ctx, Devs, DepState, DepKernelIDs,
2695-
/*PIProgram=*/nullptr, ImageOriginSYCLOffline);
2693+
DeviceImageImplPtr DepImpl = device_image_impl::create(
2694+
DepImage, Ctx, Devs, DepState, std::move(DepKernelIDs),
2695+
/*PIProgram=*/nullptr, ImageOriginSYCLOffline);
26962696

26972697
return createSyclObjFromImpl<device_image_plain>(std::move(DepImpl));
26982698
}
@@ -2905,10 +2905,8 @@ mergeImageData(const std::vector<device_image_plain> &Imgs,
29052905
for (const device_image_plain &Img : Imgs) {
29062906
device_image_impl &DeviceImageImpl = *getSyclObjImpl(Img);
29072907
// Duplicates are not expected here, otherwise urProgramLink should fail
2908-
if (DeviceImageImpl.get_kernel_ids_ptr())
2909-
KernelIDs.insert(KernelIDs.end(),
2910-
DeviceImageImpl.get_kernel_ids_ptr()->begin(),
2911-
DeviceImageImpl.get_kernel_ids_ptr()->end());
2908+
KernelIDs.insert(KernelIDs.end(), DeviceImageImpl.get_kernel_ids().begin(),
2909+
DeviceImageImpl.get_kernel_ids().end());
29122910
// To be able to answer queries about specialziation constants, the new
29132911
// device image should have the specialization constants from all the linked
29142912
// images.

sycl/source/kernel_bundle.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,8 +288,8 @@ bool has_kernel_bundle_impl(const context &Ctx, const std::vector<device> &Devs,
288288
const std::shared_ptr<device_image_impl> &DeviceImageImpl =
289289
getSyclObjImpl(DeviceImage);
290290

291-
CombinedKernelIDs.insert(DeviceImageImpl->get_kernel_ids_ptr()->begin(),
292-
DeviceImageImpl->get_kernel_ids_ptr()->end());
291+
CombinedKernelIDs.insert(DeviceImageImpl->get_kernel_ids().begin(),
292+
DeviceImageImpl->get_kernel_ids().end());
293293
}
294294
}
295295

0 commit comments

Comments
 (0)