Skip to content

Commit fcb4866

Browse files
Add session config to return an error if model needs to be compiled (microsoft#24416)
### Description Adds session config option (`"session.disable_model_compile"`) that disables model compilation during session initialization. If this option is set to "1", inference session creation will fail with error code ORT_MODEL_REQUIRES_COMPILATION if compilation is required to run the model on any Execution Provider added to the session. Only the following kinds of models are valid when this option is set to "1": - Pre-compiled models that have EPContext nodes for the compiling Execution Providers in the session. - Non-compiled models that run only on non-compiling Execution Providers, like CPU EP. ### Example usage The following example (taken from a unit test) tries to load a model that requires compilation with a session that disables compilation. The session creation fails with error code `ORT_MODEL_REQUIRES_COMPILATION`. Then, the example compiles the model and loads the compiled model successfully. ```C++ // Taken from a unit test ... // Initialize session options with QNN EP Ort::SessionOptions session_options; ProviderOptions provider_options; provider_options["backend_type"] = "htp"; provider_options["offload_graph_io_quantization"] = "0"; session_options.AppendExecutionProvider("QNN", provider_options); session_options.AddConfigEntry(kOrtSessionOptionsDisableEpCompile, "1"); // Disable model compilation! // Create an inference session that fails with error ORT_MODEL_REQUIRES_COMPILATION try { Ort::Session session(*ort_env, input_model_file, session_options); FAIL() << "Expected Session creation to fail but it succeeded"; // Should not get here! } catch (const Ort::Exception& excpt) { OrtErrorCode error_code = excpt.GetOrtErrorCode(); std::string_view error_msg = excpt.what(); ASSERT_EQ(error_code, ORT_MODEL_REQUIRES_COMPILATION); ASSERT_THAT(error_msg, testing::HasSubstr(kQnnExecutionProvider)); } // Session creation failed because the model was not pre-compiled. // Try to compile it now. // Create model compilation options from the session options. Ort::ModelCompilationOptions compile_options(*ort_env, session_options); compile_options.SetInputModelPath(input_model_file); compile_options.SetOutputModelPath(output_model_file); // Compile the model. Ort::Status status = Ort::CompileModel(*ort_env, compile_options); ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage(); // Should be able to create a session with the compiled model and the original session options. Ort::Session session(*ort_env, output_model_file, session_options); ``` ### Motivation and Context Compiling models can take a very long time. Want to have a session option that requires input models that do not need to be compiled.
1 parent c19a496 commit fcb4866

File tree

6 files changed

+107
-6
lines changed

6 files changed

+107
-6
lines changed

include/onnxruntime/core/common/status.h

+5
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ enum StatusCode {
4545
INVALID_GRAPH = 10,
4646
EP_FAIL = 11,
4747
MODEL_LOAD_CANCELED = 12,
48+
MODEL_REQUIRES_COMPILATION = 13,
4849
};
4950

5051
constexpr const char* StatusCodeToString(StatusCode status) noexcept {
@@ -75,6 +76,8 @@ constexpr const char* StatusCodeToString(StatusCode status) noexcept {
7576
return "EP_FAIL";
7677
case StatusCode::MODEL_LOAD_CANCELED:
7778
return "MODEL_LOAD_CANCELED";
79+
case StatusCode::MODEL_REQUIRES_COMPILATION:
80+
return "MODEL_REQUIRES_COMPILATION";
7881
default:
7982
return "GENERAL ERROR";
8083
}
@@ -109,6 +112,8 @@ constexpr HRESULT StatusCodeToHRESULT(StatusCode status) noexcept {
109112
return HRESULT_FROM_WIN32(ERROR_INTERNAL_ERROR);
110113
case StatusCode::MODEL_LOAD_CANCELED:
111114
return HRESULT_FROM_WIN32(ERROR_CANCELLED);
115+
case StatusCode::MODEL_REQUIRES_COMPILATION:
116+
return HRESULT_FROM_WIN32(ERROR_NOT_SUPPORTED);
112117
default:
113118
return E_FAIL;
114119
}

include/onnxruntime/core/session/onnxruntime_c_api.h

+1
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ typedef enum OrtErrorCode {
256256
ORT_INVALID_GRAPH,
257257
ORT_EP_FAIL,
258258
ORT_MODEL_LOAD_CANCELED,
259+
ORT_MODEL_REQUIRES_COMPILATION,
259260
} OrtErrorCode;
260261

261262
typedef enum OrtOpAttrType {

include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h

+16
Original file line numberDiff line numberDiff line change
@@ -348,3 +348,19 @@ static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "sessio
348348
// “Default”: OS determines the scheduling priority and processor performance to service this workload. [Default]
349349
// “Efficient”: OS treats this workload is efficiency oriented with low scheduling priority and efficient processor performance.
350350
static const char* const kOrtEpDynamicOptionsWorkloadType = "ep.dynamic.workload_type";
351+
352+
// Disables model compilation during session initialization.
353+
//
354+
// If this option is set to "1", inference session creation will fail with error code ORT_MODEL_REQUIRES_COMPILATION
355+
// if compilation is required to run the model on any Execution Provider added to the session.
356+
// Only the following kinds of models are valid when this option is set to "1":
357+
// - Pre-compiled models that have EPContext nodes for the compiling Execution Providers in the session.
358+
// - Non-compiled models that run only on non-compiling Execution Providers, like CPU EP.
359+
//
360+
// See \href https://onnxruntime.ai/docs/execution-providers/EP-Context-Design.html for details about
361+
// compiled models with EPContext nodes.
362+
//
363+
// Option values:
364+
// - "0": EP compile is not disabled. [DEFAULT]
365+
// - "1": EP compile is disabled.
366+
static const char* const kOrtSessionOptionsDisableModelCompile = "session.disable_model_compile";

onnxruntime/core/framework/graph_partitioner.cc

+27-6
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
423423
const layout_transformation::DebugGraphFn& debug_graph_fn,
424424
const CheckLoadCancellationFn& check_load_cancellation_fn,
425425
const logging::Logger& logger, IResourceAccountant* resource_accountant,
426-
const GraphOptimizerRegistry& graph_optimizer_registry) {
426+
const GraphOptimizerRegistry& graph_optimizer_registry,
427+
bool disable_model_compile) {
427428
// handle testing edge case where optimizers or constant lifting results in graph with no nodes.
428429
// doing it here saves all providers checking for this in GetCapability
429430
if (graph.NumberOfNodes() == 0) {
@@ -440,7 +441,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
440441
transform_layout_fn, debug_graph_fn,
441442
check_load_cancellation_fn,
442443
logger, resource_accountant,
443-
graph_optimizer_registry));
444+
graph_optimizer_registry, disable_model_compile));
444445
}
445446
}
446447

