Skip to content

Commit 25c5cd9

Browse files
committed
[SYCL][UR] Implement sycl_ext_oneapi_device_wait
This commit implements the UR functionality for device-wide synchronization and the SYCL APIs using it. The latter implements the sycl_ext_oneapi_device_wait extension. Signed-off-by: Larsen, Steffen <[email protected]>
1 parent e76da36 commit 25c5cd9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+869
-31
lines changed

llvm/include/llvm/SYCLLowerIR/DeviceConfigFile.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def Aspectext_oneapi_clock_sub_group : Aspect<"ext_oneapi_clock_sub_group">;
9898
def Aspectext_oneapi_clock_work_group : Aspect<"ext_oneapi_clock_work_group">;
9999
def Aspectext_oneapi_clock_device : Aspect<"ext_oneapi_clock_device">;
100100
def Aspectext_oneapi_is_integrated_gpu : Aspect<"ext_oneapi_is_integrated_gpu">;
101+
def Aspectext_oneapi_device_wait : Aspect<"ext_oneapi_device_wait">;
101102

102103
// Deprecated aspects
103104
def AspectInt64_base_atomics : Aspect<"int64_base_atomics">;
@@ -176,7 +177,8 @@ def : TargetInfo<"__TestAspectList",
176177
Aspectext_oneapi_clock_sub_group,
177178
Aspectext_oneapi_clock_work_group,
178179
Aspectext_oneapi_clock_device,
179-
Aspectext_oneapi_is_integrated_gpu],
180+
Aspectext_oneapi_is_integrated_gpu,
181+
Aspectext_oneapi_device_wait],
180182
[]>;
181183
// This definition serves the only purpose of testing whether the deprecated aspect list defined in here and in SYCL RT
182184
// match.

sycl/include/sycl/device.hpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,21 @@ class __SYCL_STANDALONE_DEBUG __SYCL_EXPORT device
365365
return profile.c_str();
366366
}
367367

368+
/// Synchronizes with all queues associated with the device.
369+
void ext_oneapi_wait();
370+
371+
/// Dispatches all unconsumed asynchronous exceptions for all queues or
372+
/// contexts associated with the queues.
373+
void ext_oneapi_throw_asynchronous();
374+
375+
/// Synchronizes with all queues associated with the device, then dispatches
376+
/// all unconsumed asynchronous exceptions for all queues or contexts
377+
/// associated with the queues.
378+
void ext_oneapi_wait_and_throw() {
379+
ext_oneapi_wait();
380+
ext_oneapi_throw_asynchronous();
381+
}
382+
368383
// TODO: Remove this diagnostics when __SYCL_WARN_IMAGE_ASPECT is removed.
369384
#if defined(__clang__)
370385
#pragma clang diagnostic pop

sycl/include/sycl/info/aspects.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,4 @@ __SYCL_ASPECT(ext_oneapi_clock_sub_group, 91)
8484
__SYCL_ASPECT(ext_oneapi_clock_work_group, 92)
8585
__SYCL_ASPECT(ext_oneapi_clock_device, 93)
8686
__SYCL_ASPECT(ext_oneapi_is_integrated_gpu, 94)
87+
__SYCL_ASPECT(ext_oneapi_device_wait, 95)

sycl/source/detail/device_impl.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,34 @@ device_impl::getImmediateProgressGuarantee(
503503
return forward_progress_guarantee::weakly_parallel;
504504
}
505505

