Skip to content

Commit 66ba535

Browse files
authoredDec 29, 2023
add stream lock to avoid use same steam in multithread at same time (DeepLink-org#793)
* add stream lock to avoid use same steam in multithread at same time
1 parent 6b99930 commit 66ba535

File tree

5 files changed

+65
-12
lines changed

5 files changed

+65
-12
lines changed
 

‎impl/ascend/common/acloprunner.hpp

100755100644
+3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "debug.hpp"
2525
#include "gil_scoped_release.hpp"
2626
#include "impl_functions.hpp"
27+
#include "stream_lock.hpp"
2728
#include "utils.hpp"
2829

2930
namespace impl {
@@ -630,6 +631,7 @@ class AclOpRunner {
630631
diopiGetStream(context_, &stream);
631632
diopi::GilScopedRelease gilReleaeGuard;
632633
if (sync_) {
634+
diopi::StreamLockGuard streamLockGuard(stream);
633635
CALL_ACLRT(aclopCompileAndExecuteV2(opname_.data(),
634636
inputIndex_,
635637
inputDescs_.data(),
@@ -643,6 +645,7 @@ class AclOpRunner {
643645
nullptr,
644646
stream));
645647
} else {
648+
diopi::StreamLockGuard streamLockGuard(stream);
646649
CALL_ACLRT(aclopCompileAndExecute(opname_.data(),
647650
inputIndex_,
648651
inputDescs_.data(),

‎impl/ascend/common/stream_lock.cpp

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#include "stream_lock.hpp"
2+
3+
#include <map>
4+
#include <mutex>
5+
#include <thread>
6+
7+
namespace diopi {
8+
9+
using MutexType = std::mutex;
10+
11+
namespace {
12+
13+
MutexType* getLockForStream(void* aclrtStreamHandle) {
14+
static std::map<void*, MutexType> streamThreadMutexMap;
15+
return &streamThreadMutexMap[aclrtStreamHandle];
16+
}
17+
18+
} // namespace
19+
20+
StreamLockGuard::StreamLockGuard(void* aclrtStreamHandle) {
21+
MutexType* mutexPtr = getLockForStream(aclrtStreamHandle);
22+
mutex_ = mutexPtr;
23+
mutexPtr->lock();
24+
}
25+
26+
StreamLockGuard::~StreamLockGuard() {
27+
MutexType* mutexPtr = static_cast<MutexType*>(mutex_);
28+
mutexPtr->unlock();
29+
}
30+
31+
} // namespace diopi

‎impl/ascend/common/stream_lock.hpp

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#pragma once
2+
3+
namespace diopi {
4+
5+
class StreamLockGuard {
6+
private:
7+
void* mutex_ = nullptr;
8+
9+
public:
10+
explicit StreamLockGuard(void* aclrtStreamHandle);
11+
~StreamLockGuard();
12+
};
13+
14+
} // namespace diopi

‎impl/ascend_npu/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,7 @@ set(OLD_IMPL_SRC
655655
${OLD_IMPL_DIR}/common/format_helper.cpp
656656
${OLD_IMPL_DIR}/common/debug.cpp
657657
${OLD_IMPL_DIR}/common/generator_helper.cpp
658+
${OLD_IMPL_DIR}/common/stream_lock.cpp
658659
${OLD_IMPL_DIR}/env_vars.cpp
659660
)
660661

‎impl/ascend_npu/torch_npu/csrc/DIOPIAdapter.cpp

100644100755
+16-12
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <torch/library.h>
88

99
#include "../../../ascend/common/gil_scoped_release.hpp"
10+
#include "../../../ascend/common/stream_lock.hpp"
1011
#include "diopi_impl/helper.hpp"
1112
#include "op_plugin/AclOpsInterface.h"
1213

@@ -2132,18 +2133,21 @@ aclError OpCommandImpl::InnerRun(const string& name, AclExecParam& params, bool
21322133
}
21332134
}
21342135
#endif
2135-
ret = AclopCompileAndExecuteV2(name.c_str(),
2136-
inputSize,
2137-
const_cast<aclTensorDesc**>(params.inDesc.data()),
2138-
const_cast<aclDataBuffer**>(params.inBuffer.data()),
2139-
outputSize,
2140-
const_cast<aclTensorDesc**>(params.outDesc.data()),
2141-
params.outBuffer.data(),
2142-
params.attr,
2143-
ACL_ENGINE_SYS,
2144-
ACL_COMPILE_SYS,
2145-
NULL,
2146-
stream);
2136+
{
2137+
diopi::StreamLockGuard streamLockGuard(stream.stream());
2138+
ret = AclopCompileAndExecuteV2(name.c_str(),
2139+
inputSize,
2140+
const_cast<aclTensorDesc**>(params.inDesc.data()),
2141+
const_cast<aclDataBuffer**>(params.inBuffer.data()),
2142+
outputSize,
2143+
const_cast<aclTensorDesc**>(params.outDesc.data()),
2144+
params.outBuffer.data(),
2145+
params.attr,
2146+
ACL_ENGINE_SYS,
2147+
ACL_COMPILE_SYS,
2148+
NULL,
2149+
stream);
2150+
}
21472151
NPU_CHECK_ERROR(ret);
21482152
if (sync) {
21492153
int64_t dimSize;

0 commit comments

Comments
 (0)
Please sign in to comment.