@@ -529,6 +530,16 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
529530
}
530531
}
531532

533+
// Helper function that returns true if any of the nodes assigned to a compiling EP is not already compiled.
534+
auto graph_viewer_has_non_compiled_node = [](const GraphViewer& graph_viewer) -> bool {
535+
for (const auto& node : graph_viewer.Nodes()) {
536+
if (node.OpType() != "EPContext") {
537+
return true;
538+
}
539+
}
540+
return false;
541+
};
542+
532543
// NOTE: if mode_ is kAssignOnly, nodes_to_compile will be empty at this point due to logic in PlaceNode
533544
// even with single node, EP might still want to compile it.
534545
// for example, it want to JIT an optimized kernel for LSTM with a given shape.
@@ -548,7 +559,14 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
548559
for (size_t j = 0, end = nodes_to_compile.size(); j < end; j++) {
549560
auto* node = nodes_to_compile[j];
550561
const auto& cur_capability = *capabilities_to_compile[j];
551-
viewers.push_back(std::make_unique<GraphViewer>(graph, *cur_capability.sub_graph));
562+
auto graph_viewer = std::make_unique<GraphViewer>(graph, *cur_capability.sub_graph);
563+
564+
if (disable_model_compile && graph_viewer_has_non_compiled_node(*graph_viewer)) {
565+
return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_REQUIRES_COMPILATION, "User disabled EP compilation but EP '",
566+
type, "' needs to compile one or more nodes.");
567+
}
568+
569+
viewers.push_back(std::move(graph_viewer));
552570
nodes_and_viewers.push_back(IExecutionProvider::FusedNodeAndGraph{*node, *viewers.back()});
553571
}
554572

@@ -896,7 +914,7 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params,
896914
KernelRegistryManager& kernel_registry_manager,
897915
const std::optional<ResourceAccountantMap>& acc_map,
898916
const GraphOptimizerRegistry& graph_optimizer_registry,
899-
const logging::Logger& logger) {
917+
const logging::Logger& logger, bool disable_model_compile) {
900918
bool modified_graph = false;
901919

902920
auto& graph = partition_params.graph.get();
@@ -921,7 +939,8 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params,
921939
transform_layout_function,
922940
partition_params.debug_graph_fn,
923941
check_load_cancellation_fn,
924-
logger, resource_accountant, graph_optimizer_registry));
942+
logger, resource_accountant, graph_optimizer_registry,
943+
disable_model_compile));
925944
}
926945

927946
// expand any nodes that have an ONNX function definition but no matching ORT kernel.
@@ -1195,8 +1214,10 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
11951214
std::optional<ResourceAccountantMap> ep_acc_map;
11961215
ORT_RETURN_IF_ERROR(NodeStatsRecorder::CreateAccountants(config_options, graph.ModelPath(), ep_acc_map));
11971216