506+
void device_impl::wait() const {
507+
// Firstly, all associated queues should be cleaned through of all
508+
// not-yet-enqueued commands and host_task.
509+
for (const std::weak_ptr<queue_impl> &WQueue : MQueues) {
510+
std::shared_ptr<queue_impl> Queue = WQueue.lock();
511+
assert(Queue && "Queue should never be dangling in the list of queues "
512+
"associated with the device!");
513+
Queue->waitForRuntimeLevelCmdsAndClear();
514+
}
515+
516+
// Then we synchronize the entire device.
517+
getAdapter().call<detail::UrApiKind::urDeviceWaitExp>(getHandleRef());
518+
}
519+
520+
void device_impl::throwAsynchronous() {
521+
std::lock_guard<std::mutex> Lock(MAsyncExceptionsMutex);
522+
for (auto &ExceptionsEntryIt : MAsyncExceptions) {
523+
exception_list Exceptions = std::move(ExceptionsEntryIt.second);
524+
std::shared_ptr<queue_impl> Queue = ExceptionsEntryIt.first.lock();
525+
if (Queue && Queue->getAsynchHandler()) {
526+
Queue->getAsynchHandler()(std::move(Exceptions));
527+
} else {
528+
// If the queue is dead, use the default handler.
529+
defaultAsyncHandler(std::move(Exceptions));
530+
}
531+
}
532+
}
533+
506534
#ifndef __INTEL_PREVIEW_BREAKING_CHANGES
507535
#define EXPORT_GET_INFO(PARAM) \
508536
template <> \

sycl/source/detail/device_impl.hpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1597,6 +1597,10 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
15971597
get_info_impl_nocheck<UR_DEVICE_INFO_IS_INTEGRATED_GPU>().value_or(
15981598
0);
15991599
}
1600+
CASE(ext_oneapi_device_wait) {
1601+
return get_info_impl_nocheck<UR_DEVICE_INFO_DEVICE_WAIT_SUPPORT_EXP>()
1602+
.value_or(0);
1603+
}
16001604
else {
16011605
return false; // This device aspect has not been implemented yet.
16021606
}
@@ -2292,6 +2296,22 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
22922296
return Exceptions;
22932297
}
22942298

2299+
/// Synchronizes with all queues on the device.
2300+
void wait() const;
2301+
2302+
// Dispatch all unconsumed asynchronous exception to the appropriate handlers.
2303+
void throwAsynchronous();
2304+
2305+
void registerQueue(const std::weak_ptr<queue_impl> &Q) {
2306+
std::lock_guard<std::mutex> Lock(MQueuesMutex);
2307+
MQueues.insert(Q);
2308+
}
2309+
2310+
void unregisterQueue(const std::weak_ptr<queue_impl> &Q) {
2311+
std::lock_guard<std::mutex> Lock(MQueuesMutex);
2312+
MQueues.erase(Q);
2313+
}
2314+
22952315
private:
22962316
ur_device_handle_t MDevice = 0;
22972317
// This is used for getAdapter so should be above other properties.
@@ -2302,6 +2322,13 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
23022322

23032323
const ur_device_handle_t MRootDevice;
23042324

2325+
// Devices track a list of active queues on it, to allow for synchronization
2326+
// with host_task and not-yet-enqueued commands.
2327+
std::mutex MQueuesMutex;
2328+
std::set<std::weak_ptr<queue_impl>,
2329+
std::owner_less<std::weak_ptr<queue_impl>>>
2330+
MQueues;
2331+
23052332
// Asynchronous exceptions are captured at device-level until flushed, either
23062333
// by queues, events or a synchronization on the device itself.
23072334
std::mutex MAsyncExceptionsMutex;

sycl/source/detail/queue_impl.cpp

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -889,32 +889,7 @@ void queue_impl::wait(const detail::code_location &CodeLoc) {
889889
LastEvent->wait();
890890
}
891891
} else if (!isInOrder()) {
892-
std::vector<std::weak_ptr<event_impl>> WeakEvents;
893-
{
894-
std::lock_guard<std::mutex> Lock(MMutex);
895-
WeakEvents.swap(MEventsWeak);
896-
MMissedCleanupRequests.unset(
897-
[&](MissedCleanupRequestsType &MissedCleanupRequests) {
898-
for (auto &UpdatedGraph : MissedCleanupRequests)
899-
doUnenqueuedCommandCleanup(UpdatedGraph);
900-
MissedCleanupRequests.clear();
901-
});
902-
}
903-
904-
// Wait for unenqueued or host task events, starting
905-
// from the latest submitted task in order to minimize total amount of
906-
// calls, then handle the rest with urQueueFinish.
907-
for (auto EventImplWeakPtrIt = WeakEvents.rbegin();
908-
EventImplWeakPtrIt != WeakEvents.rend(); ++EventImplWeakPtrIt) {
909-
if (std::shared_ptr<event_impl> EventImplSharedPtr =
910-
EventImplWeakPtrIt->lock()) {
911-
// A nullptr UR event indicates that urQueueFinish will not cover it,
912-
// either because it's a host task event or an unenqueued one.
913-
if (nullptr == EventImplSharedPtr->getHandle()) {
914-
EventImplSharedPtr->wait();
915-
}
916-
}
917-
}
892+
waitForRuntimeLevelCmdsAndClear();
918893
}
919894

