@@ -504,8 +504,7 @@ class kernel_bundle_impl
504
504
505
505
if (get_bundle_state () == bundle_state::input) {
506
506
// Copy spec constants values from the device images.
507
- auto MergeSpecConstants = [this ](const device_image_plain &Img) {
508
- detail::device_image_impl &ImgImpl = *getSyclObjImpl (Img);
507
+ for (detail::device_image_impl &ImgImpl : device_images ()) {
509
508
const std::map<std::string,
510
509
std::vector<device_image_impl::SpecConstDescT>>
511
510
&SpecConsts = ImgImpl.get_spec_const_data_ref ();
@@ -521,8 +520,7 @@ class kernel_bundle_impl
521
520
SpecConst.second .back ().CompositeOffset +
522
521
SpecConst.second .back ().Size );
523
522
}
524
- };
525
- std::for_each (begin (), end (), MergeSpecConstants);
523
+ }
526
524
}
527
525
528
526
for (const detail::KernelBundleImplPtr &Bundle : Bundles) {
@@ -667,10 +665,9 @@ class kernel_bundle_impl
667
665
668
666
public:
669
667
bool ext_oneapi_has_kernel (const std::string &Name) const {
670
- return std::any_of (begin (), end (),
671
- [&Name](const device_image_plain &DevImg) {
672
- return getSyclObjImpl (DevImg)->hasKernelName (Name);
673
- });
668
+ return any_of (device_images (), [&Name](device_image_impl &DevImg) {
669
+ return DevImg.hasKernelName (Name);
670
+ });
674
671
}
675
672
676
673
kernel ext_oneapi_get_kernel (const std::string &Name) const {
@@ -705,25 +702,24 @@ class kernel_bundle_impl
705
702
" files and kernel_bundles successfully built from "
706
703
" kernel_bundle<bundle_state::ext_oneapi_source>." );
707
704
708
- auto It =
709
- std::find_if ( begin (), end (), [&Name](const device_image_plain &DevImg) {
710
- return getSyclObjImpl ( DevImg)-> hasKernelName (Name);
711
- });
712
- if (It == end ())
705
+ auto It = std::find_if ( device_images (). begin (), device_images (). end (),
706
+ [&Name](device_image_impl &DevImg) {
707
+ return DevImg. hasKernelName (Name);
708
+ });
709
+ if (It == device_images (). end ())
713
710
throw sycl::exception (make_error_code (errc::invalid),
714
711
" kernel '" + Name + " ' not found in kernel_bundle" );
715
712
716
- return getSyclObjImpl (*It) ->adjustKernelName (Name);
713
+ return It ->adjustKernelName (Name);
717
714
}
718
715
719
716
bool ext_oneapi_has_device_global (const std::string &Name) const {
720
717
std::string MangledName = mangleDeviceGlobalName (Name);
721
718
return (MDeviceGlobals.size () &&
722
719
MDeviceGlobals.tryGetEntryLockless (MangledName)) ||
723
- std::any_of (begin (), end (),
724
- [&MangledName](const device_image_plain &DeviceImage) {
725
- return getSyclObjImpl (DeviceImage)
726
- ->hasDeviceGlobalName (MangledName);
720
+ std::any_of (device_images ().begin (), device_images ().end (),
721
+ [&MangledName](device_image_impl &DeviceImage) {
722
+ return DeviceImage.hasDeviceGlobalName (MangledName);
727
723
});
728
724
}
729
725
@@ -803,51 +799,43 @@ class kernel_bundle_impl
803
799
}
804
800
805
801
bool has_kernel (const kernel_id &KernelID) const noexcept {
806
- return std::any_of (begin (), end (),
807
- [&KernelID](const device_image_plain &DeviceImage) {
808
- return DeviceImage.has_kernel (KernelID);
809
- });
802
+ return any_of (device_images (), [&KernelID](device_image_impl &DeviceImage) {
803
+ return DeviceImage.has_kernel (KernelID);
804
+ });
810
805
}
811
806
812
807
bool has_kernel (const kernel_id &KernelID, const device &Dev) const noexcept {
813
- return std::any_of (
814
- begin (), end (),
815
- [&KernelID, &Dev](const device_image_plain &DeviceImage) {
816
- return DeviceImage.has_kernel (KernelID, Dev);
817
- });
808
+ return any_of (device_images (),
809
+ [&KernelID, &Dev](device_image_impl &DeviceImage) {
810
+ return DeviceImage.has_kernel (KernelID, Dev);
811
+ });
818
812
}
819
813
820
814
bool contains_specialization_constants () const noexcept {
821
- return std::any_of (
822
- begin (), end (), [](const device_image_plain &DeviceImage) {
823
- return getSyclObjImpl (DeviceImage)->has_specialization_constants ();
824
- });
815
+ return any_of (device_images (), [](device_image_impl &DeviceImage) {
816
+ return DeviceImage.has_specialization_constants ();
817
+ });
825
818
}
826
819
827
820
bool native_specialization_constant () const noexcept {
828
821
return contains_specialization_constants () &&
829
- std::all_of (begin (), end (),
830
- [](const device_image_plain &DeviceImage) {
831
- return getSyclObjImpl (DeviceImage)
832
- ->all_specialization_constant_native ();
833
- });
822
+ all_of (device_images (), [](device_image_impl &DeviceImage) {
823
+ return DeviceImage.all_specialization_constant_native ();
824
+ });
834
825
}
835
826
836
827
bool has_specialization_constant (const char *SpecName) const noexcept {
837
- return std::any_of (begin (), end (),
838
- [SpecName](const device_image_plain &DeviceImage) {
839
- return getSyclObjImpl (DeviceImage)
840
- ->has_specialization_constant (SpecName);
841
- });
828
+ return any_of (device_images (), [SpecName](device_image_impl &DeviceImage) {
829
+ return DeviceImage.has_specialization_constant (SpecName);
830
+ });
842
831
}
843
832
844
833
void set_specialization_constant_raw_value (const char *SpecName,
845
834
const void *Value,
846
835
size_t Size) noexcept {
847
836
if (has_specialization_constant (SpecName))
848
- for (const device_image_plain &DeviceImage : MUniqueDeviceImages)
849
- getSyclObjImpl (DeviceImage)
850
- ->set_specialization_constant_raw_value (SpecName, Value);
837
+ for (device_image_impl &DeviceImage : device_images ())
838
+ DeviceImage.set_specialization_constant_raw_value (SpecName, Value);
851
839
else {
852
840
std::vector<unsigned char > &Val = MSpecConstValues[std::string{SpecName}];
853
841
Val.resize (Size);
@@ -857,10 +845,9 @@ class kernel_bundle_impl
857
845
858
846
void get_specialization_constant_raw_value (const char *SpecName,
859
847
void *ValueRet) const noexcept {
860
- for (const device_image_plain &DeviceImage : MUniqueDeviceImages)
861
- if (getSyclObjImpl (DeviceImage)->has_specialization_constant (SpecName)) {
862
- getSyclObjImpl (DeviceImage)
863
- ->get_specialization_constant_raw_value (SpecName, ValueRet);
848
+ for (device_image_impl &DeviceImage : device_images ())
849
+ if (DeviceImage.has_specialization_constant (SpecName)) {
850
+ DeviceImage.get_specialization_constant_raw_value (SpecName, ValueRet);
864
851
return ;
865
852
}
866
853
@@ -879,19 +866,21 @@ class kernel_bundle_impl
879
866
}
880
867
881
868
bool is_specialization_constant_set (const char *SpecName) const noexcept {
882
- bool SetInDevImg = std::any_of (
883
- begin (), end (), [SpecName](const device_image_plain &DeviceImage) {
884
- return getSyclObjImpl (DeviceImage)
885
- ->is_specialization_constant_set (SpecName);
869
+ bool SetInDevImg =
870
+ any_of (device_images (), [SpecName](device_image_impl &DeviceImage) {
871
+ return DeviceImage.is_specialization_constant_set (SpecName);
886
872
});
887
873
return SetInDevImg || MSpecConstValues.count (std::string{SpecName}) != 0 ;
888
874
}
889
875
876
+ // Don't use these two for code under `source/detail`, they are only needed to
877
+ // communicate across DSO boundary.
890
878
const device_image_plain *begin () const { return MUniqueDeviceImages.data (); }
891
-
892
879
const device_image_plain *end () const {
893
880
return MUniqueDeviceImages.data () + MUniqueDeviceImages.size ();
894
881
}
882
+ // ...use that instead.
883
+ device_images_range device_images () const { return MUniqueDeviceImages; }
895
884
896
885
size_t size () const noexcept { return MUniqueDeviceImages.size (); }
897
886
@@ -931,28 +920,26 @@ class kernel_bundle_impl
931
920
}
932
921
933
922
bool hasSourceBasedImages () const noexcept {
934
- return std::any_of (begin (), end (), [](const device_image_plain &DevImg) {
935
- return getSyclObjImpl (DevImg)->getOriginMask () &
936
- ImageOriginKernelCompiler;
923
+ return any_of (device_images (), [](device_image_impl &DevImg) {
924
+ return DevImg.getOriginMask () & ImageOriginKernelCompiler;
937
925
});
938
926
}
939
927
940
928
bool hasSYCLBINImages () const noexcept {
941
- return std:: any_of (begin (), end (), [](const device_image_plain &DevImg) {
942
- return getSyclObjImpl ( DevImg)-> getOriginMask () & ImageOriginSYCLBIN;
929
+ return any_of (device_images (), [](device_image_impl &DevImg) {
930
+ return DevImg. getOriginMask () & ImageOriginSYCLBIN;
943
931
});
944
932
}
945
933
946
934
bool hasSYCLOfflineImages () const noexcept {
947
- return std:: any_of (begin (), end (), [](const device_image_plain &DevImg) {
948
- return getSyclObjImpl ( DevImg)-> getOriginMask () & ImageOriginSYCLOffline;
935
+ return any_of (device_images (), [](device_image_impl &DevImg) {
936
+ return DevImg. getOriginMask () & ImageOriginSYCLOffline;
949
937
});
950
938
}
951
939
952
940
bool allSourceBasedImages () const noexcept {
953
- return std::all_of (begin (), end (), [](const device_image_plain &DevImg) {
954
- return getSyclObjImpl (DevImg)->getOriginMask () &
955
- ImageOriginKernelCompiler;
941
+ return all_of (device_images (), [](device_image_impl &DevImg) {
942
+ return DevImg.getOriginMask () & ImageOriginKernelCompiler;
956
943
});
957
944
}
958
945
0 commit comments