1217+
bool disable_model_compile = config_options.GetConfigOrDefault(kOrtSessionOptionsDisableModelCompile, "0") == "1";
11981218
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode, providers_, kernel_registry_mgr_,
1199-
ep_acc_map, *graph_optimizer_registry_, logger));
1219+
ep_acc_map, *graph_optimizer_registry_, logger,
1220+
disable_model_compile));
12001221

12011222
if (ep_context_gen_options.enable) {
12021223
ORT_RETURN_IF_ERROR(CreateEpContextModel(providers_, graph, ep_context_gen_options, logger));

onnxruntime/core/session/model_compilation_options.cc

+2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ ModelCompilationOptions::ModelCompilationOptions(const OrtEnv& env, const OrtSes
2323

2424
// Shouldn't fail because the key/value strings are below the maximum string length limits in ConfigOptions.
2525
ORT_ENFORCE(session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1").IsOK());
26+
ORT_ENFORCE(session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionsDisableModelCompile, "0").IsOK());
2627
}
2728

2829
void ModelCompilationOptions::SetInputModelPath(const std::string& input_model_path) {
@@ -195,6 +196,7 @@ Status ModelCompilationOptions::CheckOutputModelSettings() const {
195196

196197
Status ModelCompilationOptions::Check() const {
197198
ORT_ENFORCE(session_options_.value.ep_context_gen_options.enable);
199+
ORT_ENFORCE(session_options_.value.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableModelCompile, "0") == "0");
198200
ORT_RETURN_IF_ERROR(CheckInputModelSettings());
199201
ORT_RETURN_IF_ERROR(CheckOutputModelSettings());
200202
return Status::OK();

onnxruntime/test/providers/qnn/qnn_ep_context_test.cc

+56
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,62 @@ static void CheckEpContextNodeCounts(void* model_buffer, size_t model_buffer_siz
272272
std::filesystem::remove(output_model_path);
273273
}
274274

275+
// Test workflow that:
276+
// - Creates session that disables EP compilation.
277+
// - Session creation fails because input model is not pre-compiled.
278+
// - Uses OrtCompileApi to compile the model.
279+
// - Recreates session with the compiled model.
280+
TEST_F(QnnHTPBackendTests, CompileApi_DisableEpCompile_ThenCompileExplicitly) {
281+
const ORTCHAR_T* input_model_file = ORT_TSTR("./compileapi_disable_compile_input.onnx");
282+
const ORTCHAR_T* output_model_file = ORT_TSTR("./compileapi_disable_compile_output.onnx");
283+
std::filesystem::remove(input_model_file);
284+
std::filesystem::remove(output_model_file);
285+
286+
// Create a test model and save it to a file.
287+
TestModel test_model;
288+
CreateTestModel(BuildGraphWithQAndNonQ(false), 21, logging::Severity::kERROR, test_model);
289+
ASSERT_STATUS_OK(test_model.Save(input_model_file));
290+
291+
// Initialize session options with QNN EP
292+
Ort::SessionOptions so;
293+
ProviderOptions provider_options;
294+
provider_options["backend_type"] = "htp";
295+
provider_options["offload_graph_io_quantization"] = "0";
296+
297+
so.AppendExecutionProvider("QNN", provider_options);
298+
so.AddConfigEntry(kOrtSessionOptionsDisableModelCompile, "1"); // Disable model compilation!
299+
300+
// Create an inference session that fails with error ORT_MODEL_REQUIRES_COMPILATION
301+
try {
302+
Ort::Session session(*ort_env, input_model_file, so);
303+
FAIL() << "Expected Session creation to fail but it succeeded"; // Should not get here!
304+
} catch (const Ort::Exception& excpt) {
305+
OrtErrorCode error_code = excpt.GetOrtErrorCode();
306+
std::string_view error_msg = excpt.what();
307+
ASSERT_EQ(error_code, ORT_MODEL_REQUIRES_COMPILATION);
308+
ASSERT_THAT(error_msg, testing::HasSubstr(kQnnExecutionProvider));
309+
}
310+
311+
// Session creation failed because the model was not pre-compiled.
312+
// Try to compile it now.
313+
314+
// Create model compilation options from the session options.
315+
Ort::ModelCompilationOptions compile_options(*ort_env, so);
316+
compile_options.SetInputModelPath(input_model_file);
317+
compile_options.SetOutputModelPath(output_model_file);
318+
319+
// Compile the model.
320+
Ort::Status status = Ort::CompileModel(*ort_env, compile_options);
321+
ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage();
322+
323+
// Make sure the compiled model was generated and has the expected number of EPContext nodes.
324+
ASSERT_TRUE(std::filesystem::exists(output_model_file));
325+
CheckEpContextNodeCounts(output_model_file, 2, 2);
326+
327+
// Should be able to create a session with the compiled model and the original session options.
328+
EXPECT_NO_THROW((Ort::Session(*ort_env, output_model_file, so)));
329+
}
330+
275331
// Test using the CompileModel() API with settings:
276332
// - input model file
277333
// - output model file

0 commit comments

Comments
 (0)