@@ -817,81 +817,56 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArray, _In_ const OrtEnv* env, _In
817
817
ORT_API_STATUS_IMPL (OrtApis::Run, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options,
818
818
_In_reads_ (input_len) const char* const * input_names,
819
819
_In_reads_(input_len) const OrtValue* const * input, size_t input_len,
820
- _In_reads_(output_names_len) const char* const * output_names1 , size_t output_names_len,
820
+ _In_reads_(output_names_len) const char* const * output_names , size_t output_names_len,
821
821
_Inout_updates_all_(output_names_len) OrtValue** output) {
822
822
API_IMPL_BEGIN
823
823
auto session = reinterpret_cast <::onnxruntime::InferenceSession*>(sess);
824
824
825
- InlinedVector<std::string> feed_names;
826
- feed_names.reserve (input_len);
827
- InlinedVector<OrtValue> feeds;
828
- feeds.reserve (input_len);
829
-
830
- for (size_t i = 0 ; i != input_len; ++i) {
831
- if (input_names[i] == nullptr || input_names[i][0 ] == ' \0 ' ) {
832
- return OrtApis::CreateStatus (ORT_INVALID_ARGUMENT, " input name cannot be empty" );
833
- }
834
-
835
- if (!input[i]) {
836
- return OrtApis::CreateStatus (ORT_INVALID_ARGUMENT,
837
- MakeString (" NULL input supplied for input " , input_names[i]).c_str ());
838
- }
839
-
840
- feed_names.emplace_back (input_names[i]);
841
- feeds.emplace_back (*input[i]);
842
- }
843
-
844
- // Create output feed
845
- InlinedVector<std::string> output_names;
846
- output_names.reserve (output_names_len);
847
- for (size_t i = 0 ; i != output_names_len; ++i) {
848
- if (output_names1[i] == nullptr || output_names1[i][0 ] == ' \0 ' ) {
849
- return OrtApis::CreateStatus (ORT_INVALID_ARGUMENT, " output name cannot be empty" );
850
- }
851
- output_names.emplace_back (output_names1[i]);
852
- }
853
-
854
- std::vector<OrtValue> fetches;
855
- fetches.reserve (output_names_len);
856
- for (size_t i = 0 ; i != output_names_len; ++i) {
857
- if (output[i] != nullptr ) {
858
- fetches.emplace_back (*output[i]);
859
- } else {
860
- fetches.emplace_back ();
861
- }
862
- }
825
+ gsl::span<const char * const > input_names_span (input_names, input_len);
826
+ gsl::span<const OrtValue* const > input_span (input, input_len);
827
+ gsl::span<const char * const > output_name_span (output_names, output_names_len);
828
+ gsl::span<OrtValue*> output_span (output, output_names_len);
863
829
864
830
Status status;
865
- if (run_options == nullptr ) {
866
- OrtRunOptions op;
867
- status = session->Run (op, feed_names, feeds, output_names, &fetches, nullptr );
831
+ if (run_options) {
832
+ status = session->Run (*run_options,
833
+ input_names_span,
834
+ input_span,
835
+ output_name_span,
836
+ output_span);
868
837
} else {
869
- status = session->Run (*run_options, feed_names, feeds, output_names, &fetches, nullptr );
838
+ const RunOptions default_run_options;
839
+ status = session->Run (default_run_options,
840
+ input_names_span,
841
+ input_span,
842
+ output_name_span,
843
+ output_span);
870
844
}
845
+ return ToOrtStatus (status);
846
+ API_IMPL_END
847
+ }
871
848
872
- if (!status.IsOK ())
873
- return ToOrtStatus (status);
874
-
875
- // We do it in two loops to make sure copy __ctors does not throw
876
- InlinedVector<std::unique_ptr<OrtValue>> output_unique_ptrs;
877
- output_unique_ptrs.reserve (output_names_len);
878
- for (size_t i = 0 ; i != output_names_len; ++i) {
879
- if (output[i] == nullptr ) {
880
- output_unique_ptrs.emplace_back (std::make_unique<OrtValue>(fetches[i]));
881
- } else {
882
- output_unique_ptrs.emplace_back ();
883
- }
884
- }
849
+ ORT_API_STATUS_IMPL (OrtApis::RunAsync, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options,
850
+ _In_reads_ (input_len) const char* const * input_names,
851
+ _In_reads_(input_len) const OrtValue* const * input, size_t input_len,
852
+ _In_reads_(output_names_len) const char* const * output_names, size_t output_names_len,
853
+ _Inout_updates_all_(output_names_len) OrtValue** output,
854
+ _In_ RunAsyncCallbackFn run_async_callback, _In_opt_ void* user_data) {
855
+ API_IMPL_BEGIN
856
+ auto session = reinterpret_cast <::onnxruntime::InferenceSession*>(sess);
885
857
886
- assert (output_unique_ptrs.size () == output_names_len);
858
+ gsl::span<const char * const > input_names_span (input_names, input_len);
859
+ gsl::span<const OrtValue* const > input_span (input, input_len);
860
+ gsl::span<const char * const > output_name_span (output_names, output_names_len);
861
+ gsl::span<OrtValue*> output_span (output, output_names_len);
887
862
888
- for ( size_t i = 0 ; i != output_names_len; ++i) {
889
- if (output[i] == nullptr ) {
890
- assert (output_unique_ptrs[i] != nullptr );
891
- output[i] = output_unique_ptrs[i]. release ();
892
- }
893
- }
894
- return nullptr ;
863
+ return ToOrtStatus (session-> RunAsync (run_options,
864
+ input_names_span,
865
+ input_span,
866
+ output_name_span,
867
+ output_span,
868
+ run_async_callback,
869
+ user_data)) ;
895
870
API_IMPL_END
896
871
}
897
872
@@ -2735,6 +2710,7 @@ static constexpr OrtApi ort_api_1_to_16 = {
2735
2710
&OrtApis::GetROCMProviderOptionsAsString,
2736
2711
&OrtApis::ReleaseROCMProviderOptions,
2737
2712
&OrtApis::CreateAndRegisterAllocatorV2,
2713
+ &OrtApis::RunAsync,
2738
2714
};
2739
2715
2740
2716
// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
0 commit comments