Skip to content

Commit 8c8f38a

Browse files
committed
[js/node] allow arenaExtendStrategy and gpuMemLimit for cuda
1 parent ae6dcc8 commit 8c8f38a

File tree

2 files changed

+25
-0
lines changed

2 files changed

+25
-0
lines changed

js/common/lib/inference-session.ts

+9
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,15 @@ export declare namespace InferenceSession {
223223
export interface CudaExecutionProviderOption extends ExecutionProviderOption {
224224
readonly name: 'cuda';
225225
deviceId?: number;
226+
gpuMemLimit?: number;
227+
228+
/**
229+
* Arena extend strategy. See
230+
* https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/arena_extend_strategy.h
231+
*
232+
* This setting is available only in ONNXRuntime (Node.js binding)
233+
*/
234+
arenaExtendStrategy?: 0 | 1;
226235
}
227236
export interface DmlExecutionProviderOption extends ExecutionProviderOption {
228237
readonly name: 'dml';

js/node/src/session_options_helper.cc

+16
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess
4141
Napi::Value epValue = epList[i];
4242
std::string name;
4343
int deviceId = 0;
44+
#ifdef USE_CUDA
45+
onnxruntime::ArenaExtendStrategy arenaExtendStrategy = onnxruntime::ArenaExtendStrategy::kNextPowerOfTwo;
46+
size_t gpuMemLimit = std::numeric_limits<size_t>::max();
47+
#endif
4448
#ifdef USE_COREML
4549
int coreMlFlags = 0;
4650
#endif
@@ -59,6 +63,16 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess
5963
if (obj.Has("deviceId")) {
6064
deviceId = obj.Get("deviceId").As<Napi::Number>();
6165
}
66+
#ifdef USE_CUDA
67+
if (obj.Has("arenaExtendStrategy")) {
68+
arenaExtendStrategy = static_cast<onnxruntime::ArenaExtendStrategy>(
69+
obj.Get("arenaExtendStrategy").As<Napi::Number>().Uint32Value());
70+
}
71+
if (obj.Has("gpuMemLimit")) {
72+
gpuMemLimit = static_cast<size_t>(
73+
obj.Get("gpuMemLimit").As<Napi::Number>().DoubleValue());
74+
}
75+
#endif
6276
#ifdef USE_COREML
6377
if (obj.Has("coreMlFlags")) {
6478
coreMlFlags = obj.Get("coreMlFlags").As<Napi::Number>();
@@ -86,6 +100,8 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess
86100
OrtCUDAProviderOptionsV2* options;
87101
Ort::GetApi().CreateCUDAProviderOptions(&options);
88102
options->device_id = deviceId;
103+
options->arena_extend_strategy = arenaExtendStrategy;
104+
options->gpu_mem_limit = gpuMemLimit;
89105
sessionOptions.AppendExecutionProvider_CUDA_V2(*options);
90106
Ort::GetApi().ReleaseCUDAProviderOptions(options);
91107
#endif

0 commit comments

Comments
 (0)