Skip to content

Eliminate unnecessary ScatterND #2422

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions onnxscript/rewriter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"pattern",
"rewrite",
"RewritePass",
"MatchResult",
]

import onnx
Expand All @@ -21,7 +22,9 @@
collapse_slices,
no_op,
pattern,
redundant_scatter_nd,
)
from onnxscript.rewriter._basics import MatchResult

_ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model)
_DEFAULT_REWRITE_RULES: tuple[pattern.RewriteRule, ...] = (
Expand All @@ -30,6 +33,7 @@
*cast_constant_of_shape.rules.rules,
*collapse_slices.rules.rules,
*basic_rules.basic_optimization_rules().rules,
*redundant_scatter_nd.rules.rules,
)


Expand Down
65 changes: 65 additions & 0 deletions onnxscript/rewriter/redundant_scatter_nd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Rewrite rule to eliminate redundant ScatterND operations.

Identify ScatterND(data, indices, updates) that can be replaced by Identity(updates).
This is generated by the translation of `x[:, ...] = y` in PyTorch.
The specific pattern is that the updated indices take the form [[0], ..., [S-1]] for the first dimension,
where S is the size of the first dimension of the updated-data tensor.
In effect, the scatter-update ends up being an assignment of a new value to the entire tensor.
"""

from __future__ import annotations

import onnx_ir as ir

import onnxscript.rewriter
from onnxscript.rewriter import _ir_utils as ir_utils
from onnxscript.rewriter import pattern as orp


def fail(*args):
return onnxscript.rewriter.MatchResult().fail(*args)

Check warning on line 22 in onnxscript/rewriter/redundant_scatter_nd.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/redundant_scatter_nd.py#L22

Added line #L22 was not covered by tests


class ScatterAll(orp.RewriteRuleClassBase):
def pattern(self, op, data, axis, transposed_data, updates):
# Construct update-indices spanning an entire axis:
shape = op.Shape(data, start=0)
dim = op.Gather(shape, axis, axis=0)
full_range = op.Range(0, dim, 1)
full_range_2d = op.Unsqueeze(full_range, [-1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see these ops in the repro: pytorch/pytorch#157289

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you can delete

remove_redundant_scatternd = pattern.RewriteRule(

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can consolidate the rules separately. (I am thinking of trying out Copilot to do it.)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May need to make other dimensions symbolic. Otherwise, all of these ops will be constant-folded, and the indices becomes a constant.

# The update is applied to the data transposed to bring the updated axis to the front:
return op.ScatterND(transposed_data, full_range_2d, updates, reduction="none")

def check(self, context, data, axis, transposed_data, **_):
# Check that updated-indices represent the full range of the first dimension of the transposed data.
# That is: check that the data.shape[axis] matches transposed_data.shape[0].
axis_value = ir_utils.get_singleton_value(axis)
if not isinstance(axis_value, int):
return fail("Axis value must be a constant integer.", axis)

Check warning on line 40 in onnxscript/rewriter/redundant_scatter_nd.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/redundant_scatter_nd.py#L40

Added line #L40 was not covered by tests
shape: ir.Shape | None = data.shape
if shape is None:
return fail("Data shape is not statically known.", data)

Check warning on line 43 in onnxscript/rewriter/redundant_scatter_nd.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/redundant_scatter_nd.py#L43

Added line #L43 was not covered by tests
updated_dim_value = shape[axis_value]
transposed_data_shape: ir.Shape | None = transposed_data.shape
if transposed_data_shape is None:
return fail("Transposed data shape is not statically known.", transposed_data)

Check warning on line 47 in onnxscript/rewriter/redundant_scatter_nd.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/redundant_scatter_nd.py#L47

Added line #L47 was not covered by tests
actual_dim_value = transposed_data_shape[0]
if updated_dim_value != actual_dim_value:
# The first dimension of the transposed data does not match the updated dimension,
# so we cannot apply this rule.
return fail(

Check warning on line 52 in onnxscript/rewriter/redundant_scatter_nd.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/redundant_scatter_nd.py#L52

Added line #L52 was not covered by tests
"The first dimension of the transposed data does not match the updated dimension.",
data,
transposed_data,
)
return True

def rewrite(self, op, updates, **_):
return op.Identity(updates)


rule = ScatterAll.rule()

rules = orp.RewriteRuleSet([rule])
70 changes: 70 additions & 0 deletions onnxscript/rewriter/redundant_scatter_nd_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ruff: noqa: F821

import unittest

import numpy as np
import onnx_ir as ir
import onnxruntime
from onnx_ir.passes.common import CheckerPass, ShapeInferencePass

import onnxscript.optimizer
from onnxscript import FLOAT, script
from onnxscript import opset18 as op
from onnxscript.rewriter import redundant_scatter_nd

shape_inference = ShapeInferencePass()
onnx_check = CheckerPass(True)


class RedundantScatterNdTest(unittest.TestCase):
def test_redundant_scatter_nd(self):
@script()
def model_script(
data: FLOAT[8, "N", 16], updates: FLOAT[8, "N", 16]
) -> FLOAT[8, "N", 16]:
# Construct update-indices spanning an entire axis:
axis = op.Constant(value_int=1)
shape = op.Shape(data, start=0)
dim = op.Gather(shape, axis, axis=0)
full_range = op.Range(0, dim, 1)
full_range_2d = op.Unsqueeze(full_range, [-1])

Check warning on line 32 in onnxscript/rewriter/redundant_scatter_nd_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/redundant_scatter_nd_test.py#L28-L32

Added lines #L28 - L32 were not covered by tests
# The update is applied to the data transposed to bring the updated axis to the front:
transposed_data = op.Transpose(data, perm=[1, 0, 2])
transposed_updates = op.Transpose(updates, perm=[1, 0, 2])
scattered = op.ScatterND(

Check warning on line 36 in onnxscript/rewriter/redundant_scatter_nd_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/redundant_scatter_nd_test.py#L34-L36

Added lines #L34 - L36 were not covered by tests
transposed_data, full_range_2d, transposed_updates, reduction="none"
)
# Transpose the result back to the original shape:
output = op.Transpose(scattered, perm=[1, 0, 2])
return output

Check warning on line 41 in onnxscript/rewriter/redundant_scatter_nd_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/redundant_scatter_nd_test.py#L40-L41

Added lines #L40 - L41 were not covered by tests

input_model_proto = model_script.to_model_proto()
model = ir.serde.deserialize_model(input_model_proto)
onnx_check(model)
shape_inference(model)
onnxscript.optimizer.fold_constants(model)
count = redundant_scatter_nd.rules.apply_to_model(model)
self.assertEqual(count, 1)
onnx_check(model)
optimized_model_proto = ir.serde.serialize_model(model)
# Test that both models are equivalent:
inputs = {
"data": np.random.rand(8, 4, 16).astype(np.float32),
"updates": np.random.rand(8, 4, 16).astype(np.float32),
}
session = onnxruntime.InferenceSession(
input_model_proto.SerializeToString(), providers=["CPUExecutionProvider"]
)
outputs = session.run(None, inputs)
optimized_session = onnxruntime.InferenceSession(
optimized_model_proto.SerializeToString(), providers=["CPUExecutionProvider"]
)
optimized_outputs = optimized_session.run(None, inputs)
for output, optimized_output in zip(outputs, optimized_outputs):
np.testing.assert_allclose(output, optimized_output, rtol=1e-6, atol=1e-6)


if __name__ == "__main__":
unittest.main()

Check warning on line 70 in onnxscript/rewriter/redundant_scatter_nd_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/redundant_scatter_nd_test.py#L70

Added line #L70 was not covered by tests
Loading