Skip to content

Conversation

@galagam
Copy link
Contributor

@galagam galagam commented Dec 21, 2025

What does this PR do?

Type of change: Bug fix

Overview:
When clearing type information for shape inference, preserve value_info for outer scope variables in subgraphs. Previously, all value_info entries were cleared indiscriminately, causing shape inference failures when subgraph nodes referenced outer scope variables.

Testing

pytest tests/unit/onnx/autocast/test_precisionconverter.py::test_if_subgraph_outer_scope_type_preservation

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: Yes
  • Did you add or update any necessary documentation?: N/A
  • Did you update Changelog?: No

…during type clearing

When clearing type information for shape inference, preserve value_info
for outer scope variables in subgraphs. Previously, all value_info
entries were cleared indiscriminately, causing shape inference failures
when subgraph nodes referenced outer scope variables.

Signed-off-by: Gal Hubara Agam <[email protected]>
@galagam galagam requested a review from a team as a code owner December 21, 2025 13:56
@galagam galagam requested a review from ajrasane December 21, 2025 13:56
@codecov
Copy link

codecov bot commented Dec 21, 2025

Codecov Report

❌ Patch coverage is 73.33333% with 4 lines in your changes missing coverage. Please review.
✅ Project coverage is 74.68%. Comparing base (cb34335) to head (fd1dd71).
⚠️ Report is 16 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/onnx/autocast/precisionconverter.py 73.33% 4 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #717      +/-   ##
==========================================
- Coverage   74.69%   74.68%   -0.01%     
==========================================
  Files         192      192              
  Lines       18946    18956      +10     
==========================================
+ Hits        14152    14158       +6     
- Misses       4794     4798       +4     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@galagam
Copy link
Contributor Author

galagam commented Dec 23, 2025

@ajrasane @gcunhase please help review

@jricker2
Copy link

I am actually running into an issue that is slightly related. Hopefully can get some clarification before looking further into creating a reproducer. I found that for a specific model with dynamic inputs, after converting to FP16 the FP16 graph cannot be run on CPU or Cuda. This is not the biggest deal since we plan to deploy on TensorRT anyway. Looking into it further the root cause is value_info for some operators being mostly "unk", rather than the real values added after shape inference. I end up with shape mismatch errors at inference time. Removing value_info from the converted model and re-updating value_info with shape inference values resolves the issue. This leads me to wonder about the purpose of filling value_info with undefined type and generic dims before running shape inference?

@galagam
Copy link
Contributor Author

galagam commented Dec 24, 2025

I am actually running into an issue that is slightly related. Hopefully can get some clarification before looking further into creating a reproducer. I found that for a specific model with dynamic inputs, after converting to FP16 the FP16 graph cannot be run on CPU or Cuda. This is not the biggest deal since we plan to deploy on TensorRT anyway. Looking into it further the root cause is value_info for some operators being mostly "unk", rather than the real values added after shape inference. I end up with shape mismatch errors at inference time. Removing value_info from the converted model and re-updating value_info with shape inference values resolves the issue. This leads me to wonder about the purpose of filling value_info with undefined type and generic dims before running shape inference?

Hey @jricker2 !

This is not the biggest deal since we plan to deploy on TensorRT anyway

Can you explain this statement? Do you mean deploy to TensorRT with weak typing (AKA --fp16 / kFP16 flag)? AutoCast is meant to prepare the model for deployment in TensorRT with strong typing. Please keep in mind that weak typing was deprecated and is planned to be removed in the next major release.

Looking into it further the root cause is value_info for some operators being mostly "unk", rather than the real values added after shape inference. I end up with shape mismatch errors at inference time.

This is a bug. Would you like to share the model so we can look into it?

This leads me to wonder about the purpose of filling value_info with undefined type and generic dims before running shape inference?

