Skip to content

Commit cd9c02f

Browse files
Allow EpContext models with input/output models completely in buffers (microsoft#24463)
### Description Re-enables (and fixes) generation of compiled EpContext models with **both** input and output models stored in buffers. ### Motivation and Context Previous PR microsoft#24176 inadvertently added a check that disabled storing both input and output models in buffers. However, we need this functionality. This was actually a fortunate scenario, as it led to the discovery of a bug.
1 parent fcb4866 commit cd9c02f

File tree

4 files changed

+101
-15
lines changed

4 files changed

+101
-15
lines changed

onnxruntime/core/framework/graph_partitioner.cc

+11-7
Original file line numberDiff line numberDiff line change
@@ -818,11 +818,17 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers
818818
return std::make_pair(false, static_cast<const Node*>(nullptr));
819819
};
820820

821+
bool saving_to_buffer = ep_context_gen_options.output_model_buffer_ptr != nullptr &&
822+
ep_context_gen_options.output_model_buffer_size_ptr != nullptr &&
823+
ep_context_gen_options.output_model_buffer_allocator != nullptr;
824+
821825
std::filesystem::path context_cache_path;
822-
ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(ep_context_gen_options.output_model_file_path,
823-
graph.ModelPath(),
824-
context_cache_path,
825-
ep_context_gen_options.overwrite_existing_output_file));
826+
if (!saving_to_buffer || !graph.ModelPath().empty()) {
827+
ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(ep_context_gen_options.output_model_file_path,
828+
graph.ModelPath(),
829+
context_cache_path,
830+
ep_context_gen_options.overwrite_existing_output_file));
831+
}
826832

