|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# Licensed under the MIT License. |
| 3 | +# ruff: noqa: F821 |
| 4 | + |
| 5 | +import unittest |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +import onnx_ir as ir |
| 9 | +import onnxruntime |
| 10 | +from onnx_ir.passes.common import CheckerPass, ShapeInferencePass |
| 11 | + |
| 12 | +import onnxscript.optimizer |
| 13 | +from onnxscript import FLOAT, script |
| 14 | +from onnxscript import opset18 as op |
| 15 | +from onnxscript.rewriter import redundant_scatter_nd |
| 16 | + |
| 17 | +shape_inference = ShapeInferencePass() |
| 18 | +onnx_check = CheckerPass(True) |
| 19 | + |
| 20 | + |
| 21 | +class RedundantScatterNdTest(unittest.TestCase): |
| 22 | + def test_redundant_scatter_nd(self): |
| 23 | + @script() |
| 24 | + def model_script( |
| 25 | + data: FLOAT[8, "N", 16], updates: FLOAT[8, "N", 16] |
| 26 | + ) -> FLOAT[8, "N", 16]: |
| 27 | + # Construct update-indices spanning an entire axis: |
| 28 | + axis = op.Constant(value_int=1) |
| 29 | + shape = op.Shape(data, start=0) |
| 30 | + dim = op.Gather(shape, axis, axis=0) |
| 31 | + full_range = op.Range(0, dim, 1) |
| 32 | + full_range_2d = op.Unsqueeze(full_range, [-1]) |
| 33 | + # The update is applied to the data transposed to bring the updated axis to the front: |
| 34 | + transposed_data = op.Transpose(data, perm=[1, 0, 2]) |
| 35 | + transposed_updates = op.Transpose(updates, perm=[1, 0, 2]) |
| 36 | + scattered = op.ScatterND( |
| 37 | + transposed_data, full_range_2d, transposed_updates, reduction="none" |
| 38 | + ) |
| 39 | + # Transpose the result back to the original shape: |
| 40 | + output = op.Transpose(scattered, perm=[1, 0, 2]) |
| 41 | + return output |
| 42 | + |
| 43 | + input_model_proto = model_script.to_model_proto() |
| 44 | + model = ir.serde.deserialize_model(input_model_proto) |
| 45 | + onnx_check(model) |
| 46 | + shape_inference(model) |
| 47 | + onnxscript.optimizer.fold_constants(model) |
| 48 | + count = redundant_scatter_nd.rules.apply_to_model(model) |
| 49 | + self.assertEqual(count, 1) |
| 50 | + onnx_check(model) |
| 51 | + optimized_model_proto = ir.serde.serialize_model(model) |
| 52 | + # Test that both models are equivalent: |
| 53 | + inputs = { |
| 54 | + "data": np.random.rand(8, 4, 16).astype(np.float32), |
| 55 | + "updates": np.random.rand(8, 4, 16).astype(np.float32), |
| 56 | + } |
| 57 | + session = onnxruntime.InferenceSession( |
| 58 | + input_model_proto.SerializeToString(), providers=["CPUExecutionProvider"] |
| 59 | + ) |
| 60 | + outputs = session.run(None, inputs) |
| 61 | + optimized_session = onnxruntime.InferenceSession( |
| 62 | + optimized_model_proto.SerializeToString(), providers=["CPUExecutionProvider"] |
| 63 | + ) |
| 64 | + optimized_outputs = optimized_session.run(None, inputs) |
| 65 | + for output, optimized_output in zip(outputs, optimized_outputs): |
| 66 | + np.testing.assert_allclose(output, optimized_output, rtol=1e-6, atol=1e-6) |
| 67 | + |
| 68 | + |
| 69 | +if __name__ == "__main__": |
| 70 | + unittest.main() |
0 commit comments