920895
getAdapter().call<UrApiKind::urQueueFinish>(getHandleRef());
@@ -1127,6 +1102,47 @@ void queue_impl::verifyProps(const property_list &Props) const {
11271102
CheckPropertiesWithData);
11281103
}
11291104

1105+
void queue_impl::waitForRuntimeLevelCmdsAndClear() {
1106+
if (isInOrder() && !MNoLastEventMode.load(std::memory_order_relaxed)) {
1107+
// if MLastEvent is not null and has no associated handle, we need to wait
1108+
// for it. We do not clear it however.
1109+
EventImplPtr LastEvent;
1110+
{
1111+
std::lock_guard<std::mutex> Lock(MMutex);
1112+
LastEvent = MDefaultGraphDeps.LastEventPtr;
1113+
}
1114+
if (LastEvent && nullptr == LastEvent->getHandle())
1115+
LastEvent->wait();
1116+
} else if (!isInOrder()) {
1117+
std::vector<std::weak_ptr<event_impl>> WeakEvents;
1118+
{
1119+
std::lock_guard<std::mutex> Lock(MMutex);
1120+
WeakEvents.swap(MEventsWeak);
1121+
MMissedCleanupRequests.unset(
1122+
[&](MissedCleanupRequestsType &MissedCleanupRequests) {
1123+
for (auto &UpdatedGraph : MissedCleanupRequests)
1124+
doUnenqueuedCommandCleanup(UpdatedGraph);
1125+
MissedCleanupRequests.clear();
1126+
});
1127+
}
1128+
1129+
// Wait for unenqueued or host task events, starting
1130+
// from the latest submitted task in order to minimize total amount of
1131+
// calls, then handle the rest with urQueueFinish.
1132+
for (auto EventImplWeakPtrIt = WeakEvents.rbegin();
1133+
EventImplWeakPtrIt != WeakEvents.rend(); ++EventImplWeakPtrIt) {
1134+
if (std::shared_ptr<event_impl> EventImplSharedPtr =
1135+
EventImplWeakPtrIt->lock()) {
1136+
// A nullptr UR event indicates that urQueueFinish will not cover it,
1137+
// either because it's a host task event or an unenqueued one.
1138+
if (nullptr == EventImplSharedPtr->getHandle()) {
1139+
EventImplSharedPtr->wait();
1140+
}
1141+
}
1142+
}
1143+
}
1144+
}
1145+
11301146
} // namespace detail
11311147
} // namespace _V1
11321148
} // namespace sycl

sycl/source/detail/queue_impl.hpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,10 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
241241
// `std::shared_ptr` allocations.
242242
template <typename... Ts>
243243
static std::shared_ptr<queue_impl> create(Ts &&...args) {
244-
return std::make_shared<queue_impl>(std::forward<Ts>(args)...,
245-
private_tag{});
244+
auto ImplPtr =
245+
std::make_shared<queue_impl>(std::forward<Ts>(args)..., private_tag{});
246+
ImplPtr->getDeviceImpl().registerQueue(ImplPtr);
247+
return ImplPtr;
246248
}
247249