We want to ensure consistent type inference after adding cast nodes. The simplest way to do this is using ONNX's infer_shapes function which also performs type inference. If we leave the existing types, we'll get shape/type mismatches during shape inference.
The unnecessary coupling of type and shape inference is a known issue which should eventually be resolved in ONNX.
You might also be interested in this early draft for a workaround, using a standalone implementation for type inference only. This is still in draft mode, so no guarantees... #719

@jricker2
Copy link

This is not the biggest deal since we plan to deploy on TensorRT anyway

Can you explain this statement? Do you mean deploy to TensorRT with weak typing (AKA --fp16 / kFP16 flag)? AutoCast is meant to prepare the model for deployment in TensorRT with strong typing. Please keep in mind that weak typing was deprecated and is planned to be removed in the next major release.

Apologies for the confusion here I worded this poorly. We use TensorRT EP in ONNXRuntime (which works with the converted model). We will not use the CUDA/CPU EP at any point so the failure I'm mentioning is not a show stopper by any means. I am working on pushing folks towards strongly typed models via AutoCast.

Looking into it further the root cause is value_info for some operators being mostly "unk", rather than the real values added after shape inference. I end up with shape mismatch errors at inference time.

This is a bug. Would you like to share the model so we can look into it?

Unfortunately I cannot share the full model, I will hopefully have an IP-free reproducer soon and will file an issue with that.

This leads me to wonder about the purpose of filling value_info with undefined type and generic dims before running shape inference?

We want to ensure consistent type inference after adding cast nodes. The simplest way to do this is using ONNX's infer_shapes function which also performs type inference. If we leave the existing types, we'll get shape/type mismatches during shape inference. The unnecessary coupling of type and shape inference is a known issue which should eventually be resolved in ONNX. You might also be interested in this early draft for a workaround, using a standalone implementation for type inference only. This is still in draft mode, so no guarantees... #719

This makes sense, I was wondering about the difference in having value_info be pre-populated with undefined types/generic dims (how it currently is), vs just being empty before running infer_shapes to populate it. This made the difference in making the model work for CPU/Cuda EP. I don't know enough about infer_shapes to know why this is the case. As a note, shape_inference passes on the converted model, it is actual inference on CPU/Cuda EP which fails.

The local shape inference is interesting, if I have some time I will try it out, thanks for the response.

# Clear value_info only for intermediates produced by nodes in this subgraph
for vi in g.value_info:
if vi.name in subgraph_outputs:
vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jricker2 do you mean that replacing lines 317~320 with vi.type.ClearField("tensor_type") solves the ORT with CUDA EP issue that you're observing?

Copy link

@jricker2 jricker2 Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly, just commenting those lines out actually resolves the issue (lines 321-326, non-subgraph block)

I also found that if I cleared the value_info from the original fp32 graph then ran shape inference to re-populate it, then fed this graph into the model optimizer I had no issues. It seems to be an issue with how the graph was produced (exported by torch I believe). I don't own the creation/export of the original torch model, and the workaround is straightforward so I decided to not look much further into it (also because I found debugging value_info related issue to be very time consuming, tools like DL designer are not very helpful for this).

Anyhow, from what I can tell having a clear value_info before shape inference is the best way to go as opposed to pre-filling with generic shape/type.

edit: as to not take anything away from this PR - there are no subgraphs in the model I had this issue with, I just saw that this touched similar parts of code as I was looking into so figured I ask. Don't want to hold this up.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the confirmation. Please let us know when you have an IP-free repro so we can check if the suggested WAR is enough to fix this issue. Thanks.

Copy link
Contributor

@gcunhase gcunhase left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, we can tackle the tensor_type clearing issue once we have a repro from @jricker2.

@ajrasane
Copy link
Contributor

ajrasane commented Jan 9, 2026

LGTM. @jricker2, do share the ip-free model or a custom model that we can use to replicate this issue and look into it further

@galagam galagam merged commit 7971fff into NVIDIA:main Jan 9, 2026
35 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants