Skip to content

Commit ff0a132

Browse files
authored
Eliminate unnecessary ScatterND (#2422)
Identify ScatterND(data, indices, updates) that can be replaced by Identity(updates). This is generated by the translation of `x[:, ...] = y` in PyTorch. The specific pattern is that the updated indices take the form [[0], ..., [S-1]] for the first dimension, where S is the size of the first dimension of the updated-data tensor. In effect, the scatter-update ends up being an assignment of a new value to the entire tensor. --------- Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent 7b89760 commit ff0a132

File tree

3 files changed

+139
-0
lines changed

3 files changed

+139
-0
lines changed

onnxscript/rewriter/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"pattern",
99
"rewrite",
1010
"RewritePass",
11+
"MatchResult",
1112
]
1213

1314
import onnx
@@ -21,7 +22,9 @@
2122
collapse_slices,
2223
no_op,
2324
pattern,
25+
redundant_scatter_nd,
2426
)
27+
from onnxscript.rewriter._basics import MatchResult
2528

2629
_ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model)
2730
_DEFAULT_REWRITE_RULES: tuple[pattern.RewriteRule, ...] = (
@@ -30,6 +33,7 @@
3033
*cast_constant_of_shape.rules.rules,
3134
*collapse_slices.rules.rules,
3235
*basic_rules.basic_optimization_rules().rules,
36+
*redundant_scatter_nd.rules.rules,
3337
)
3438

3539

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Rewrite rule to eliminate redundant ScatterND operations.
4+
5+
Identify ScatterND(data, indices, updates) that can be replaced by Identity(updates).
6+
This is generated by the translation of `x[:, ...] = y` in PyTorch.
7+
The specific pattern is that the updated indices take the form [[0], ..., [S-1]] for the first dimension,
8+
where S is the size of the first dimension of the updated-data tensor.
9+
In effect, the scatter-update ends up being an assignment of a new value to the entire tensor.
10+
"""
11+
12+
from __future__ import annotations
13+
14+
import onnx_ir as ir
15+
16+
import onnxscript.rewriter
17+
from onnxscript.rewriter import _ir_utils as ir_utils
18+
from onnxscript.rewriter import pattern as orp
19+
20+
21+
def fail(*args):
22+
return onnxscript.rewriter.MatchResult().fail(*args)
23+
24+
25+
class ScatterAll(orp.RewriteRuleClassBase):
26+
def pattern(self, op, data, axis, transposed_data, updates):
27+
# Construct update-indices spanning an entire axis:
28+
shape = op.Shape(data, start=0)
29+
dim = op.Gather(shape, axis, axis=0)
30+
full_range = op.Range(0, dim, 1)
31+
full_range_2d = op.Unsqueeze(full_range, [-1])
32+
# The update is applied to the data transposed to bring the updated axis to the front:
33+
return op.ScatterND(transposed_data, full_range_2d, updates, reduction="none")
34+
35+
def check(self, context, data, axis, transposed_data, **_):
36+
# Check that updated-indices represent the full range of the first dimension of the transposed data.
37+
# That is: check that the data.shape[axis] matches transposed_data.shape[0].
38+
axis_value = ir_utils.get_singleton_value(axis)
39+
if not isinstance(axis_value, int):
40+
return fail("Axis value must be a constant integer.", axis)
41+
shape: ir.Shape | None = data.shape
42+
if shape is None:
43+
return fail("Data shape is not statically known.", data)
44+
updated_dim_value = shape[axis_value]
45+
transposed_data_shape: ir.Shape | None = transposed_data.shape
46+
if transposed_data_shape is None:
47+
return fail("Transposed data shape is not statically known.", transposed_data)
48+
actual_dim_value = transposed_data_shape[0]
49+
if updated_dim_value != actual_dim_value:
50+
# The first dimension of the transposed data does not match the updated dimension,
51+
# so we cannot apply this rule.
52+
return fail(
53+
"The first dimension of the transposed data does not match the updated dimension.",
54+
data,
55+
transposed_data,
56+
)
57+
return True
58+
59+
def rewrite(self, op, updates, **_):
60+
return op.Identity(updates)
61+
62+
63+
rule = ScatterAll.rule()
64+
65+
rules = orp.RewriteRuleSet([rule])
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
# ruff: noqa: F821
4+
5+
import unittest
6+
7+
import numpy as np
8+
import onnx_ir as ir
9+
import onnxruntime
10+
from onnx_ir.passes.common import CheckerPass, ShapeInferencePass
11+
12+
import onnxscript.optimizer
13+
from onnxscript import FLOAT, script
14+
from onnxscript import opset18 as op
15+
from onnxscript.rewriter import redundant_scatter_nd
16+
17+
shape_inference = ShapeInferencePass()
18+
onnx_check = CheckerPass(True)
19+
20+
21+
class RedundantScatterNdTest(unittest.TestCase):
22+
def test_redundant_scatter_nd(self):
23+
@script()
24+
def model_script(
25+
data: FLOAT[8, "N", 16], updates: FLOAT[8, "N", 16]
26+
) -> FLOAT[8, "N", 16]:
27+
# Construct update-indices spanning an entire axis:
28+
axis = op.Constant(value_int=1)
29+
shape = op.Shape(data, start=0)
30+
dim = op.Gather(shape, axis, axis=0)
31+
full_range = op.Range(0, dim, 1)
32+
full_range_2d = op.Unsqueeze(full_range, [-1])
33+
# The update is applied to the data transposed to bring the updated axis to the front:
34+
transposed_data = op.Transpose(data, perm=[1, 0, 2])
35+
transposed_updates = op.Transpose(updates, perm=[1, 0, 2])
36+
scattered = op.ScatterND(
37+
transposed_data, full_range_2d, transposed_updates, reduction="none"
38+
)
39+
# Transpose the result back to the original shape:
40+
output = op.Transpose(scattered, perm=[1, 0, 2])
41+
return output
42+
43+
input_model_proto = model_script.to_model_proto()
44+
model = ir.serde.deserialize_model(input_model_proto)
45+
onnx_check(model)
46+
shape_inference(model)
47+
onnxscript.optimizer.fold_constants(model)
48+
count = redundant_scatter_nd.rules.apply_to_model(model)
49+
self.assertEqual(count, 1)
50+
onnx_check(model)
51+
optimized_model_proto = ir.serde.serialize_model(model)
52+
# Test that both models are equivalent:
53+
inputs = {
54+
"data": np.random.rand(8, 4, 16).astype(np.float32),
55+
"updates": np.random.rand(8, 4, 16).astype(np.float32),
56+
}
57+
session = onnxruntime.InferenceSession(
58+
input_model_proto.SerializeToString(), providers=["CPUExecutionProvider"]
59+
)
60+
outputs = session.run(None, inputs)
61+
optimized_session = onnxruntime.InferenceSession(
62+
optimized_model_proto.SerializeToString(), providers=["CPUExecutionProvider"]
63+
)
64+
optimized_outputs = optimized_session.run(None, inputs)
65+
for output, optimized_output in zip(outputs, optimized_outputs):
66+
np.testing.assert_allclose(output, optimized_output, rtol=1e-6, atol=1e-6)
67+
68+
69+
if __name__ == "__main__":
70+
unittest.main()

0 commit comments

Comments
 (0)