Skip to content

Commit e1ca8ee

Browse files
RunAsync C/CXX API (microsoft#16613)
Implement RunAsync API - the session will run in a thread of intra-op thread pool. --------- Co-authored-by: Randy Shuai <[email protected]>
1 parent 2cf31a2 commit e1ca8ee

File tree

8 files changed

+308
-66
lines changed

8 files changed

+308
-66
lines changed

include/onnxruntime/core/session/onnxruntime_c_api.h

+30
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,15 @@ typedef void (*OrtCustomJoinThreadFn)(OrtCustomThreadHandle ort_custom_thread_ha
696696

697697
typedef OrtStatus*(ORT_API_CALL* RegisterCustomOpsFn)(OrtSessionOptions* options, const OrtApiBase* api);
698698

699+
/** \brief Callback function for RunAsync
700+
*
701+
* \param[in] user_data User specific data that passed back to the callback
702+
* \param[out] outputs On succeed, outputs host inference results, on error, the value will be nullptr
703+
* \param[out] num_outputs Number of outputs, on error, the value will be zero
704+
* \param[out] status On error, status will provide details
705+
*/
706+
typedef void (*RunAsyncCallbackFn)(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr status);
707+
699708
/** \brief The C API
700709
*
701710
* All C API functions are defined inside this structure as pointers to functions.
@@ -4316,6 +4325,27 @@ struct OrtApi {
43164325
*/
43174326
ORT_API2_STATUS(CreateAndRegisterAllocatorV2, _Inout_ OrtEnv* env, _In_ const char* provider_type, _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg,
43184327
_In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys);
4328+
4329+
/** \brief Run the model asynchronously in a thread owned by intra op thread pool
4330+
*
4331+
* \param[in] session
4332+
* \param[in] run_options If nullptr, will use a default ::OrtRunOptions
4333+
* \param[in] input_names Array of null terminated UTF8 encoded strings of the input names
4334+
* \param[in] input Array of ::OrtValue%s of the input values
4335+
* \param[in] input_len Number of elements in the input_names and inputs arrays
4336+
* \param[in] output_names Array of null terminated UTF8 encoded strings of the output names
4337+
* \param[in] output_names_len Number of elements in the output_names and outputs array
4338+
* \param[out] output Array of OrtValue* owned by customers, size to output_names_len. It could simply be an array of nullptr
4339+
* The array will be passed back to run_async_callback
4340+
* \param[in] run_async_callback Callback function on model run completion
4341+
* \param[in] user_data User data that pass back to run_async_callback
4342+
*/
4343+
ORT_API2_STATUS(RunAsync, _Inout_ OrtSession* session, _In_opt_ const OrtRunOptions* run_options,
4344+
_In_reads_(input_len) const char* const* input_names,
4345+
_In_reads_(input_len) const OrtValue* const* input, size_t input_len,
4346+
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
4347+
_Inout_updates_all_(output_names_len) OrtValue** output,
4348+
_In_ RunAsyncCallbackFn run_async_callback, _In_opt_ void* user_data);
43194349
};
43204350

43214351
/*

include/onnxruntime/core/session/onnxruntime_cxx_api.h

+18
Original file line numberDiff line numberDiff line change
@@ -1067,6 +1067,24 @@ struct SessionImpl : ConstSessionImpl<T> {
10671067

10681068
void Run(const RunOptions& run_options, const IoBinding&); ///< Wraps OrtApi::RunWithBinding
10691069

1070+
/** \brief Run the model asynchronously in a thread owned by intra op thread pool
1071+
*
1072+
* Wraps OrtApi::RunAsync
1073+
*
1074+
* \param[in] run_options
1075+
* \param[in] input_names Array of null terminated UTF8 encoded strings of the input names
1076+
* \param[in] input_values Array of ::OrtValue%s of the input values
1077+
* \param[in] input_count Number of elements in the input_names and inputs arrays
1078+
* \param[in] output_names Array of null terminated UTF8 encoded strings of the output names
1079+
* \param[out] output_values Array of ::OrtValue%s owned by customers, size to output_count. It could simply be an array of nullptr
1080+
* The array will be passed back to the callback
1081+
* \param[in] output_count Number of elements in the output_names and outputs array
1082+
* \param[in] callback Callback function on model run completion
1083+
* \param[in] user_data User data that pass back to the callback
1084+
*/
1085+
void RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
1086+
const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data);
1087+
10701088
/** \brief End profiling and return a copy of the profiling file name.
10711089
*
10721090
* \param allocator to allocate memory for the copy of the string returned

include/onnxruntime/core/session/onnxruntime_cxx_inline.h

+10
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,16 @@ inline void SessionImpl<T>::Run(const RunOptions& run_options, const IoBinding&
972972
ThrowOnError(GetApi().RunWithBinding(this->p_, run_options, io_binding));
973973
}
974974

975+
template <typename T>
976+
inline void SessionImpl<T>::RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
977+
const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data) {
978+
auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
979+
auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
980+
ThrowOnError(GetApi().RunAsync(this->p_, run_options, input_names,
981+
ort_input_values, input_count, output_names, output_count,
982+
ort_output_values, callback, user_data));
983+
}
984+
975985
template <typename T>
976986
inline AllocatedStringPtr SessionImpl<T>::EndProfilingAllocated(OrtAllocator* allocator) {
977987
char* out = nullptr;

onnxruntime/core/session/inference_session.cc

+110
Original file line numberDiff line numberDiff line change
@@ -2300,6 +2300,116 @@ Status InferenceSession::Run(const RunOptions& run_options,
23002300
return retval;
23012301
}
23022302

2303+
Status InferenceSession::Run(const RunOptions& run_options,
2304+
gsl::span<const char* const> feed_names,
2305+
gsl::span<const OrtValue* const> feeds,
2306+
gsl::span<const char* const> fetch_names,
2307+
gsl::span<OrtValue*> fetches) {
2308+
size_t num_feeds = feed_names.size();
2309+
size_t num_fetches = fetch_names.size();
2310+
InlinedVector<std::string> feed_name_vec;
2311+
feed_name_vec.reserve(num_feeds);
2312+
InlinedVector<OrtValue> feed_vec;
2313+
feed_vec.reserve(num_feeds);
2314+
2315+
for (size_t i = 0; i != num_feeds; ++i) {
2316+
if (feed_names[i] == nullptr || feed_names[i][0] == '\0') {
2317+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "input name cannot be empty");
2318+
}
2319+
2320+
if (!feeds[i]) {
2321+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, MakeString("NULL input supplied for input ", feed_names[i]).c_str());
2322+
}
2323+
2324+
feed_name_vec.emplace_back(feed_names[i]);
2325+
feed_vec.emplace_back(*feeds[i]);
2326+
}
2327+
2328+
// Create output feed
2329+
InlinedVector<std::string> fetch_name_vec;
2330+
fetch_name_vec.reserve(num_fetches);
2331+
for (size_t i = 0; i != num_fetches; ++i) {
2332+
if (fetch_names[i] == nullptr || fetch_names[i][0] == '\0') {
2333+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "output name cannot be empty");
2334+
}
2335+
fetch_name_vec.emplace_back(fetch_names[i]);
2336+
}
2337+
2338+
std::vector<OrtValue> fetch_vec;
2339+
fetch_vec.reserve(num_fetches);
2340+
for (size_t i = 0; i != num_fetches; ++i) {
2341+
if (fetches[i] != nullptr) {
2342+
fetch_vec.emplace_back(*fetches[i]);
2343+
} else {
2344+
fetch_vec.emplace_back();
2345+
}
2346+
}
2347+
2348+
Status status;
2349+
status = Run(run_options, feed_name_vec, feed_vec, fetch_name_vec, &fetch_vec, nullptr);
2350+
2351+
if (!status.IsOK())
2352+
return status;
2353+
2354+
// We do it in two loops to make sure copy __ctors does not throw
2355+
InlinedVector<std::unique_ptr<OrtValue>> fetch_unique_ptrs;
2356+
fetch_unique_ptrs.reserve(num_fetches);
2357+
for (size_t i = 0; i != num_fetches; ++i) {
2358+
if (fetches[i] == nullptr) {
2359+
fetch_unique_ptrs.emplace_back(std::make_unique<OrtValue>(fetch_vec[i]));
2360+
} else {
2361+
fetch_unique_ptrs.emplace_back();
2362+
}
2363+
}
2364+
2365+
for (size_t i = 0; i != num_fetches; ++i) {
2366+
if (fetches[i] == nullptr) {
2367+
ORT_ENFORCE(fetch_unique_ptrs[i] != nullptr);
2368+
fetches[i] = fetch_unique_ptrs[i].release();
2369+
}
2370+
}
2371+
return Status::OK();
2372+
}
2373+
2374+
common::Status InferenceSession::RunAsync(const RunOptions* run_options,
2375+
gsl::span<const char* const> feed_names,
2376+
gsl::span<const OrtValue* const> feeds,
2377+
gsl::span<const char* const> fetch_names,
2378+
gsl::span<OrtValue*> fetches,
2379+
RunAsyncCallbackFn callback,
2380+
void* user_data) {
2381+
size_t num_fetches = fetch_names.size();
2382+
if (!thread_pool_.get() || concurrency::ThreadPool::DegreeOfParallelism(thread_pool_.get()) < 2) {
2383+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "intra op thread pool must have at least one thread for RunAsync");
2384+
}
2385+
std::function<void()> run_fn = [=]() {
2386+
ORT_TRY {
2387+
Status status;
2388+
if (run_options) {
2389+
status = Run(*run_options, feed_names, feeds, fetch_names, fetches);
2390+
} else {
2391+
RunOptions default_run_options;
2392+
status = Run(default_run_options, feed_names, feeds, fetch_names, fetches);
2393+
}
2394+
if (status.IsOK()) {
2395+
callback(user_data, fetches.data(), num_fetches, ToOrtStatus(status));
2396+
} else {
2397+
callback(user_data, {}, 0, ToOrtStatus(status));
2398+
}
2399+
}
2400+
ORT_CATCH(const std::exception& ex) {
2401+
ORT_HANDLE_EXCEPTION([=]() {
2402+
callback(user_data, {}, 0, ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, ex.what())));
2403+
});
2404+
}
2405+
ORT_CATCH(...) {
2406+
callback(user_data, {}, 0, ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "unknown exception")));
2407+
}
2408+
}; // run_fn
2409+
concurrency::ThreadPool::Schedule(thread_pool_.get(), run_fn);
2410+
return Status::OK();
2411+
}
2412+
23032413
common::Status InferenceSession::Run(const NameMLValMap& feeds, gsl::span<const std::string> output_names,
23042414
std::vector<OrtValue>* p_fetches) {
23052415
return Run(RunOptions(), feeds, output_names, p_fetches);

onnxruntime/core/session/inference_session.h

+14
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,20 @@ class InferenceSession {
305305
std::vector<OrtValue>* p_fetches,
306306
const std::vector<OrtDevice>* p_fetches_device_info = nullptr);
307307

308+
[[nodiscard]] common::Status Run(const RunOptions& run_options,
309+
gsl::span<const char* const> feed_names,
310+
gsl::span<const OrtValue* const> feeds,
311+
gsl::span<const char* const> fetch_names,
312+
gsl::span<OrtValue*> fetches);
313+
314+
[[nodiscard]] common::Status RunAsync(const RunOptions* run_options,
315+
gsl::span<const char* const> feed_names,
316+
gsl::span<const OrtValue* const> feeds,
317+
gsl::span<const char* const> fetch_names,
318+
gsl::span<OrtValue*> fetches,
319+
RunAsyncCallbackFn callback,
320+
void* user_data = nullptr);
321+
308322
/**
309323
* Run a pre-loaded and pre-intialized model.
310324
* Multiple threads are allowed to run this function; hence its thread-safe.

onnxruntime/core/session/onnxruntime_c_api.cc

+40-64
Original file line numberDiff line numberDiff line change
@@ -817,81 +817,56 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArray, _In_ const OrtEnv* env, _In
817817
ORT_API_STATUS_IMPL(OrtApis::Run, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options,
818818
_In_reads_(input_len) const char* const* input_names,
819819
_In_reads_(input_len) const OrtValue* const* input, size_t input_len,
820-
_In_reads_(output_names_len) const char* const* output_names1, size_t output_names_len,
820+
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
821821
_Inout_updates_all_(output_names_len) OrtValue** output) {
822822
API_IMPL_BEGIN
823823
auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess);
824824

825-
InlinedVector<std::string> feed_names;
826-
feed_names.reserve(input_len);
827-
InlinedVector<OrtValue> feeds;
828-
feeds.reserve(input_len);
829-
830-
for (size_t i = 0; i != input_len; ++i) {
831-
if (input_names[i] == nullptr || input_names[i][0] == '\0') {
832-
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "input name cannot be empty");
833-
}
834-
835-
if (!input[i]) {
836-
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
837-
MakeString("NULL input supplied for input ", input_names[i]).c_str());
838-
}
839-
840-
feed_names.emplace_back(input_names[i]);
841-
feeds.emplace_back(*input[i]);
842-
}
843-
844-
// Create output feed
845-
InlinedVector<std::string> output_names;
846-
output_names.reserve(output_names_len);
847-
for (size_t i = 0; i != output_names_len; ++i) {
848-
if (output_names1[i] == nullptr || output_names1[i][0] == '\0') {
849-
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "output name cannot be empty");
850-
}
851-
output_names.emplace_back(output_names1[i]);
852-
}
853-
854-
std::vector<OrtValue> fetches;
855-
fetches.reserve(output_names_len);
856-
for (size_t i = 0; i != output_names_len; ++i) {
857-
if (output[i] != nullptr) {
858-
fetches.emplace_back(*output[i]);
859-
} else {
860-
fetches.emplace_back();
861-
}
862-
}
825+
gsl::span<const char* const> input_names_span(input_names, input_len);
826+
gsl::span<const OrtValue* const> input_span(input, input_len);
827+
gsl::span<const char* const> output_name_span(output_names, output_names_len);
828+
gsl::span<OrtValue*> output_span(output, output_names_len);
863829

864830
Status status;
865-
if (run_options == nullptr) {
866-
OrtRunOptions op;
867-
status = session->Run(op, feed_names, feeds, output_names, &fetches, nullptr);
831+
if (run_options) {
832+
status = session->Run(*run_options,
833+
input_names_span,
834+
input_span,
835+
output_name_span,
836+
output_span);
868837
} else {
869-
status = session->Run(*run_options, feed_names, feeds, output_names, &fetches, nullptr);
838+
const RunOptions default_run_options;
839+
status = session->Run(default_run_options,
840+
input_names_span,
841+
input_span,
842+
output_name_span,
843+
output_span);
870844
}
845+
return ToOrtStatus(status);
846+
API_IMPL_END
847+
}
871848

872-
if (!status.IsOK())
873-
return ToOrtStatus(status);
874-
875-
// We do it in two loops to make sure copy __ctors does not throw
876-
InlinedVector<std::unique_ptr<OrtValue>> output_unique_ptrs;
877-
output_unique_ptrs.reserve(output_names_len);
878-
for (size_t i = 0; i != output_names_len; ++i) {
879-
if (output[i] == nullptr) {
880-
output_unique_ptrs.emplace_back(std::make_unique<OrtValue>(fetches[i]));
881-
} else {
882-
output_unique_ptrs.emplace_back();
883-
}
884-
}
849+
ORT_API_STATUS_IMPL(OrtApis::RunAsync, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options,
850+
_In_reads_(input_len) const char* const* input_names,
851+
_In_reads_(input_len) const OrtValue* const* input, size_t input_len,
852+
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
853+
_Inout_updates_all_(output_names_len) OrtValue** output,
854+
_In_ RunAsyncCallbackFn run_async_callback, _In_opt_ void* user_data) {
855+
API_IMPL_BEGIN
856+
auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess);
885857

886-
assert(output_unique_ptrs.size() == output_names_len);
858+
gsl::span<const char* const> input_names_span(input_names, input_len);
859+
gsl::span<const OrtValue* const> input_span(input, input_len);
860+
gsl::span<const char* const> output_name_span(output_names, output_names_len);
861+
gsl::span<OrtValue*> output_span(output, output_names_len);
887862

888-
for (size_t i = 0; i != output_names_len; ++i) {
889-
if (output[i] == nullptr) {
890-
assert(output_unique_ptrs[i] != nullptr);
891-
output[i] = output_unique_ptrs[i].release();
892-
}
893-
}
894-
return nullptr;
863+
return ToOrtStatus(session->RunAsync(run_options,
864+
input_names_span,
865+
input_span,
866+
output_name_span,
867+
output_span,
868+
run_async_callback,
869+
user_data));
895870
API_IMPL_END
896871
}
897872

@@ -2735,6 +2710,7 @@ static constexpr OrtApi ort_api_1_to_16 = {
27352710
&OrtApis::GetROCMProviderOptionsAsString,
27362711
&OrtApis::ReleaseROCMProviderOptions,
27372712
&OrtApis::CreateAndRegisterAllocatorV2,
2713+
&OrtApis::RunAsync,
27382714
};
27392715

27402716
// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.

onnxruntime/core/session/ort_apis.h

+7
Original file line numberDiff line numberDiff line change
@@ -478,4 +478,11 @@ ORT_API(void, ReleaseROCMProviderOptions, _Frees_ptr_opt_ OrtROCMProviderOptions
478478

479479
ORT_API_STATUS_IMPL(CreateAndRegisterAllocatorV2, _Inout_ OrtEnv* env, _In_ const char* provider_type, _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg,
480480
_In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys);
481+
482+
ORT_API_STATUS_IMPL(RunAsync, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options,
483+
_In_reads_(input_len) const char* const* input_names,
484+
_In_reads_(input_len) const OrtValue* const* input, size_t input_len,
485+
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
486+
_Inout_updates_all_(output_names_len) OrtValue** outputs,
487+
_In_ RunAsyncCallbackFn run_async_callback, _In_opt_ void* user_data);
481488
} // namespace OrtApis

0 commit comments

Comments
 (0)