248250
~queue_impl() {
@@ -253,6 +255,7 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
253255
// notification and destroy the trace event for this queue.
254256
destructorNotification();
255257
#endif
258+
MDevice.unregisterQueue(weak_from_this());
256259
auto status =
257260
getAdapter().call_nocheck<UrApiKind::urQueueRelease>(MQueue);
258261
// If loader is already closed, it'll return a not-initialized status
@@ -704,6 +707,17 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
704707
}
705708
#endif
706709

710+
/// Returns the async_handler associated with the queue.
711+
const async_handler &getAsynchHandler() const noexcept {
712+
return MAsyncHandler;
713+
}
714+
715+
/// Waits for all not-yet-enqueued and host_task commands in the queue and
716+
/// clears the events associated with the queue (if out-of-order.)
717+
/// Note: This should only be called if the queue is guaranteed to be
718+
/// synchronized by the caller.
719+
void waitForRuntimeLevelCmdsAndClear();
720+
707721
protected:
708722
template <typename HandlerType = handler>
709723
EventImplPtr insertHelperBarrier(const HandlerType &Handler) {

sycl/source/detail/ur_device_info_ret_types.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,4 +197,5 @@ MAP(UR_DEVICE_INFO_CLOCK_SUB_GROUP_SUPPORT_EXP, ur_bool_t)
197197
MAP(UR_DEVICE_INFO_CLOCK_WORK_GROUP_SUPPORT_EXP, ur_bool_t)
198198
MAP(UR_DEVICE_INFO_CLOCK_DEVICE_SUPPORT_EXP, ur_bool_t)
199199
MAP(UR_DEVICE_INFO_IS_INTEGRATED_GPU, ur_bool_t)
200+
MAP(UR_DEVICE_INFO_DEVICE_WAIT_SUPPORT_EXP, ur_bool_t)
200201
// clang-format on

sycl/source/device.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,5 +344,9 @@ detail::string device::ext_oneapi_cl_profile_impl() const {
344344
return detail::string{profile};
345345
}
346346

347+
void device::ext_oneapi_wait() { impl->wait(); }
348+
349+
void device::ext_oneapi_throw_asynchronous() { impl->throwAsynchronous(); }
350+
347351
} // namespace _V1
348352
} // namespace sycl

sycl/test-e2e/DeviceWait/basic.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// RUN: %{build} -o %t.out
2+
// RUN: %{run} %t.out
3+
4+
#include <sycl/detail/core.hpp>
5+
#include <sycl/properties/all_properties.hpp>
6+
7+
#include <array>
8+
#include <vector>
9+
10+
constexpr size_t NContexts = 2;
11+
constexpr size_t NQueues = 6;
12+
13+
int main() {
14+
sycl::device D;
15+
std::array<sycl::context, NContexts> Contexts{sycl::context{D},
16+
sycl::context{D}};
17+
std::array<sycl::queue, NQueues> Queues{
18+
sycl::queue{Contexts[0], D},
19+
sycl::queue{Contexts[0], D, sycl::property::queue::in_order()},
20+
sycl::queue{Contexts[0], D},
21+
sycl::queue{Contexts[1], D, sycl::property::queue::in_order()},
22+
sycl::queue{Contexts[1], D},
23+
sycl::queue{Contexts[1], D, sycl::property::queue::in_order()}};
24+
25+
std::vector<sycl::event> Events;
26+
Events.reserve(NQueues);
27+
for (sycl::queue &Q : Queues) {
28+
sycl::event E = Q.single_task([]() {
29+
volatile int value = 1024 * 1024;
30+
while (--value)
31+
;
32+
});
33+
Events.push_back(std::move(E));
34+
}
35+
36+
D.ext_oneapi_wait();
37+
38+
int Failed = 0;
39+
for (size_t I = 0; I < Events.size(); ++I) {
40+
sycl::info::event_command_status EventStatus =
41+
Events[I].get_info<sycl::info::event::command_execution_status>();
42+
if (EventStatus != sycl::info::event_command_status::complete) {
43+
std::cout << "Unexpected event status for event at " << I << std::endl;
44+
++Failed;
45+
}
46+
}
47+
return Failed;
48+
}

0 commit comments

Comments
 (0)