Skip to content

Fix SimplifiedLayerNorm fusion with node-produced Pow exponent#29196

Open
the0cp wants to merge 6 commits into
microsoft:mainfrom
the0cp:fix-simplified-layernorm-shared-cast
Open

Fix SimplifiedLayerNorm fusion with node-produced Pow exponent#29196
the0cp wants to merge 6 commits into
microsoft:mainfrom
the0cp:fix-simplified-layernorm-shared-cast

Conversation

@the0cp

@the0cp the0cp commented Jun 20, 2026

Copy link
Copy Markdown
Contributor

Description

  • Disconnect input edges that are not used by the fused SimplifiedLayerNormalization node before calling FinalizeNodeFusion.
  • Remove newly dead input producers only after their final consumer has been fused.
  • Add a full optimization-loop regression test for CPU EP fallback with a shared Cast-produced Pow exponent.
  • Validate that the Pow exponent is statically known to be scalar/one-element 2.0 before applying SimplifiedLayerNorm fusion.
  • Support epsilon on either input of Add, since Add is commutative.

Motivation and Context

A mixed-precision Cast can turn the Pow exponent from an initializer-only input into a graph edge. SimplifiedLayerNormFusion previously passed that edge to FinalizeNodeFusion, which attempted to move it to the replacement node. Since SimplifiedLayerNormalization does not have a corresponding exponent input, graph initialization failed in GetIndexFromName.

The fix removes only inputs that are not part of the replacement node and preserves shared producers until they have no remaining consumers.

Additional validation was added based on review feedback to make the fusion semantically safer: the Pow exponent must be proven to be 2.0, and epsilon is now identified by graph connectivity instead of assuming a fixed Add input index.

Fixes #29153.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes a failure in the SimplifiedLayerNormFusion optimizer pass when the Pow exponent is produced by a node (e.g., from a mixed-precision Cast) instead of being an initializer, which previously could cause FinalizeNodeFusion to try to rewire a non-existent replacement input and fail graph initialization.

Changes:

  • Disconnects input edges on the first fused node that are not valid inputs to the SimplifiedLayerNormalization replacement before calling FinalizeNodeFusion.
  • Defers cleanup of newly-dead upstream producers until after each fusion, and only removes them once they have no remaining consumers.
  • Adds a regression test covering multiple SimplifiedLayerNorm patterns sharing a Cast-produced Pow exponent.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
onnxruntime/core/optimizer/layer_norm_fusion.cc Adjusts SimplifiedLayerNorm fusion finalization to ignore/clean up node-produced Pow exponent inputs safely.
onnxruntime/test/optimizer/graph_transform_test_layernorm.cc Adds a regression test for shared Cast-produced Pow exponent across multiple fused patterns.

Comment thread onnxruntime/core/optimizer/layer_norm_fusion.cc Outdated
Comment thread onnxruntime/test/optimizer/graph_transform_test_layernorm.cc Outdated
Comment thread onnxruntime/test/optimizer/graph_transform_test_layernorm.cc Outdated
@the0cp

the0cp commented Jun 20, 2026

Copy link
Copy Markdown
Contributor Author

Addressed all three review comments

@tianleiwu tianleiwu left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary

The fix correctly targets the root cause described in #29153: FinalizeNodeFusion moves all input edges of the first fused node onto the replacement SimplifiedLayerNormalization, and a node-produced (mixed-precision Cast) Pow exponent has no corresponding input on the replacement, so the moved edge later fails in GetIndexFromName. Pre-disconnecting non-replacement input edges and only removing the producer once it becomes dead is the right approach, and the deferred cleanup (GetOutputEdgesCount() == 0 guard + RemoveNodesWithOneOutputBottomUp) safely handles the shared-producer case across multiple matched subgraphs.

Positive

  • Matching kept edges by NodeArg name against layer_norm_input_defs is robust; scale never feeds the first node so only the stray exponent edge is disconnected.
  • Deferring producer removal until after FinalizeNodeFusion and gating on GetOutputEdgesCount() == 0 correctly preserves a Cast exponent shared by two subgraphs until its last consumer is fused — exactly what the new test exercises.
  • The regression test reproduces the reported pattern (fp16 const -> Cast -> shared Pow exponent across two subgraphs) and was already hardened to use guarded find lookups instead of throwing map::at.

Minor suggestion

  • ⚠️ The disconnect/cleanup logic only inspects nodes_to_remove.front(). In the leading-Cast (GPU) path the front node is the input Cast and Pow is a middle node, so a node-produced exponent on that Pow is dropped by FinalizeNodeFusion without its producer being cleaned up, leaving a dead orphan node. This does not crash (the edge is removed, not moved), so it is out of scope for the #29153 crash, but the asymmetry is worth a short comment or a follow-up. See the inline note.

