Skip to content

Commit 06f765f

Browse files
[NFC][SYCL] Add device_images_range helper
1 parent c0929c3 commit 06f765f

File tree

3 files changed

+78
-63
lines changed

3 files changed

+78
-63
lines changed

sycl/source/detail/device_image_impl.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,6 +1338,17 @@ class device_image_impl
13381338
std::unique_ptr<DynRTDeviceBinaryImage> MMergedImageStorage = nullptr;
13391339
};
13401340

1341+
using device_images_iterator =
1342+
variadic_iterator<device_image_plain,
1343+
std::vector<device_image_plain>::const_iterator>;
1344+
class device_images_range : public iterator_range<device_images_iterator> {
1345+
private:
1346+
using Base = iterator_range<device_images_iterator>;
1347+
1348+
public:
1349+
using Base::Base;
1350+
};
1351+
13411352
} // namespace detail
13421353
} // namespace _V1
13431354
} // namespace sycl

sycl/source/detail/helpers.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ template <typename SyclTy, typename... Iterators> class variadic_iterator {
9090
},
9191
It);
9292
}
93+
94+
pointer operator->() { return &this->operator*(); }
9395
};
9496

9597
// Non-owning!
@@ -175,6 +177,21 @@ template <typename iterator> class iterator_range {
175177
iterator End;
176178
const size_t Size;
177179
};
180+
181+
template <typename iterator, class Pred>
182+
bool all_of(iterator_range<iterator> R, Pred &&P) {
183+
return std::all_of(R.begin(), R.end(), std::forward<Pred>(P));
184+
}
185+
186+
template <typename iterator, class Pred>
187+
bool any_of(iterator_range<iterator> R, Pred &&P) {
188+
return std::any_of(R.begin(), R.end(), std::forward<Pred>(P));
189+
}
190+
191+
template <typename iterator, class Pred>
192+
bool none_of(iterator_range<iterator> R, Pred &&P) {
193+
return std::none_of(R.begin(), R.end(), std::forward<Pred>(P));
194+
}
178195
} // namespace detail
179196
} // namespace _V1
180197
} // namespace sycl

sycl/source/detail/kernel_bundle_impl.hpp

Lines changed: 50 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -504,8 +504,7 @@ class kernel_bundle_impl
504504

505505
if (get_bundle_state() == bundle_state::input) {
506506
// Copy spec constants values from the device images.
507-
auto MergeSpecConstants = [this](const device_image_plain &Img) {
508-
detail::device_image_impl &ImgImpl = *getSyclObjImpl(Img);
507+
for (detail::device_image_impl &ImgImpl : device_images()) {
509508
const std::map<std::string,
510509
std::vector<device_image_impl::SpecConstDescT>>
511510
&SpecConsts = ImgImpl.get_spec_const_data_ref();
@@ -521,8 +520,7 @@ class kernel_bundle_impl
521520
SpecConst.second.back().CompositeOffset +
522521
SpecConst.second.back().Size);
523522
}
524-
};
525-
std::for_each(begin(), end(), MergeSpecConstants);
523+
}
526524
}
527525

528526
for (const detail::KernelBundleImplPtr &Bundle : Bundles) {
@@ -667,10 +665,9 @@ class kernel_bundle_impl
667665

668666
public:
669667
bool ext_oneapi_has_kernel(const std::string &Name) const {
670-
return std::any_of(begin(), end(),
671-
[&Name](const device_image_plain &DevImg) {
672-
return getSyclObjImpl(DevImg)->hasKernelName(Name);
673-
});
668+
return any_of(device_images(), [&Name](device_image_impl &DevImg) {
669+
return DevImg.hasKernelName(Name);
670+
});
674671
}
675672

676673
kernel ext_oneapi_get_kernel(const std::string &Name) const {
@@ -705,25 +702,24 @@ class kernel_bundle_impl
705702
"files and kernel_bundles successfully built from "
706703
"kernel_bundle<bundle_state::ext_oneapi_source>.");
707704

708-
auto It =
709-
std::find_if(begin(), end(), [&Name](const device_image_plain &DevImg) {
710-
return getSyclObjImpl(DevImg)->hasKernelName(Name);
711-
});
712-
if (It == end())
705+
auto It = std::find_if(device_images().begin(), device_images().end(),
706+
[&Name](device_image_impl &DevImg) {
707+
return DevImg.hasKernelName(Name);
708+
});
709+
if (It == device_images().end())
713710
throw sycl::exception(make_error_code(errc::invalid),
714711
"kernel '" + Name + "' not found in kernel_bundle");
715712

716-
return getSyclObjImpl(*It)->adjustKernelName(Name);
713+
return It->adjustKernelName(Name);
717714
}
718715

719716
bool ext_oneapi_has_device_global(const std::string &Name) const {
720717
std::string MangledName = mangleDeviceGlobalName(Name);
721718
return (MDeviceGlobals.size() &&
722719
MDeviceGlobals.tryGetEntryLockless(MangledName)) ||
723-
std::any_of(begin(), end(),
724-
[&MangledName](const device_image_plain &DeviceImage) {
725-
return getSyclObjImpl(DeviceImage)
726-
->hasDeviceGlobalName(MangledName);
720+
std::any_of(device_images().begin(), device_images().end(),
721+
[&MangledName](device_image_impl &DeviceImage) {
722+
return DeviceImage.hasDeviceGlobalName(MangledName);
727723
});
728724
}
729725

@@ -803,51 +799,43 @@ class kernel_bundle_impl
803799
}
804800

805801
bool has_kernel(const kernel_id &KernelID) const noexcept {
806-
return std::any_of(begin(), end(),
807-
[&KernelID](const device_image_plain &DeviceImage) {
808-
return DeviceImage.has_kernel(KernelID);
809-
});
802+
return any_of(device_images(), [&KernelID](device_image_impl &DeviceImage) {
803+
return DeviceImage.has_kernel(KernelID);
804+
});
810805
}
811806

812807
bool has_kernel(const kernel_id &KernelID, const device &Dev) const noexcept {
813-
return std::any_of(
814-
begin(), end(),
815-
[&KernelID, &Dev](const device_image_plain &DeviceImage) {
816-
return DeviceImage.has_kernel(KernelID, Dev);
817-
});
808+
return any_of(device_images(),
809+
[&KernelID, &Dev](device_image_impl &DeviceImage) {
810+
return DeviceImage.has_kernel(KernelID, Dev);
811+
});
818812
}
819813

820814
bool contains_specialization_constants() const noexcept {
821-
return std::any_of(
822-
begin(), end(), [](const device_image_plain &DeviceImage) {
823-
return getSyclObjImpl(DeviceImage)->has_specialization_constants();
824-
});
815+
return any_of(device_images(), [](device_image_impl &DeviceImage) {
816+
return DeviceImage.has_specialization_constants();
817+
});
825818
}
826819

827820
bool native_specialization_constant() const noexcept {
828821
return contains_specialization_constants() &&
829-
std::all_of(begin(), end(),
830-
[](const device_image_plain &DeviceImage) {
831-
return getSyclObjImpl(DeviceImage)
832-
->all_specialization_constant_native();
833-
});
822+
all_of(device_images(), [](device_image_impl &DeviceImage) {
823+
return DeviceImage.all_specialization_constant_native();
824+
});
834825
}
835826

836827
bool has_specialization_constant(const char *SpecName) const noexcept {
837-
return std::any_of(begin(), end(),
838-
[SpecName](const device_image_plain &DeviceImage) {
839-
return getSyclObjImpl(DeviceImage)
840-
->has_specialization_constant(SpecName);
841-
});
828+
return any_of(device_images(), [SpecName](device_image_impl &DeviceImage) {
829+
return DeviceImage.has_specialization_constant(SpecName);
830+
});
842831
}
843832

844833
void set_specialization_constant_raw_value(const char *SpecName,
845834
const void *Value,
846835
size_t Size) noexcept {
847836
if (has_specialization_constant(SpecName))
848-
for (const device_image_plain &DeviceImage : MUniqueDeviceImages)
849-
getSyclObjImpl(DeviceImage)
850-
->set_specialization_constant_raw_value(SpecName, Value);
837+
for (device_image_impl &DeviceImage : device_images())
838+
DeviceImage.set_specialization_constant_raw_value(SpecName, Value);
851839
else {
852840
std::vector<unsigned char> &Val = MSpecConstValues[std::string{SpecName}];
853841
Val.resize(Size);
@@ -857,10 +845,9 @@ class kernel_bundle_impl
857845

858846
void get_specialization_constant_raw_value(const char *SpecName,
859847
void *ValueRet) const noexcept {
860-
for (const device_image_plain &DeviceImage : MUniqueDeviceImages)
861-
if (getSyclObjImpl(DeviceImage)->has_specialization_constant(SpecName)) {
862-
getSyclObjImpl(DeviceImage)
863-
->get_specialization_constant_raw_value(SpecName, ValueRet);
848+
for (device_image_impl &DeviceImage : device_images())
849+
if (DeviceImage.has_specialization_constant(SpecName)) {
850+
DeviceImage.get_specialization_constant_raw_value(SpecName, ValueRet);
864851
return;
865852
}
866853

@@ -879,19 +866,21 @@ class kernel_bundle_impl
879866
}
880867

881868
bool is_specialization_constant_set(const char *SpecName) const noexcept {
882-
bool SetInDevImg = std::any_of(
883-
begin(), end(), [SpecName](const device_image_plain &DeviceImage) {
884-
return getSyclObjImpl(DeviceImage)
885-
->is_specialization_constant_set(SpecName);
869+
bool SetInDevImg =
870+
any_of(device_images(), [SpecName](device_image_impl &DeviceImage) {
871+
return DeviceImage.is_specialization_constant_set(SpecName);
886872
});
887873
return SetInDevImg || MSpecConstValues.count(std::string{SpecName}) != 0;
888874
}
889875

876+
// Don't use these two for code under `source/detail`, they are only needed to
877+
// communicate across DSO boundary.
890878
const device_image_plain *begin() const { return MUniqueDeviceImages.data(); }
891-
892879
const device_image_plain *end() const {
893880
return MUniqueDeviceImages.data() + MUniqueDeviceImages.size();
894881
}
882+
// ...use that instead.
883+
device_images_range device_images() const { return MUniqueDeviceImages; }
895884

896885
size_t size() const noexcept { return MUniqueDeviceImages.size(); }
897886

@@ -931,28 +920,26 @@ class kernel_bundle_impl
931920
}
932921

933922
bool hasSourceBasedImages() const noexcept {
934-
return std::any_of(begin(), end(), [](const device_image_plain &DevImg) {
935-
return getSyclObjImpl(DevImg)->getOriginMask() &
936-
ImageOriginKernelCompiler;
923+
return any_of(device_images(), [](device_image_impl &DevImg) {
924+
return DevImg.getOriginMask() & ImageOriginKernelCompiler;
937925
});
938926
}
939927

940928
bool hasSYCLBINImages() const noexcept {
941-
return std::any_of(begin(), end(), [](const device_image_plain &DevImg) {
942-
return getSyclObjImpl(DevImg)->getOriginMask() & ImageOriginSYCLBIN;
929+
return any_of(device_images(), [](device_image_impl &DevImg) {
930+
return DevImg.getOriginMask() & ImageOriginSYCLBIN;
943931
});
944932
}
945933

946934
bool hasSYCLOfflineImages() const noexcept {
947-
return std::any_of(begin(), end(), [](const device_image_plain &DevImg) {
948-
return getSyclObjImpl(DevImg)->getOriginMask() & ImageOriginSYCLOffline;
935+
return any_of(device_images(), [](device_image_impl &DevImg) {
936+
return DevImg.getOriginMask() & ImageOriginSYCLOffline;
949937
});
950938
}
951939

952940
bool allSourceBasedImages() const noexcept {
953-
return std::all_of(begin(), end(), [](const device_image_plain &DevImg) {
954-
return getSyclObjImpl(DevImg)->getOriginMask() &
955-
ImageOriginKernelCompiler;
941+
return all_of(device_images(), [](device_image_impl &DevImg) {
942+
return DevImg.getOriginMask() & ImageOriginKernelCompiler;
956943
});
957944
}
958945

0 commit comments

Comments
 (0)