Skip to content

Commit dde3359

Browse files
[NFC][SYCL] Change context_impl::getDevices to return devices_range (#19456)
`devices_range` helper was added in #19405 to facilitate ongoing refactoring of using more raw ptr/ref to `*_impl` objects. Also update some of the callsites that can use that instead of more expensive `get_info<info::context::devices>()`.
1 parent 9c3a9f5 commit dde3359

File tree

8 files changed

+36
-43
lines changed

8 files changed

+36
-43
lines changed

sycl/source/detail/context_impl.hpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,7 @@ class context_impl : public std::enable_shared_from_this<context_impl> {
130130
/// \return an instance of raw UR context handle.
131131
const ur_context_handle_t &getHandleRef() const;
132132

133-
/// Unlike `get_info<info::context::devices>', this function returns a
134-
/// reference.
135-
const std::vector<device> &getDevices() const { return MDevices; }
133+
devices_range getDevices() const { return MDevices; }
136134

137135
using CachedLibProgramsT =
138136
std::map<std::pair<DeviceLibExt, ur_device_handle_t>,

sycl/source/detail/device_global_map_entry.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ DeviceGlobalMapEntry::getOrAllocateDeviceGlobalUSM(const context &Context) {
104104
"USM allocations should not be acquired for device_global with "
105105
"device_image_scope property.");
106106
context_impl &CtxImpl = *getSyclObjImpl(Context);
107-
device_impl &DevImpl = *getSyclObjImpl(CtxImpl.getDevices().front());
107+
device_impl &DevImpl = CtxImpl.getDevices().front();
108108
std::lock_guard<std::mutex> Lock(MDeviceToUSMPtrMapMutex);
109109

110110
auto DGUSMPtr = MDeviceToUSMPtrMap.find({&DevImpl, &CtxImpl});
@@ -153,9 +153,8 @@ DeviceGlobalMapEntry::getOrAllocateDeviceGlobalUSM(const context &Context) {
153153
void DeviceGlobalMapEntry::removeAssociatedResources(
154154
const context_impl *CtxImpl) {
155155
std::lock_guard<std::mutex> Lock{MDeviceToUSMPtrMapMutex};
156-
for (device Device : CtxImpl->getDevices()) {
157-
auto USMPtrIt =
158-
MDeviceToUSMPtrMap.find({getSyclObjImpl(Device).get(), CtxImpl});
156+
for (device_impl &Device : CtxImpl->getDevices()) {
157+
auto USMPtrIt = MDeviceToUSMPtrMap.find({&Device, CtxImpl});
159158
if (USMPtrIt != MDeviceToUSMPtrMap.end()) {
160159
DeviceGlobalUSMMem &USMMem = USMPtrIt->second;
161160
detail::usm::freeInternal(USMMem.MPtr, CtxImpl);

sycl/source/detail/helpers.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ template <typename iterator> class iterator_range {
102102
iterator_range(IterTy Begin, IterTy End, size_t Size)
103103
: Begin(Begin), End(End), Size(Size) {}
104104

105+
iterator_range()
106+
: iterator_range(static_cast<value_type *>(nullptr),
107+
static_cast<value_type *>(nullptr), 0) {}
108+
105109
template <typename ContainerTy>
106110
iterator_range(const ContainerTy &Container)
107111
: iterator_range(Container.begin(), Container.end(), Container.size()) {}

sycl/source/detail/image_impl.cpp

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,11 @@ uint8_t GImageStreamID;
2323
#endif
2424

2525
template <typename Param>
26-
static bool checkImageValueRange(const std::vector<device> &Devices,
27-
const size_t Value) {
28-
return Value >= 1 && std::all_of(Devices.cbegin(), Devices.cend(),
29-
[Value](const device &Dev) {
30-
return Value <= Dev.get_info<Param>();
31-
});
26+
static bool checkImageValueRange(devices_range Devices, const size_t Value) {
27+
return Value >= 1 &&
28+
std::all_of(Devices.begin(), Devices.end(), [Value](device_impl &Dev) {
29+
return Value <= Dev.get_info<Param>();
30+
});
3231
}
3332

3433
template <typename T, typename... Args> static bool checkAnyImpl(T) {
@@ -345,46 +344,47 @@ void *image_impl::allocateMem(context_impl *Context, bool InitFromUserData,
345344

346345
bool image_impl::checkImageDesc(const ur_image_desc_t &Desc,
347346
context_impl *Context, void *UserPtr) {
347+
devices_range Devices = Context ? Context->getDevices() : devices_range{};
348348
if (checkAny(Desc.type, UR_MEM_TYPE_IMAGE1D, UR_MEM_TYPE_IMAGE1D_ARRAY,
349349
UR_MEM_TYPE_IMAGE2D_ARRAY, UR_MEM_TYPE_IMAGE2D) &&
350-
!checkImageValueRange<info::device::image2d_max_width>(
351-
getDevices(Context), Desc.width))
350+
!checkImageValueRange<info::device::image2d_max_width>(Devices,
351+
Desc.width))
352352
throw exception(make_error_code(errc::invalid),
353353
"For a 1D/2D image/image array, the width must be a Value "
354354
">= 1 and <= info::device::image2d_max_width");
355355

356356
if (checkAny(Desc.type, UR_MEM_TYPE_IMAGE3D) &&
357-
!checkImageValueRange<info::device::image3d_max_width>(
358-
getDevices(Context), Desc.width))
357+
!checkImageValueRange<info::device::image3d_max_width>(Devices,
358+
Desc.width))
359359
throw exception(make_error_code(errc::invalid),
360360
"For a 3D image, the width must be a Value >= 1 and <= "
361361
"info::device::image3d_max_width");
362362

363363
if (checkAny(Desc.type, UR_MEM_TYPE_IMAGE2D, UR_MEM_TYPE_IMAGE2D_ARRAY) &&
364-
!checkImageValueRange<info::device::image2d_max_height>(
365-
getDevices(Context), Desc.height))
364+
!checkImageValueRange<info::device::image2d_max_height>(Devices,
365+
Desc.height))
366366
throw exception(make_error_code(errc::invalid),
367367
"For a 2D image or image array, the height must be a Value "
368368
">= 1 and <= info::device::image2d_max_height");
369369

370370
if (checkAny(Desc.type, UR_MEM_TYPE_IMAGE3D) &&
371-
!checkImageValueRange<info::device::image3d_max_height>(
372-
getDevices(Context), Desc.height))
371+
!checkImageValueRange<info::device::image3d_max_height>(Devices,
372+
Desc.height))
373373
throw exception(make_error_code(errc::invalid),
374374
"For a 3D image, the heightmust be a Value >= 1 and <= "
375375
"info::device::image3d_max_height");
376376

377377
if (checkAny(Desc.type, UR_MEM_TYPE_IMAGE3D) &&
378-
!checkImageValueRange<info::device::image3d_max_depth>(
379-
getDevices(Context), Desc.depth))
378+
!checkImageValueRange<info::device::image3d_max_depth>(Devices,
379+
Desc.depth))
380380
throw exception(make_error_code(errc::invalid),
381381
"For a 3D image, the depth must be a Value >= 1 and <= "
382382
"info::device::image2d_max_depth");
383383

384384
if (checkAny(Desc.type, UR_MEM_TYPE_IMAGE1D_ARRAY,
385385
UR_MEM_TYPE_IMAGE2D_ARRAY) &&
386-
!checkImageValueRange<info::device::image_max_array_size>(
387-
getDevices(Context), Desc.arraySize))
386+
!checkImageValueRange<info::device::image_max_array_size>(Devices,
387+
Desc.arraySize))
388388
throw exception(make_error_code(errc::invalid),
389389
"For a 1D and 2D image array, the array_size must be a "
390390
"Value >= 1 and <= info::device::image_max_array_size.");
@@ -451,12 +451,6 @@ bool image_impl::checkImageFormat(const ur_image_format_t &Format,
451451
return true;
452452
}
453453

454-
std::vector<device> image_impl::getDevices(context_impl *Context) {
455-
if (!Context)
456-
return {};
457-
return Context->get_info<info::context::devices>();
458-
}
459-
460454
void image_impl::sampledImageConstructorNotification(
461455
const detail::code_location &CodeLoc, void *UserObj, const void *HostObj,
462456
uint32_t Dim, size_t Range[3], image_format Format,

sycl/source/detail/image_impl.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class accessor;
3535
class handler;
3636

3737
namespace detail {
38+
class devices_range;
3839

3940
// utility functions and typedefs for image_impl
4041
using image_allocator = aligned_allocator<byte>;
@@ -297,8 +298,6 @@ class image_impl final : public SYCLMemObjT {
297298
void unsampledImageDestructorNotification(void *UserObj);
298299

299300
private:
300-
std::vector<device> getDevices(context_impl *Context);
301-
302301
ur_mem_type_t getImageType() {
303302
if (MDimensions == 1)
304303
return (MIsArrayImage ? UR_MEM_TYPE_IMAGE1D_ARRAY : UR_MEM_TYPE_IMAGE1D);

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,10 @@ static bool isDeviceBinaryTypeSupported(context_impl &ContextImpl,
124124
if (ContextBackend == backend::ext_oneapi_cuda)
125125
return false;
126126

127-
const std::vector<device> &Devices = ContextImpl.getDevices();
127+
devices_range Devices = ContextImpl.getDevices();
128128

129129
// Program type is SPIR-V, so we need a device compiler to do JIT.
130-
for (const device &D : Devices) {
130+
for (device_impl &D : Devices) {
131131
if (!D.get_info<info::device::is_compiler_available>())
132132
return false;
133133
}
@@ -143,7 +143,7 @@ static bool isDeviceBinaryTypeSupported(context_impl &ContextImpl,
143143
return true;
144144
}
145145

146-
for (const device &D : Devices) {
146+
for (device_impl &D : Devices) {
147147
// We need cl_khr_il_program extension to be present
148148
// and we can call clCreateProgramWithILKHR using the extension
149149
std::vector<std::string> Extensions =

sycl/source/detail/scheduler/graph_builder.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -213,10 +213,9 @@ Scheduler::GraphBuilder::getOrInsertMemObjRecord(queue_impl *Queue,
213213
// which means that there is already an allocation(cl_mem) in some context.
214214
// Registering this allocation in the SYCL graph.
215215

216-
std::vector<sycl::device> Devices =
217-
InteropCtxPtr->get_info<info::context::devices>();
218-
assert(Devices.size() != 0);
219-
device_impl &Dev = *detail::getSyclObjImpl(Devices[0]);
216+
devices_range Devices = InteropCtxPtr->getDevices();
217+
assert(!Devices.empty());
218+
device_impl &Dev = Devices.front();
220219

221220
// Since all the Scheduler commands require queue but we have only context
222221
// here, we need to create a dummy queue bound to the context and one of the
@@ -675,7 +674,7 @@ static bool checkHostUnifiedMemory(context_impl *Ctx) {
675674
if (Ctx == nullptr)
676675
return true;
677676

678-
for (const device &Device : Ctx->getDevices()) {
677+
for (device_impl &Device : Ctx->getDevices()) {
679678
if (!Device.get_info<info::device::host_unified_memory>())
680679
return false;
681680
}

sycl/source/detail/usm/usm_impl.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -581,13 +581,13 @@ device get_pointer_device(const void *Ptr, const context &Ctxt) {
581581

582582
// Check if ptr is a host allocation
583583
if (get_pointer_type(Ptr, Ctxt) == alloc::host) {
584-
auto Devs = detail::getSyclObjImpl(Ctxt)->getDevices();
584+
detail::devices_range Devs = detail::getSyclObjImpl(Ctxt)->getDevices();
585585
if (Devs.size() == 0)
586586
throw exception(make_error_code(errc::invalid),
587587
"No devices in passed context!");
588588

589589
// Just return the first device in the context
590-
return Devs[0];
590+
return detail::createSyclObjFromImpl<device>(Devs.front());
591591
}
592592

593593
ur_device_handle_t DeviceId;

0 commit comments

Comments
 (0)