Overall this looks correct and well-tested for the reported scenario.

Comment thread onnxruntime/core/optimizer/layer_norm_fusion.cc

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

Comment thread onnxruntime/test/optimizer/graph_transform_test_layernorm.cc Outdated
@yuslepukhin

Copy link
Copy Markdown
Member

Please, address the following test coverage gaps:

  • No full optimization loop test — The test runs SimplifiedLayerNormFusion in isolation. It doesn't exercise the actual crash scenario: L2 → InsertCastTransformer → L4 (FuseInitializersTransformer triggers loop) → L2 repeat. The Cast node in the test is manually constructed, rather than being produced by InsertCastTransformer during the loop.

  • No test with CPU EP assignment — The real bug requires nodes to be assigned to CPU EP (so the fp16 skip check at line 691 fires in the first L2 pass, then InsertCastTransformer converts to fp32, then the second L2 pass succeeds). The PR test doesn't set execution providers on nodes.

  • No test where the shared exponent Cast has multiple consumers across iterations — In the real scenario, the Cast node feeds many Pow nodes. When one pattern is fused, the Cast still has other consumers. The test uses two patterns (good), but both are fused in the same transformer pass. It doesn't test the case where one pattern fuses and the other doesn't (e.g., because of some validation check failure), leaving the Cast partially consumed.

  • Epsilon at Add input[0] — Add is commutative. If a model has the epsilon at index 0 and ReduceMean output at index 1, line 703 reads the wrong initializer name. No test covers this ordering.

  • Double precision stash_type path — Line 716-718 sets stash_type to DOUBLE when either x_input or scale is double. No test exercises this.

  • No test verifying the Pow exponent value is 2.0 — The fusion matches any Pow node regardless of exponent value. If the exponent is not 2 (e.g., 3), the fusion still fires and produces mathematically incorrect results. There's no validation that the Pow exponent is actually 2. (This is a pre-existing correctness issue, not a gap in the PR itself.)

  • has_leading_cast = true only tested with skip_device_check_ = true — The test passes has_leading_cast to the constructor, but in production the leading Cast path is gated on is_gpu_ep

  • The test doesn't simulate actual CUDA EP assignment, so it relies on skip_device_check_ flag.

@the0cp

the0cp commented Jun 23, 2026

Copy link
Copy Markdown
Contributor Author

Thanks for the detailed review!

I pushed additional coverage and validation updates:

  • Added a full optimization-loop regression test that exercises the L2 -> InsertCastTransformer -> L4 -> L2 repeat path with CPU EP assignment and a shared Cast-produced Pow exponent.
  • Hardened the fusion so Pow is only accepted when the exponent can be statically proven to be scalar/one-element 2.0, including the Cast-produced exponent case from the reported issue.
  • Added coverage for epsilon on Add input[0], since Add is commutative.
  • Updated the leading-Cast test path to use actual CUDA EP assignment instead of relying on skip_device_check_.

I intentionally kept the remaining mixed partial-fusion shared-Cast case out of this PR for now because the current regression already covers the reported crash path and shared-producer lifetime, while the partial-fusion variant is additional edge-case coverage rather than required for the fix. I can add it as follow-up if you think it is necessary for this PR.

@tianleiwu tianleiwu left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed the latest head (cb08b0e). The fix is correct and well-tested.

The core approach is sound: non-replacement input edges of the fused front node are disconnected before FinalizeNodeFusion (so the exponent edge is no longer moved onto a replacement node that lacks that input), and dead producers are removed afterward only when their last consumer is gone (GetOutputEdgesCount() == 0), which correctly handles the shared-Cast case. Tracking NodeIndex values (not pointers) with a nullptr re-fetch guard is safe against bottom-up cascade removals. The leading-Cast path is handled by explicitly tracking the Pow exponent producer since the Pow is a middle node removed by FinalizeNodeFusion. GetOtherAddInput correctly identifies epsilon by connectivity instead of a fixed Add input index, and IsPowExponentTwo is a genuine correctness improvement over the previous unconditional 2.0 assumption.

Test coverage is thorough (shared-Cast exponent in both leading-cast variants, commutative epsilon, exponent-not-two rejection, and a full optimization-loop CPU-fallback regression). All previously raised threads are resolved and addressed.

One low-priority suggestion below; otherwise this looks good.

Comment thread onnxruntime/core/optimizer/layer_norm_fusion.cc
@the0cp

the0cp commented Jun 23, 2026

Copy link
Copy Markdown
Contributor Author

The macOS xnnpack failure appears unrelated: the C++ tests passed and the failure is a Python quantization test hitting a DNS/urlretrieve error. The wasm failure also does not show an error from this change, layer_norm_fusion.cc compiled successfully.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

cannot load mixedbread-ai model with optimization enabled

4 participants