Skip to content

Commit 8f06a1d

Browse files
[SYCL] Add one more WA on Win for shutdown process (#19195)
WA for host task threads handling. We must call join() or detach() on host task execution thread to avoid UB. DLLMain parameter lpReserved == NULL if library is unloaded via FreeLibrary. In this case we can't join threads within DllMain call due to global loader lock and DLL_THREAD_DETACH signalling. lpReserved != NULL if library is unloaded during process termination. In this case Windows terminates threads but leave them in signalled state, prevents DLL_THREAD_DETACH notification and we can call join() as NOP. Note that FreeLibrary called with sycl RT handle does not guarantee which path will be used. Windows can (and actually does) simply postpone actual unloading till the end of program. --------- Signed-off-by: Tikhomirova, Kseniya <[email protected]>
1 parent cbb7dbe commit 8f06a1d

File tree

3 files changed

+64
-12
lines changed

3 files changed

+64
-12
lines changed

sycl/source/detail/global_handler.cpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ using LockGuard = std::lock_guard<SpinLock>;
3838
SpinLock GlobalHandler::MSyclGlobalHandlerProtector{};
3939

4040
// forward decl
41-
void shutdown_early();
41+
void shutdown_early(bool);
4242
void shutdown_late();
4343
#ifdef _WIN32
4444
BOOL isLinkedStatically();
@@ -287,12 +287,12 @@ struct StaticVarShutdownHandler {
287287
// If statically linked, DllMain will not be called. So we do its work
288288
// here.
289289
if (isLinkedStatically()) {
290-
shutdown_early();
290+
shutdown_early(true);
291291
}
292292

293293
shutdown_late();
294294
#else
295-
shutdown_early();
295+
shutdown_early(true);
296296
#endif
297297
} catch (std::exception &e) {
298298
__SYCL_REPORT_EXCEPTION_TO_STREAM(
@@ -347,7 +347,10 @@ void GlobalHandler::drainThreadPool() {
347347
MHostTaskThreadPool.Inst->drain();
348348
}
349349

350-
void shutdown_early() {
350+
// Note: this function can be called on Windows twice:
351+
// 1) when library is unloaded via FreeLibrary
352+
// 2) when process is being terminated
353+
void shutdown_early(bool CanJoinThreads = true) {
351354
const LockGuard Lock{GlobalHandler::MSyclGlobalHandlerProtector};
352355
GlobalHandler *&Handler = GlobalHandler::getInstancePtr();
353356
if (!Handler)
@@ -366,8 +369,10 @@ void shutdown_early() {
366369
// upon its release
367370
Handler->prepareSchedulerToRelease(true);
368371

369-
if (Handler->MHostTaskThreadPool.Inst)
370-
Handler->MHostTaskThreadPool.Inst->finishAndWait();
372+
if (Handler->MHostTaskThreadPool.Inst) {
373+
Handler->MHostTaskThreadPool.Inst->finishAndWait(CanJoinThreads);
374+
Handler->MHostTaskThreadPool.Inst.reset(nullptr);
375+
}
371376

372377
// This releases OUR reference to the default context, but
373378
// other may yet have refs
@@ -428,7 +433,14 @@ extern "C" __SYCL_EXPORT BOOL WINAPI DllMain(HINSTANCE hinstDLL,
428433
std::cout << "---> DLL_PROCESS_DETACH syclx.dll\n" << std::endl;
429434

430435
try {
431-
shutdown_early();
436+
// WA for threads handling. We must call join() or detach() on host task
437+
// execution thread to avoid UB. lpReserved == NULL if library is unloaded
438+
// via FreeLibrary. In this case we can't join threads within DllMain call
439+
// due to global loader lock and DLL_THREAD_DETACH signalling. lpReserved
440+
// != NULL if library is unloaded during process termination. In this case
441+
// Windows terminates threads but leave them in signalled state, prevents
442+
// DLL_THREAD_DETACH notification and we can call join() as NOP.
443+
shutdown_early(lpReserved != NULL);
432444
} catch (std::exception &e) {
433445
__SYCL_REPORT_EXCEPTION_TO_STREAM("exception in DLL_PROCESS_DETACH", e);
434446
return FALSE;

sycl/source/detail/global_handler.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ class GlobalHandler {
9898

9999
bool OkToDefer = true;
100100

101-
friend void shutdown_early();
101+
friend void shutdown_early(bool);
102102
friend void shutdown_late();
103103
friend class ObjectUsageCounter;
104104
static GlobalHandler *&getInstancePtr();

sycl/source/detail/thread_pool.hpp

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,43 @@ class ThreadPool {
3232
bool MStop = false;
3333
std::atomic_uint MJobsInPool;
3434

35+
#ifdef _WIN32
36+
class ThreadExitTracker {
37+
public:
38+
void wait(size_t ThreadCount) {
39+
std::unique_lock<std::mutex> lk(MWorkerExitMutex);
40+
MWorkerExitCV.wait(
41+
lk, [&ThreadCount, this] { return MWorkerExitCount == ThreadCount; });
42+
}
43+
44+
void signalAboutExit() {
45+
{
46+
std::lock_guard<std::mutex> lk(MWorkerExitMutex);
47+
MWorkerExitCount++;
48+
}
49+
MWorkerExitCV.notify_one();
50+
}
51+
52+
private:
53+
std::mutex MWorkerExitMutex;
54+
std::condition_variable MWorkerExitCV;
55+
size_t MWorkerExitCount{};
56+
} WinThreadExitTracker;
57+
#endif
58+
3559
void worker() {
3660
GlobalHandler::instance().registerSchedulerUsage(/*ModifyCounter*/ false);
3761
std::unique_lock<std::mutex> Lock(MJobQueueMutex);
3862
while (true) {
3963
MDoSmthOrStop.wait(Lock,
4064
[this]() { return !MJobQueue.empty() || MStop; });
4165

42-
if (MStop)
43-
break;
66+
if (MStop) {
67+
#ifdef _WIN32
68+
WinThreadExitTracker.signalAboutExit();
69+
#endif
70+
return;
71+
}
4472

4573
std::function<void()> Job = std::move(MJobQueue.front());
4674
MJobQueue.pop();
@@ -76,21 +104,33 @@ class ThreadPool {
76104
~ThreadPool() {
77105
try {
78106
#ifndef _WIN32
79-
finishAndWait();
107+
finishAndWait(true);
80108
#endif
81109
} catch (std::exception &e) {
82110
__SYCL_REPORT_EXCEPTION_TO_STREAM("exception in ~ThreadPool", e);
83111
}
84112
}
85113

86-
void finishAndWait() {
114+
void finishAndWait(bool CanJoinThreads) {
87115
{
88116
std::lock_guard<std::mutex> Lock(MJobQueueMutex);
89117
MStop = true;
90118
}
91119

92120
MDoSmthOrStop.notify_all();
93121

122+
#ifdef _WIN32
123+
if (!CanJoinThreads) {
124+
WinThreadExitTracker.wait(MThreadCount);
125+
for (std::thread &Thread : MLaunchedThreads)
126+
Thread.detach();
127+
return;
128+
}
129+
#else
130+
// We always can join on Linux.
131+
std::ignore = CanJoinThreads;
132+
#endif
133+
94134
for (std::thread &Thread : MLaunchedThreads)
95135
if (Thread.joinable())
96136
Thread.join();

0 commit comments

Comments
 (0)