Fix SimplifiedLayerNorm fusion with node-produced Pow exponent#29196
Fix SimplifiedLayerNorm fusion with node-produced Pow exponent#29196the0cp wants to merge 6 commits into
Conversation
There was a problem hiding this comment.
Pull request overview
This PR fixes a failure in the SimplifiedLayerNormFusion optimizer pass when the Pow exponent is produced by a node (e.g., from a mixed-precision Cast) instead of being an initializer, which previously could cause FinalizeNodeFusion to try to rewire a non-existent replacement input and fail graph initialization.
Changes:
- Disconnects input edges on the first fused node that are not valid inputs to the
SimplifiedLayerNormalizationreplacement before callingFinalizeNodeFusion. - Defers cleanup of newly-dead upstream producers until after each fusion, and only removes them once they have no remaining consumers.
- Adds a regression test covering multiple SimplifiedLayerNorm patterns sharing a Cast-produced Pow exponent.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
onnxruntime/core/optimizer/layer_norm_fusion.cc |
Adjusts SimplifiedLayerNorm fusion finalization to ignore/clean up node-produced Pow exponent inputs safely. |
onnxruntime/test/optimizer/graph_transform_test_layernorm.cc |
Adds a regression test for shared Cast-produced Pow exponent across multiple fused patterns. |
|
Addressed all three review comments |
tianleiwu
left a comment
There was a problem hiding this comment.
Summary
The fix correctly targets the root cause described in #29153: FinalizeNodeFusion moves all input edges of the first fused node onto the replacement SimplifiedLayerNormalization, and a node-produced (mixed-precision Cast) Pow exponent has no corresponding input on the replacement, so the moved edge later fails in GetIndexFromName. Pre-disconnecting non-replacement input edges and only removing the producer once it becomes dead is the right approach, and the deferred cleanup (GetOutputEdgesCount() == 0 guard + RemoveNodesWithOneOutputBottomUp) safely handles the shared-producer case across multiple matched subgraphs.
Positive
- Matching kept edges by
NodeArgname againstlayer_norm_input_defsis robust;scalenever feeds the first node so only the stray exponent edge is disconnected. - Deferring producer removal until after
FinalizeNodeFusionand gating onGetOutputEdgesCount() == 0correctly preserves aCastexponent shared by two subgraphs until its last consumer is fused — exactly what the new test exercises. - The regression test reproduces the reported pattern (fp16 const ->
Cast-> sharedPowexponent across two subgraphs) and was already hardened to use guardedfindlookups instead of throwingmap::at.
Minor suggestion
⚠️ The disconnect/cleanup logic only inspectsnodes_to_remove.front(). In the leading-Cast(GPU) path the front node is the inputCastandPowis a middle node, so a node-produced exponent on thatPowis dropped byFinalizeNodeFusionwithout its producer being cleaned up, leaving a dead orphan node. This does not crash (the edge is removed, not moved), so it is out of scope for the #29153 crash, but the asymmetry is worth a short comment or a follow-up. See the inline note.
Overall this looks correct and well-tested for the reported scenario.
|
Please, address the following test coverage gaps:
|
|
Thanks for the detailed review! I pushed additional coverage and validation updates:
I intentionally kept the remaining mixed partial-fusion shared-Cast case out of this PR for now because the current regression already covers the reported crash path and shared-producer lifetime, while the partial-fusion variant is additional edge-case coverage rather than required for the fix. I can add it as follow-up if you think it is necessary for this PR. |
tianleiwu
left a comment
There was a problem hiding this comment.
Reviewed the latest head (cb08b0e). The fix is correct and well-tested.
The core approach is sound: non-replacement input edges of the fused front node are disconnected before FinalizeNodeFusion (so the exponent edge is no longer moved onto a replacement node that lacks that input), and dead producers are removed afterward only when their last consumer is gone (GetOutputEdgesCount() == 0), which correctly handles the shared-Cast case. Tracking NodeIndex values (not pointers) with a nullptr re-fetch guard is safe against bottom-up cascade removals. The leading-Cast path is handled by explicitly tracking the Pow exponent producer since the Pow is a middle node removed by FinalizeNodeFusion. GetOtherAddInput correctly identifies epsilon by connectivity instead of a fixed Add input index, and IsPowExponentTwo is a genuine correctness improvement over the previous unconditional 2.0 assumption.
Test coverage is thorough (shared-Cast exponent in both leading-cast variants, commutative epsilon, exponent-not-two rejection, and a full optimization-loop CPU-fallback regression). All previously raised threads are resolved and addressed.
One low-priority suggestion below; otherwise this looks good.
|
The macOS xnnpack failure appears unrelated: the C++ tests passed and the failure is a Python quantization test hitting a DNS/urlretrieve error. The wasm failure also does not show an error from this change, layer_norm_fusion.cc compiled successfully. |
Description
Motivation and Context
A mixed-precision Cast can turn the Pow exponent from an initializer-only input into a graph edge. SimplifiedLayerNormFusion previously passed that edge to FinalizeNodeFusion, which attempted to move it to the replacement node. Since SimplifiedLayerNormalization does not have a corresponding exponent input, graph initialization failed in GetIndexFromName.
The fix removes only inputs that are not part of the replacement node and preserves shared producers until they have no remaining consumers.
Additional validation was added based on review feedback to make the fusion semantically safer: the Pow exponent must be proven to be 2.0, and epsilon is now identified by graph connectivity instead of assuming a fixed Add input index.
Fixes #29153.