Skip to content

[TRT RTX EP] Implement GetEPContextNodes() #24901

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

thevishalagarwal
Copy link

Implements GetEPContextNodes()

@thevishalagarwal
Copy link
Author

@HectorSVC
Copy link
Contributor

/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline,Windows x64 QNN CI Pipeline

Copy link

Azure Pipelines successfully started running 5 pipeline(s).

@HectorSVC HectorSVC added the ep:NvRTX NV RTX execution provider label May 29, 2025
@@ -266,6 +268,7 @@ class NvExecutionProvider : public IExecutionProvider {
std::string cache_prefix_;
std::string op_types_to_exclude_;
int nv_profile_index_ = 0;
std::vector<std::unique_ptr<onnxruntime::Model>> ep_context_models_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ep_context_models_

The model instance here is just help to hold the EPContext nodes. Keep one instance is enough, why need a list of models?

@@ -72,7 +72,8 @@ ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer,
const int64_t embed_mode,
const std::string compute_capability,
const std::string onnx_model_path,
const logging::Logger* logger) {
const logging::Logger* logger,
std::vector<std::unique_ptr<onnxruntime::Model>>& ep_context_models, const std::string& ep_context_node_name) {
auto model_build = graph_viewer.CreateModel(*logger);
Copy link
Contributor

@HectorSVC HectorSVC May 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto model_build = graph_viewer.CreateModel(*logger);

Seems wired to me. Why create a model from a existing graph_viewer? how the lifecycle is controlled with the existing graph_viewer? ep_context_model instance need to be valid as long as the EP instance is alive.
All EP need to do with GetEpContextNodes() is, the EP creates the model instance to hold all EPContext nodes for the graph partitioner to query. EP only need to add all EPContext nodes into that model instance.

qnn_ep_context_model_ = Factory<Model>::Create(std::string{"qnn_ep_context_model"}, false, logger);
ORT_RETURN_IF_ERROR(qnn::CreateEPContextNodes(qnn_ep_context_model_.get(),
context_buffer.get(),
buffer_size,
qnn_backend_manager_->GetSdkVersion(),
fused_nodes_and_graphs,
qnn_models_,
context_model_path,
qnn_context_embed_mode_,
max_spill_fill_buffer_size,
logger,
share_ep_contexts_,
stop_share_ep_contexts_));

Copy link
Contributor

@chilo-ms chilo-ms Jun 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name of the CreateCtxModel() might be a bit confusing.
It has some historic reasons

  • At the beginning of TRT EP implementing EP Context feature, it didn't follow the rule to implement GetEpContextNodes()
  • it only supported a model only containing a single EP Context node, meaning the whole graph can be run by TRT. There is no partitioning.
  • It didn't implement EP API's GetEpContextNodes(), instead it directly creates the ONNX model, so that's why this function has this name to really create a onnx model and dump to file.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now, since RTX EP and TRT EP are going to follow EP Context implementation rule to implement GetEpContextNodes(), we could rename it to avoid confusion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sharing some source between TRT and TRT RTX seems like a good idea but will be kind of awkward. Any opinion on this ? Otherwise let's duplicate changes and source for now.

@HectorSVC HectorSVC requested review from jywu-msft and chilo-ms May 29, 2025 22:00
@HectorSVC
Copy link
Contributor

There was a fix for the Web CI pipeline, please merge the code from latest main branch.

@chilo-ms
Copy link
Contributor

chilo-ms commented Jun 3, 2025

When running with TRT RTX EP, will it also handle the case where model contains contrib ops that will be fallback to run on CUDA EP or CPU? I assume it will.

If that's the case, following change need to be added to this PR as well.

  • In GetCapability, it should be able the correctly parse the EP Context model that contains some EP Context nodes and some onnx nodes.

@gedoensmax
Copy link
Contributor

If we will be filling the source section on an EP Context node will the partitioner automatically force a node onto an EP ? I think if there are EP Context nodes there is no longer a real partitioning needed 🤔

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:NvRTX NV RTX execution provider
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants