Skip to content

[NFC][SYCL] std::shared_ptr<device_image_impl> cleanups #19506

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: sycl
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions sycl/include/sycl/kernel_bundle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,12 @@ class __SYCL_EXPORT kernel_id : public detail::OwnerLessBase<kernel_id> {

namespace detail {
class device_image_impl;
using DeviceImageImplPtr = std::shared_ptr<device_image_impl>;

// The class is used as a base for device_image for "untemplating" public
// methods.
class __SYCL_EXPORT device_image_plain {
public:
device_image_plain(const detail::DeviceImageImplPtr &Impl)
device_image_plain(std::shared_ptr<device_image_impl> &&Impl)
: impl(std::move(Impl)) {}

bool operator==(const device_image_plain &RHS) const {
Expand All @@ -124,7 +123,7 @@ class __SYCL_EXPORT device_image_plain {
ur_native_handle_t getNative() const;

protected:
detail::DeviceImageImplPtr impl;
std::shared_ptr<device_image_impl> impl;

template <class Obj>
friend const decltype(Obj::impl) &
Expand Down Expand Up @@ -191,7 +190,7 @@ class device_image : public detail::device_image_plain,
#endif // _HAS_STD_BYTE

private:
device_image(detail::DeviceImageImplPtr Impl)
device_image(std::shared_ptr<detail::device_image_impl> &&Impl)
: device_image_plain(std::move(Impl)) {}

template <class Obj>
Expand Down Expand Up @@ -736,7 +735,7 @@ namespace detail {

// Stable selector function type for passing thru library boundaries
using DevImgSelectorImpl =
std::function<bool(const detail::DeviceImageImplPtr &DevImgImpl)>;
std::function<bool(const std::shared_ptr<device_image_impl> &DevImgImpl)>;

// Internal non-template versions of get_kernel_bundle API which is used by
// public onces
Expand Down Expand Up @@ -769,7 +768,7 @@ kernel_bundle<State> get_kernel_bundle(const context &Ctx,
std::vector<device> UniqueDevices = detail::removeDuplicateDevices(Devs);

detail::DevImgSelectorImpl SelectorWrapper =
[Selector](const detail::DeviceImageImplPtr &DevImg) {
[Selector](const std::shared_ptr<detail::device_image_impl> &DevImg) {
return Selector(
detail::createSyclObjFromImpl<sycl::device_image<State>>(DevImg));
};
Expand Down
13 changes: 6 additions & 7 deletions sycl/source/backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,13 +300,12 @@ make_kernel_bundle(ur_native_handle_t NativeHandle,
// this by pre-building the device image and extracting kernel info. We can't
// do the same to user images, since they may contain references to undefined
// symbols (e.g. when kernel_bundle is supposed to be joined with another).
auto KernelIDs = std::make_shared<std::vector<kernel_id>>();
auto DevImgImpl =
device_image_impl::create(nullptr, TargetContext, Devices, State,
KernelIDs, UrProgram, ImageOriginInterop);
device_image_plain DevImg{DevImgImpl};

return kernel_bundle_impl::create(TargetContext, Devices, DevImg);
return kernel_bundle_impl::create(
TargetContext, Devices,
device_image_plain{
device_image_impl::create(nullptr, TargetContext, Devices, State,
std::make_shared<std::vector<kernel_id>>(),
UrProgram, ImageOriginInterop)});
}

// TODO: Unused. Remove when allowed.
Expand Down
4 changes: 2 additions & 2 deletions sycl/source/detail/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ const RTDeviceBinaryImage *retrieveKernelBinary(queue_impl &Queue,
}

if (KernelCG->MSyclKernel != nullptr)
return KernelCG->MSyclKernel->getDeviceImage()->get_bin_image_ref();
return KernelCG->MSyclKernel->getDeviceImage().get_bin_image_ref();

if (auto KernelBundleImpl = KernelCG->getKernelBundle())
if (auto SyclKernelImpl = KernelBundleImpl->tryGetKernel(KernelName))
// Retrieve the device image from the kernel bundle.
return SyclKernelImpl->getDeviceImage()->get_bin_image_ref();
return SyclKernelImpl->getDeviceImage().get_bin_image_ref();

context_impl &ContextImpl = Queue.getContextImpl();
return &detail::ProgramManager::getInstance().getDeviceImage(
Expand Down
29 changes: 14 additions & 15 deletions sycl/source/detail/kernel_bundle_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ class kernel_bundle_impl

// Interop constructor
kernel_bundle_impl(context Ctx, devices_range Devs,
device_image_plain &DevImage, private_tag Tag)
device_image_plain &&DevImage, private_tag Tag)
: kernel_bundle_impl(std::move(Ctx), Devs, Tag) {
MDeviceImages.emplace_back(DevImage);
MDeviceImages.emplace_back(std::move(DevImage));
MUniqueDeviceImages.emplace_back(DevImage);
}

Expand Down Expand Up @@ -162,9 +162,9 @@ class kernel_bundle_impl
InputBundleImpl.MDeviceImages) {
// Skip images which are not compatible with devices provided
if (std::none_of(get_devices().begin(), get_devices().end(),
[&DevImgWithDeps](device_impl &Dev) {
return getSyclObjImpl(DevImgWithDeps.getMain())
->compatible_with_device(Dev);
[&MainImg = *getSyclObjImpl(DevImgWithDeps.getMain())](
device_impl &Dev) {
return MainImg.compatible_with_device(Dev);
}))
continue;

Expand Down Expand Up @@ -249,8 +249,7 @@ class kernel_bundle_impl
// images with specialization constants in separation.
// TODO: Remove when spec const overwriting issue has been fixed in L0.
std::vector<const DevImgPlainWithDeps *> ImagesWithSpecConsts;
std::unordered_set<std::shared_ptr<device_image_impl>>
ImagesWithSpecConstsSet;
std::unordered_set<device_image_impl *> ImagesWithSpecConstsSet;
for (const kernel_bundle<bundle_state::object> &ObjectBundle :
ObjectBundles) {
for (const DevImgPlainWithDeps &DeviceImageWithDeps :
Expand All @@ -265,7 +264,7 @@ class kernel_bundle_impl

ImagesWithSpecConsts.push_back(&DeviceImageWithDeps);
for (const device_image_plain &DevImg : DeviceImageWithDeps)
ImagesWithSpecConstsSet.insert(getSyclObjImpl(DevImg));
ImagesWithSpecConstsSet.insert(&*getSyclObjImpl(DevImg));
}
}

Expand All @@ -284,8 +283,7 @@ class kernel_bundle_impl
// been seen before or the device image implementation is in the
// image set already.
if ((BinImg && SeenBinImgs.find(BinImg) != SeenBinImgs.end()) ||
ImagesWithSpecConstsSet.find(DevImgImpl) !=
ImagesWithSpecConstsSet.end())
ImagesWithSpecConstsSet.count(&*DevImgImpl))
continue;
SeenBinImgs.insert(BinImg);
DevImagesSet.insert(DevImgImpl);
Expand Down Expand Up @@ -401,9 +399,9 @@ class kernel_bundle_impl
ImagesWithSpecConsts) {
// Skip images which are not compatible with devices provided
if (std::none_of(get_devices().begin(), get_devices().end(),
[DeviceImageWithDeps](device_impl &Dev) {
return getSyclObjImpl(DeviceImageWithDeps->getMain())
->compatible_with_device(Dev);
[&MainImg = *getSyclObjImpl(
DeviceImageWithDeps->getMain())](device_impl &Dev) {
return MainImg.compatible_with_device(Dev);
}))
continue;

Expand Down Expand Up @@ -1016,9 +1014,10 @@ class kernel_bundle_impl
MContext, KernelID.get_name(), /*PropList=*/{},
SelectedImage->get_ur_program_ref());

ur_program_handle_t UrProgram = SelectedImage->get_ur_program_ref();
return std::make_shared<kernel_impl>(
Kernel, *detail::getSyclObjImpl(MContext), SelectedImage, *this,
ArgMask, SelectedImage->get_ur_program_ref(), CacheMutex);
Kernel, *detail::getSyclObjImpl(MContext), std::move(SelectedImage),
*this, ArgMask, UrProgram, CacheMutex);
}

std::shared_ptr<kernel_impl>
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/kernel_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ kernel_impl::kernel_impl(ur_kernel_handle_t Kernel, context_impl &Context,
}

kernel_impl::kernel_impl(ur_kernel_handle_t Kernel, context_impl &ContextImpl,
DeviceImageImplPtr DeviceImageImpl,
std::shared_ptr<device_image_impl> &&DeviceImageImpl,
const kernel_bundle_impl &KernelBundleImpl,
const KernelArgMask *ArgMask,
ur_program_handle_t Program, std::mutex *CacheMutex)
Expand Down
6 changes: 3 additions & 3 deletions sycl/source/detail/kernel_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class kernel_impl {
/// \param ContextImpl is a valid SYCL context
/// \param KernelBundleImpl is a valid instance of kernel_bundle_impl
kernel_impl(ur_kernel_handle_t Kernel, context_impl &ContextImpl,
DeviceImageImplPtr DeviceImageImpl,
std::shared_ptr<device_image_impl> &&DeviceImageImpl,
const kernel_bundle_impl &KernelBundleImpl,
const KernelArgMask *ArgMask, ur_program_handle_t Program,
std::mutex *CacheMutex);
Expand Down Expand Up @@ -213,7 +213,7 @@ class kernel_impl {
bool isInteropOrSourceBased() const noexcept;
bool hasSYCLMetadata() const noexcept;

const DeviceImageImplPtr &getDeviceImage() const { return MDeviceImageImpl; }
device_image_impl &getDeviceImage() const { return *MDeviceImageImpl; }

ur_native_handle_t getNative() const {
adapter_impl &Adapter = MContext->getAdapter();
Expand Down Expand Up @@ -247,7 +247,7 @@ class kernel_impl {
const std::shared_ptr<context_impl> MContext;
const ur_program_handle_t MProgram = nullptr;
bool MCreatedFromSource = true;
const DeviceImageImplPtr MDeviceImageImpl;
const std::shared_ptr<device_image_impl> MDeviceImageImpl;
const KernelBundleImplPtr MKernelBundleImpl;
bool MIsInterop = false;
mutable std::mutex MNoncacheableEnqueueMutex;
Expand Down
33 changes: 14 additions & 19 deletions sycl/source/detail/program_manager/program_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2493,11 +2493,9 @@ device_image_plain ProgramManager::getDeviceImageFromBinaryImage(
KernelIDs = m_BinImg2KernelIDs[BinImage];
}

DeviceImageImplPtr Impl =
return createSyclObjFromImpl<device_image_plain>(
device_image_impl::create(BinImage, Ctx, Dev, ImgState, KernelIDs,
/*PIProgram=*/nullptr, ImageOriginSYCLOffline);

return createSyclObjFromImpl<device_image_plain>(std::move(Impl));
/*PIProgram=*/nullptr, ImageOriginSYCLOffline));
}

std::vector<DevImgPlainWithDeps>
Expand Down Expand Up @@ -2655,7 +2653,7 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
if (ImgInfoPair.second.RequirementCounter == 0)
continue;

DeviceImageImplPtr MainImpl = device_image_impl::create(
std::shared_ptr<device_image_impl> MainImpl = device_image_impl::create(
ImgInfoPair.first, Ctx, Devs, ImgInfoPair.second.State,
ImgInfoPair.second.KernelIDs, /*PIProgram=*/nullptr,
ImageOriginSYCLOffline);
Expand Down Expand Up @@ -2690,11 +2688,10 @@ ProgramManager::createDependencyImage(const context &Ctx, devices_range Devs,

assert(DepState == getBinImageState(DepImage) &&
"State mismatch between main image and its dependency");
DeviceImageImplPtr DepImpl =
device_image_impl::create(DepImage, Ctx, Devs, DepState, DepKernelIDs,
/*PIProgram=*/nullptr, ImageOriginSYCLOffline);

return createSyclObjFromImpl<device_image_plain>(std::move(DepImpl));
return createSyclObjFromImpl<device_image_plain>(
device_image_impl::create(DepImage, Ctx, Devs, DepState, DepKernelIDs,
/*PIProgram=*/nullptr, ImageOriginSYCLOffline));
}

void ProgramManager::bringSYCLDeviceImageToState(
Expand Down Expand Up @@ -2863,7 +2860,7 @@ ProgramManager::compile(const DevImgPlainWithDeps &ImgWithDeps,

std::optional<detail::KernelCompilerBinaryInfo> RTCInfo =
InputImpl.getRTCInfo();
DeviceImageImplPtr ObjectImpl = device_image_impl::create(
std::shared_ptr<device_image_impl> ObjectImpl = device_image_impl::create(
InputImpl.get_bin_image_ref(), InputImpl.get_context(), Devs,
bundle_state::object, InputImpl.get_kernel_ids_ptr(), Prog,
InputImpl.get_spec_const_data_ref(),
Expand Down Expand Up @@ -3064,15 +3061,14 @@ ProgramManager::link(const std::vector<device_image_plain> &Imgs,
}
auto MergedRTCInfo = detail::KernelCompilerBinaryInfo::Merge(RTCInfoPtrs);

DeviceImageImplPtr ExecutableImpl = device_image_impl::create(
// TODO: Make multiple sets of device images organized by devices they are
// compiled for.
return {createSyclObjFromImpl<device_image_plain>(device_image_impl::create(
NewBinImg, Context, Devs, bundle_state::executable, std::move(KernelIDs),
LinkedProg, std::move(NewSpecConstMap), std::move(NewSpecConstBlob),
CombinedOrigins, std::move(MergedRTCInfo), std::move(MergedKernelNames),
std::move(MergedEliminatedKernelArgMasks), std::move(MergedImageStorage));

// TODO: Make multiple sets of device images organized by devices they are
// compiled for.
return {createSyclObjFromImpl<device_image_plain>(std::move(ExecutableImpl))};
std::move(MergedEliminatedKernelArgMasks),
std::move(MergedImageStorage)))};
}

// The function duplicates most of the code from existing getBuiltPIProgram.
Expand Down Expand Up @@ -3146,13 +3142,12 @@ ProgramManager::build(const DevImgPlainWithDeps &DevImgWithDeps,
}
auto MergedRTCInfo = detail::KernelCompilerBinaryInfo::Merge(RTCInfoPtrs);

DeviceImageImplPtr ExecImpl = device_image_impl::create(
return createSyclObjFromImpl<device_image_plain>(device_image_impl::create(
ResultBinImg, Context, Devs, bundle_state::executable,
std::move(KernelIDs), ResProgram, std::move(SpecConstMap),
std::move(SpecConstBlob), CombinedOrigins, std::move(MergedRTCInfo),
std::move(MergedKernelNames), std::move(MergedEliminatedKernelArgMasks),
std::move(MergedImageStorage));
return createSyclObjFromImpl<device_image_plain>(std::move(ExecImpl));
std::move(MergedImageStorage)));
}

// When caching is enabled, the returned UrKernel will already have
Expand Down
38 changes: 17 additions & 21 deletions sycl/source/detail/scheduler/commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2400,10 +2400,9 @@ static void SetArgBasedOnType(

static ur_result_t SetKernelParamsAndLaunch(
queue_impl &Queue, std::vector<ArgDesc> &Args,
const std::shared_ptr<device_image_impl> &DeviceImageImpl,
ur_kernel_handle_t Kernel, NDRDescT &NDRDesc,
std::vector<ur_event_handle_t> &RawEvents, detail::event_impl *OutEventImpl,
const KernelArgMask *EliminatedArgMask,
device_image_impl *DeviceImageImpl, ur_kernel_handle_t Kernel,
NDRDescT &NDRDesc, std::vector<ur_event_handle_t> &RawEvents,
detail::event_impl *OutEventImpl, const KernelArgMask *EliminatedArgMask,
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
bool IsCooperative, bool KernelUsesClusterLaunch,
uint32_t WorkGroupMemorySize, const RTDeviceBinaryImage *BinImage,
Expand All @@ -2418,8 +2417,7 @@ static ur_result_t SetKernelParamsAndLaunch(
std::vector<unsigned char> Empty;
Kernel = Scheduler::getInstance().completeSpecConstMaterialization(
Queue, BinImage, KernelName,
DeviceImageImpl.get() ? DeviceImageImpl->get_spec_const_blob_ref()
: Empty);
DeviceImageImpl ? DeviceImageImpl->get_spec_const_blob_ref() : Empty);
}

if (KernelFuncPtr && !KernelHasSpecialCaptures) {
Expand Down Expand Up @@ -2449,9 +2447,8 @@ static ur_result_t SetKernelParamsAndLaunch(
} else {
auto setFunc = [&Adapter, Kernel, &DeviceImageImpl, &getMemAllocationFunc,
&Queue](detail::ArgDesc &Arg, size_t NextTrueIndex) {
SetArgBasedOnType(Adapter, Kernel, DeviceImageImpl.get(),
getMemAllocationFunc, Queue.getContextImpl(), Arg,
NextTrueIndex);
SetArgBasedOnType(Adapter, Kernel, DeviceImageImpl, getMemAllocationFunc,
Queue.getContextImpl(), Arg, NextTrueIndex);
};
applyFuncOnFilteredArgs(EliminatedArgMask, Args, setFunc);
}
Expand Down Expand Up @@ -2537,14 +2534,14 @@ static ur_result_t SetKernelParamsAndLaunch(
return Error;
}

static std::tuple<ur_kernel_handle_t, std::shared_ptr<device_image_impl>,
static std::tuple<ur_kernel_handle_t, device_image_impl *,
const KernelArgMask *>
getCGKernelInfo(const CGExecKernel &CommandGroup, context_impl &ContextImpl,
device_impl &DeviceImpl,
std::vector<FastKernelCacheValPtr> &KernelCacheValsToRelease) {

ur_kernel_handle_t UrKernel = nullptr;
std::shared_ptr<device_image_impl> DeviceImageImpl = nullptr;
device_image_impl *DeviceImageImpl = nullptr;
const KernelArgMask *EliminatedArgMask = nullptr;
kernel_bundle_impl *KernelBundleImplPtr = CommandGroup.MKernelBundle.get();

Expand All @@ -2556,7 +2553,7 @@ getCGKernelInfo(const CGExecKernel &CommandGroup, context_impl &ContextImpl,
CommandGroup.MKernelName)
: std::shared_ptr<kernel_impl>{nullptr}) {
UrKernel = SyclKernelImpl->getHandleRef();
DeviceImageImpl = SyclKernelImpl->getDeviceImage();
DeviceImageImpl = &SyclKernelImpl->getDeviceImage();
EliminatedArgMask = SyclKernelImpl->getKernelArgMask();
} else {
FastKernelCacheValPtr FastKernelCacheVal =
Expand All @@ -2568,8 +2565,7 @@ getCGKernelInfo(const CGExecKernel &CommandGroup, context_impl &ContextImpl,
// To keep UrKernel valid, we return FastKernelCacheValPtr.
KernelCacheValsToRelease.push_back(std::move(FastKernelCacheVal));
}
return std::make_tuple(UrKernel, std::move(DeviceImageImpl),
EliminatedArgMask);
return std::make_tuple(UrKernel, DeviceImageImpl, EliminatedArgMask);
}

ur_result_t enqueueImpCommandBufferKernel(
Expand All @@ -2586,7 +2582,7 @@ ur_result_t enqueueImpCommandBufferKernel(
std::vector<FastKernelCacheValPtr> FastKernelCacheValsToRelease;

ur_kernel_handle_t UrKernel = nullptr;
std::shared_ptr<device_image_impl> DeviceImageImpl = nullptr;
device_image_impl *DeviceImageImpl = nullptr;
const KernelArgMask *EliminatedArgMask = nullptr;

context_impl &ContextImpl = *sycl::detail::getSyclObjImpl(Ctx);
Expand All @@ -2610,10 +2606,10 @@ ur_result_t enqueueImpCommandBufferKernel(
}

adapter_impl &Adapter = ContextImpl.getAdapter();
auto SetFunc = [&Adapter, &UrKernel, &DeviceImageImpl, &ContextImpl,
&getMemAllocationFunc](sycl::detail::ArgDesc &Arg,
size_t NextTrueIndex) {
sycl::detail::SetArgBasedOnType(Adapter, UrKernel, DeviceImageImpl.get(),
auto SetFunc = [&Adapter, &UrKernel, &ContextImpl, &getMemAllocationFunc,
DeviceImageImpl](sycl::detail::ArgDesc &Arg,
size_t NextTrueIndex) {
sycl::detail::SetArgBasedOnType(Adapter, UrKernel, DeviceImageImpl,
getMemAllocationFunc, ContextImpl, Arg,
NextTrueIndex);
};
Expand Down Expand Up @@ -2695,7 +2691,7 @@ void enqueueImpKernel(
const KernelArgMask *EliminatedArgMask;

std::shared_ptr<kernel_impl> SyclKernelImpl;
std::shared_ptr<device_image_impl> DeviceImageImpl;
device_image_impl *DeviceImageImpl = nullptr;
FastKernelCacheValPtr KernelCacheVal;

if (nullptr != MSyclKernel) {
Expand All @@ -2717,7 +2713,7 @@ void enqueueImpKernel(
? KernelBundleImplPtr->tryGetKernel(KernelName)
: std::shared_ptr<kernel_impl>{nullptr})) {
Kernel = SyclKernelImpl->getHandleRef();
DeviceImageImpl = SyclKernelImpl->getDeviceImage();
DeviceImageImpl = &SyclKernelImpl->getDeviceImage();

Program = DeviceImageImpl->get_ur_program_ref();

Expand Down
Loading
Loading