827833
Model ep_context_model(graph.Name(), false, graph.GetModel().MetaData(),
828834
graph.GetModel().ModelPath(), // use source model path so that external initializers can find the data file path
@@ -882,9 +888,7 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers
882888

883889
ModelSavingOptions model_saving_options{ini_size_threshold};
884890

885-
if (ep_context_gen_options.output_model_buffer_ptr != nullptr &&
886-
ep_context_gen_options.output_model_buffer_size_ptr != nullptr &&
887-
ep_context_gen_options.output_model_buffer_allocator != nullptr) {
891+
if (saving_to_buffer) {
888892
ORT_RETURN_IF_ERROR(ep_context_model.MainGraph().Resolve());
889893
// TODO(adrianlizarraga): Investigate if we can make this more memory efficient.
890894
// May be able to use allocator to directly allocate the ModelProto to avoid a copy.

onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc

+7
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,13 @@ Status CreateEPContextNodes(Model* model,
247247
} else {
248248
context_bin_path = context_model_path;
249249
}
250+
251+
if (context_bin_path.empty()) {
252+
// Context bin path is empty, so just use the graph name (e.g., "QNNExecutionProvider_QNN_13728744673520368385_2_0").
253+
// This happens if both the input model and output model are stored in buffers (i.e., there are no paths).
254+
context_bin_path = ToPathString(graph_name);
255+
}
256+
250257
context_bin_path = context_bin_path + ToPathString("_qnn.bin");
251258
context_cache_name = std::filesystem::path(context_bin_path).filename().string();
252259

onnxruntime/core/session/utils.cc

+11-6
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,18 @@ OrtStatus* CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options,
5252

5353
// If ep.context_enable is set, then ep.context_file_path is expected, otherwise ORT don't know where to generate the _ctx.onnx file
5454
if (options && model_path == nullptr) {
55-
auto ep_context_enable = options->value.config_options.GetConfigEntry(kOrtSessionOptionEpContextEnable);
56-
auto ep_context_file_path = options->value.config_options.GetConfigEntry(kOrtSessionOptionEpContextFilePath);
57-
if (ep_context_enable.has_value() && ep_context_enable.value() == "1" && (!ep_context_file_path.has_value() || (ep_context_file_path.has_value() && ep_context_file_path.value().empty()))) {
55+
EpContextModelGenerationOptions ep_ctx_gen_options = options->value.GetEpContextGenerationOptions();
56+
57+
// This is checked by the OrtCompileApi's CompileModel() function, but we check again here in case
58+
// the user used the older SessionOptions' configuration entries to generate a compiled model.
59+
if (ep_ctx_gen_options.enable &&
60+
ep_ctx_gen_options.output_model_file_path.empty() &&
61+
ep_ctx_gen_options.output_model_buffer_ptr == nullptr) {
5862
return OrtApis::CreateStatus(ORT_FAIL,
59-
"CreateSessionFromArray is called with ep.context_enable enabled but an \
60-
empty ep.context_file_path. The system does not know where to generate the \
61-
EP context model. Please specify a valid ep.context_file_path.");
63+
"Inference session was configured with EPContext model generation enabled but "
64+
"without a valid location (e.g., file or buffer) for the output model. "
65+
"Please specify a valid ep.context_file_path via SessionOption configs "
66+
"or use the OrtCompileApi to compile a model to a file or buffer.");
6267
}
6368
}
6469

onnxruntime/test/providers/qnn/qnn_ep_context_test.cc

+72-2
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,75 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_OutputModelBuffer) {
436436

437437
// Check that the compiled model has the expected number of EPContext nodes.
438438
CheckEpContextNodeCounts(output_model_buffer, output_model_buffer_size, 2, 2);
439+
allocator.Free(output_model_buffer);
440+
}
441+
442+
// Test using the CompileModel() API with settings:
443+
// - input model from buffer
444+
// - save output model to buffer
445+
// - test enabling AND disabling embed mode for context binary in EPContext node attributes
446+
TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputAndOutputModelsInBuffers) {
447+
// Create a test model and serialize it to a buffer.
448+
TestModel test_model;
449+
CreateTestModel(BuildGraphWithQAndNonQ(false), 21, logging::Severity::kERROR, test_model);
450+
std::string model_data = test_model.Serialize();
451+
452+
// Initialize session options with QNN EP
453+
Ort::SessionOptions session_options;
454+
ProviderOptions provider_options;
455+
provider_options["backend_type"] = "htp";
456+
provider_options["offload_graph_io_quantization"] = "0";
457+
session_options.AppendExecutionProvider("QNN", provider_options);
458+
459+
Ort::AllocatorWithDefaultOptions allocator;
460+
461+
// Test embed mode enabled.
462+
{
463+
void* output_model_buffer = nullptr;
464+
size_t output_model_buffer_size = 0;
465+
466+
// Create model compilation options from the session options.
467+
Ort::ModelCompilationOptions compile_options(*ort_env, session_options);
468+
compile_options.SetInputModelFromBuffer(reinterpret_cast<const void*>(model_data.data()), model_data.size());
469+
compile_options.SetOutputModelBuffer(allocator, &output_model_buffer, &output_model_buffer_size);
470+
compile_options.SetEpContextEmbedMode(true);
471+
472+
// Compile the model.
473+
Ort::Status status = Ort::CompileModel(*ort_env, compile_options);
474+
ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage();
475+
476+
// Make sure the compiled model was saved to the buffer.
477+
ASSERT_TRUE(output_model_buffer != nullptr);
478+
ASSERT_TRUE(output_model_buffer_size > 0);
479+
480+
// Check that the compiled model has the expected number of EPContext nodes.
481+
CheckEpContextNodeCounts(output_model_buffer, output_model_buffer_size, 2, 2);
482+
allocator.Free(output_model_buffer);
483+
}
484+
485+
// Test embed mode disabled.
486+
{
487+
void* output_model_buffer = nullptr;
488+
size_t output_model_buffer_size = 0;
489+
490+
// Create model compilation options from the session options.
491+
Ort::ModelCompilationOptions compile_options(*ort_env, session_options);
492+
compile_options.SetInputModelFromBuffer(reinterpret_cast<const void*>(model_data.data()), model_data.size());
493+
compile_options.SetOutputModelBuffer(allocator, &output_model_buffer, &output_model_buffer_size);
494+
compile_options.SetEpContextEmbedMode(false);
495+
496+
// Compile the model.
497+
Ort::Status status = Ort::CompileModel(*ort_env, compile_options);
498+
ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage();
499+
500+
// Make sure the compiled model was saved to the buffer.
501+
ASSERT_TRUE(output_model_buffer != nullptr);
502+
ASSERT_TRUE(output_model_buffer_size > 0);
503+
504+
// Check that the compiled model has the expected number of EPContext nodes.
505+
CheckEpContextNodeCounts(output_model_buffer, output_model_buffer_size, 2, 2);
506+
allocator.Free(output_model_buffer);
507+
}
439508
}
440509

441510
// Test using the CompileModel() API with settings:
@@ -485,6 +554,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_OutputModelBuffer_Outpu
485554

486555
// Check that the compiled model has the expected number of EPContext nodes.
487556
CheckEpContextNodeCounts(output_model_buffer, output_model_buffer_size, 2, 2);
557+
allocator.Free(output_model_buffer);
488558
}
489559

490560
// Test that models with 1 non-quantized FusedMatMul node and 1 quantized Add node can still generate the context binary
@@ -1566,7 +1636,7 @@ TEST_F(QnnHTPBackendTests, LoadFromArrayWithQnnEpContextGenPathValidation) {
15661636
ORT_CATCH(const std::exception& e) {
15671637
ORT_HANDLE_EXCEPTION([&e]() {
15681638
std::string e_message1(std::string(e.what()));
1569-
ASSERT_TRUE(e_message1.find("Please specify a valid ep.context_file_path.") != std::string::npos);
1639+
ASSERT_TRUE(e_message1.find("Please specify a valid ep.context_file_path") != std::string::npos);
15701640
});
15711641
}
15721642

@@ -1577,7 +1647,7 @@ TEST_F(QnnHTPBackendTests, LoadFromArrayWithQnnEpContextGenPathValidation) {
15771647
ORT_CATCH(const std::exception& ex) {
15781648
ORT_HANDLE_EXCEPTION([&ex]() {
15791649
std::string e_message2(std::string(ex.what()));
1580-
ASSERT_TRUE(e_message2.find("Please specify a valid ep.context_file_path.") != std::string::npos);
1650+
ASSERT_TRUE(e_message2.find("Please specify a valid ep.context_file_path") != std::string::npos);
15811651
});
15821652
}
15831653
}

0 commit comments

Comments
 (0)