@@ -41,6 +41,10 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess
41
41
Napi::Value epValue = epList[i];
42
42
std::string name;
43
43
int deviceId = 0 ;
44
+ #ifdef USE_CUDA
45
+ onnxruntime::ArenaExtendStrategy arenaExtendStrategy = onnxruntime::ArenaExtendStrategy::kNextPowerOfTwo ;
46
+ size_t gpuMemLimit = std::numeric_limits<size_t >::max ();
47
+ #endif
44
48
#ifdef USE_COREML
45
49
int coreMlFlags = 0 ;
46
50
#endif
@@ -59,6 +63,16 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess
59
63
if (obj.Has (" deviceId" )) {
60
64
deviceId = obj.Get (" deviceId" ).As <Napi::Number>();
61
65
}
66
+ #ifdef USE_CUDA
67
+ if (obj.Has (" arenaExtendStrategy" )) {
68
+ arenaExtendStrategy = static_cast <onnxruntime::ArenaExtendStrategy>(
69
+ obj.Get (" arenaExtendStrategy" ).As <Napi::Number>().Uint32Value ());
70
+ }
71
+ if (obj.Has (" gpuMemLimit" )) {
72
+ gpuMemLimit = static_cast <size_t >(
73
+ obj.Get (" gpuMemLimit" ).As <Napi::Number>().DoubleValue ());
74
+ }
75
+ #endif
62
76
#ifdef USE_COREML
63
77
if (obj.Has (" coreMlFlags" )) {
64
78
coreMlFlags = obj.Get (" coreMlFlags" ).As <Napi::Number>();
@@ -86,6 +100,8 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess
86
100
OrtCUDAProviderOptionsV2* options;
87
101
Ort::GetApi ().CreateCUDAProviderOptions (&options);
88
102
options->device_id = deviceId;
103
+ options->arena_extend_strategy = arenaExtendStrategy;
104
+ options->gpu_mem_limit = gpuMemLimit;
89
105
sessionOptions.AppendExecutionProvider_CUDA_V2 (*options);
90
106
Ort::GetApi ().ReleaseCUDAProviderOptions (options);
91
107
#endif
0 commit comments