-
Notifications
You must be signed in to change notification settings - Fork 3.2k
[TRT RTX EP] Implement GetEPContextNodes() #24901
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[TRT RTX EP] Implement GetEPContextNodes() #24901
Conversation
cc @gedoensmax @ankan-ban @ishwar-raut1 @chilo-ms @jywu-msft to review |
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline,Windows x64 QNN CI Pipeline |
Azure Pipelines successfully started running 5 pipeline(s). |
@@ -266,6 +268,7 @@ class NvExecutionProvider : public IExecutionProvider { | |||
std::string cache_prefix_; | |||
std::string op_types_to_exclude_; | |||
int nv_profile_index_ = 0; | |||
std::vector<std::unique_ptr<onnxruntime::Model>> ep_context_models_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -72,7 +72,8 @@ ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer, | |||
const int64_t embed_mode, | |||
const std::string compute_capability, | |||
const std::string onnx_model_path, | |||
const logging::Logger* logger) { | |||
const logging::Logger* logger, | |||
std::vector<std::unique_ptr<onnxruntime::Model>>& ep_context_models, const std::string& ep_context_node_name) { | |||
auto model_build = graph_viewer.CreateModel(*logger); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems wired to me. Why create a model from a existing graph_viewer? how the lifecycle is controlled with the existing graph_viewer? ep_context_model instance need to be valid as long as the EP instance is alive.
All EP need to do with GetEpContextNodes() is, the EP creates the model instance to hold all EPContext nodes for the graph partitioner to query. EP only need to add all EPContext nodes into that model instance.
onnxruntime/onnxruntime/core/providers/qnn/qnn_execution_provider.cc
Lines 1179 to 1191 in 9705b17
qnn_ep_context_model_ = Factory<Model>::Create(std::string{"qnn_ep_context_model"}, false, logger); | |
ORT_RETURN_IF_ERROR(qnn::CreateEPContextNodes(qnn_ep_context_model_.get(), | |
context_buffer.get(), | |
buffer_size, | |
qnn_backend_manager_->GetSdkVersion(), | |
fused_nodes_and_graphs, | |
qnn_models_, | |
context_model_path, | |
qnn_context_embed_mode_, | |
max_spill_fill_buffer_size, | |
logger, | |
share_ep_contexts_, | |
stop_share_ep_contexts_)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name of the CreateCtxModel()
might be a bit confusing.
It has some historic reasons
- At the beginning of TRT EP implementing EP Context feature, it didn't follow the rule to implement
GetEpContextNodes()
- it only supported a model only containing a single EP Context node, meaning the whole graph can be run by TRT. There is no partitioning.
- It didn't implement EP API's
GetEpContextNodes()
, instead it directly creates the ONNX model, so that's why this function has this name to really create a onnx model and dump to file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right now, since RTX EP and TRT EP are going to follow EP Context implementation rule to implement GetEpContextNodes()
, we could rename it to avoid confusion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sharing some source between TRT and TRT RTX seems like a good idea but will be kind of awkward. Any opinion on this ? Otherwise let's duplicate changes and source for now.
There was a fix for the Web CI pipeline, please merge the code from latest main branch. |
When running with TRT RTX EP, will it also handle the case where model contains contrib ops that will be fallback to run on CUDA EP or CPU? I assume it will. If that's the case, following change need to be added to this PR as well.
|
If we will be filling the |
Implements GetEPContextNodes()