Skip to content

Commit 1c81147

Browse files
committed
[Rewriter]: apply suggestions (#2301)
1 parent f86e1fd commit 1c81147

File tree

2 files changed

+81
-87
lines changed

2 files changed

+81
-87
lines changed

onnxscript/rewriter/fuse_pad_into_conv.py

Lines changed: 70 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33
"""Fuses Pad nodes into preceding nodes. Supported fusion patterns:
4-
- PadConv -> Conv
5-
- PadConvInteger -> ConvInteger
4+
- ConvPad -> Conv
5+
- ConvIntegerPad -> ConvInteger
66
77
To make some rules possible, we implicitly transform `auto_pad` attribute into its explicit list.
88
"""
99

10-
import typing
10+
from __future__ import annotations
11+
12+
from typing import List, Sequence
1113

1214
import numpy as np
1315
import onnx_ir as ir
1416

1517
from onnxscript.rewriter import pattern as orp
1618

1719

18-
def fill_pads_with_axes(
19-
pads: typing.Sequence[int], axes: typing.Sequence[int], rank: int
20-
) -> typing.List[int]:
20+
def fill_pads_with_axes(pads: Sequence[int], axes: Sequence[int], rank: int) -> List[int]:
2121
new_pads = [0] * 2 * rank
2222
N = len(axes)
2323
for start_idx, axis in enumerate(axes):
@@ -26,43 +26,39 @@ def fill_pads_with_axes(
2626
return new_pads
2727

2828

29-
def read_conv_attributes(ir_conv: ir.Node) -> dict[str, typing.Sequence[int] | str]:
29+
def read_conv_attributes(ir_conv: ir.Node) -> dict[str, Sequence[int] | str]:
3030
# Read attributes
3131
attributes = {}
32-
if (kernel_shape := ir_conv.attributes.get("kernel_shape", None)) is not None:
33-
attributes["kernel_shape"] = kernel_shape.as_ints()
34-
else:
35-
attributes["kernel_shape"] = ir_conv.inputs[1].shape[2:]
36-
if (strides := ir_conv.attributes.get("strides", None)) is not None:
37-
attributes["strides"] = strides.as_ints()
38-
else:
39-
attributes["strides"] = [1] * len(ir_conv.inputs[0].shape[2:])
40-
if (auto_pad := ir_conv.attributes.get("auto_pad", None)) is not None:
41-
attributes["auto_pad"] = auto_pad.as_string()
42-
else:
43-
attributes["auto_pad"] = "NOTSET"
44-
if (pads := ir_conv.attributes.get("pads", None)) is not None:
45-
attributes["pads"] = pads.as_ints()
32+
ir_attributes = ir_conv.attributes
33+
attributes["kernel_shape"] = ir_attributes.get_ints(
34+
"kernel_shape", ir_conv.inputs[1].shape[2:]
35+
)
36+
attributes["strides"] = ir_attributes.get_ints(
37+
"strides", [1] * len(ir_conv.inputs[0].shape[2:])
38+
)
39+
attributes["auto_pad"] = ir_attributes.get_string("auto_pad", "NOTSET")
40+
if "pads" in ir_attributes:
41+
attributes["pads"] = ir_attributes.get_ints("pads")
4642
return attributes
4743

4844

4945
class _FusePadConvBase(orp.RewriteRuleClassBase):
5046
"""Interface for PadConv nodes fusion."""
5147

52-
def __init__(self, name: str, as_function: bool = False):
48+
def __init__(self, as_function: bool = False):
5349
# Remove nodes is set to False to remove unused nodes after the rewrite.
54-
super().__init__(name=name, remove_nodes=False, as_function=as_function)
50+
super().__init__(remove_nodes=False, as_function=as_function)
5551

5652
def rewrite(
5753
self, op: ir.tape.Tape, x: ir.Value, pad: ir.Value, conv: ir.Value
5854
) -> ir.Value:
59-
pnode = pad.producer()
60-
cnode = conv.producer()
55+
pad_node = pad.producer()
56+
conv_node = conv.producer()
6157

6258
# Retrieve the padding and axes
6359
x_rank = len(x.shape)
64-
pad_pads = pnode.inputs[1].const_value.numpy().tolist()
65-
if len(pnode.inputs) > 3 and (axes := pnode.inputs[3]) is not None:
60+
pad_pads = pad_node.inputs[1].const_value.numpy().tolist()
61+
if len(pad_node.inputs) > 3 and (axes := pad_node.inputs[3]) is not None:
6662
axes = [x if x >= 0 else x_rank + x for x in axes.const_value.numpy()]
6763
else:
6864
axes = list(range(x_rank))
@@ -74,41 +70,40 @@ def rewrite(
7470
new_pads = pad_pads[2:x_rank] + pad_pads[x_rank + 2 :]
7571

7672
# Replace conv pads = new + old
77-
conv_attr: typing.Mapping[str, ir.Attr] = cnode.attributes.copy()
73+
conv_attr = conv_node.attributes.copy()
7874
if "pads" in conv_attr:
7975
new_pads = [x + y for x, y in zip(conv_attr["pads"].as_ints(), new_pads)]
80-
conv_attr["pads"] = ir.convenience.convert_attribute("pads", new_pads)
76+
conv_attr.add(ir.AttrInt64s("pads", new_pads))
8177

8278
return op.op(
83-
cnode.op_type,
84-
inputs=(x, *cnode.inputs[1:]),
79+
conv_node.op_type,
80+
inputs=(x, *conv_node.inputs[1:]),
8581
attributes=conv_attr,
86-
domain=cnode.domain,
87-
name=cnode.name,
82+
domain=conv_node.domain,
83+
name=conv_node.name,
8884
)
8985

9086
def check(self, context, x: ir.Value, pad: ir.Value, conv: ir.Value) -> orp.MatchResult:
9187
del context # Unused
9288
check_result = orp.MatchResult()
93-
pnode = pad.producer()
89+
pad_node = pad.producer()
9490
x_rank = len(x.shape)
9591

9692
# Pad constraints: attributes
97-
if (mode := pnode.attributes.get("mode", None)) and mode.as_string() != "constant":
98-
return check_result.fail(f"{pnode.name} mode must be 'constant'.")
93+
if (mode := pad_node.attributes.get("mode", None)) and mode.as_string() != "constant":
94+
return check_result.fail(f"{pad_node.name} mode must be 'constant'.")
9995

10096
# Pad constraints: inputs
101-
if (pads := pnode.inputs[1]).const_value is None:
97+
if (pads := pad_node.inputs[1]).const_value is None:
10298
return check_result.fail(f"{pads.name} is not a constant/initializer.")
103-
if len(pnode.inputs) > 2 and (constant_value := pnode.inputs[2]) is not None:
99+
if len(pad_node.inputs) > 2 and (constant_value := pad_node.inputs[2]) is not None:
104100
if constant_value.const_value is None:
105101
return check_result.fail(
106102
f"{constant_value.name} is not a constant/initializer."
107103
)
108104
elif constant_value.const_value.numpy().item() != 0:
109105
return check_result.fail(f"{constant_value.name} must be equal to 0.")
110-
axes = list(range(x_rank))
111-
if len(pnode.inputs) > 3 and (axes := pnode.inputs[3]) is not None:
106+
if len(pad_node.inputs) > 3 and (axes := pad_node.inputs[3]) is not None:
112107
if axes.const_value is None:
113108
return check_result.fail(f"{axes.name} is not a constant/initializer.")
114109
axes_list = [x if x >= 0 else x_rank + x for x in axes.const_value.numpy()]
@@ -126,9 +121,6 @@ def check(self, context, x: ir.Value, pad: ir.Value, conv: ir.Value) -> orp.Matc
126121
class FusePadConv(_FusePadConvBase):
127122
"""Replaces ``Pad(Conv(x))`` with ``Conv(x)``."""
128123

129-
def __init__(self, as_function: bool = False):
130-
super().__init__(name="FusePadConv", as_function=as_function)
131-
132124
def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value:
133125
return op.Conv(
134126
op.Pad(x, _allow_other_inputs=True, _outputs=["pad"]),
@@ -142,18 +134,17 @@ def check(self, context, x: ir.Value, pad: ir.Value, conv: ir.Value) -> orp.Matc
142134
return check_result
143135

144136
# Conv constraints: attributes
145-
cnode = conv.producer()
146-
if (apad := cnode.attributes.get("auto_pad", None)) and apad.as_string() != "NOTSET":
147-
return check_result.fail(f"{cnode.name} auto_pad must be 'NOTSET'.")
137+
conv_node = conv.producer()
138+
if (
139+
apad := conv_node.attributes.get("auto_pad", None)
140+
) and apad.as_string() != "NOTSET":
141+
return check_result.fail(f"{conv_node.name} auto_pad must be 'NOTSET'.")
148142
return check_result
149143

150144

151145
class FusePadConvInteger(FusePadConv):
152146
"""Replaces ``Pad(ConvInteger(x))`` with ``ConvInteger(x)``."""
153147

154-
def __init__(self, as_function: bool = False):
155-
super(FusePadConv, self).__init__(name="FusePadConvInteger", as_function=as_function)
156-
157148
def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value:
158149
return op.ConvInteger(
159150
op.Pad(x, _allow_other_inputs=True, _outputs=["pad"]),
@@ -165,66 +156,68 @@ def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value:
165156
class _NormalizePadFormatBase(orp.RewriteRuleClassBase):
166157
"""Interface to normalize pad attributes in conv nodes."""
167158

159+
@staticmethod
168160
def compute_pads(
169-
self,
170-
input_shape: typing.Sequence[int],
171-
output_shape: typing.Sequence[int],
172-
attributes: dict[str, typing.Sequence[int] | str],
173-
) -> typing.Sequence[int]:
161+
input_shape: Sequence[int],
162+
output_shape: Sequence[int],
163+
attributes: dict[str, Sequence[int] | str],
164+
) -> Sequence[int]:
174165
raise NotImplementedError("Child have to implement this function")
175166

176167
def rewrite(self, op: ir.tape.Tape, conv: ir.Value, **__) -> ir.Value:
177-
cnode = conv.producer()
168+
conv_node = conv.producer()
178169

179170
# Read spatial dimensions and attributes
180-
input_shape = cnode.inputs[0].shape[2:]
181-
output_shape = cnode.outputs[0].shape[2:]
182-
attributes = read_conv_attributes(cnode)
171+
input_shape = conv_node.inputs[0].shape[2:]
172+
output_shape = conv_node.outputs[0].shape[2:]
173+
attributes = read_conv_attributes(conv_node)
183174

184175
# Convert auto_pad mode into an explicit list
185176
pads = self.compute_pads(input_shape, output_shape, attributes)
186177

187178
# Replace auto_pad, forcing to the explicit list
188-
conv_attr: typing.Mapping[str, ir.Attr] = cnode.attributes.copy()
189-
conv_attr["auto_pad"] = ir.convenience.convert_attribute("auto_pad", "NOTSET")
179+
conv_attr = conv_node.attributes.copy()
180+
conv_attr.add(ir.AttrString("auto_pad", "NOTSET"))
190181
if any(x != 0 for x in pads):
191-
conv_attr["pads"] = ir.convenience.convert_attribute("pads", pads)
182+
conv_attr.add(ir.AttrInt64s("pads", pads))
192183

193184
return op.op(
194-
cnode.op_type,
195-
inputs=cnode.inputs,
185+
conv_node.op_type,
186+
inputs=conv_node.inputs,
196187
attributes=conv_attr,
197-
domain=cnode.domain,
198-
name=cnode.name,
188+
domain=conv_node.domain,
189+
name=conv_node.name,
199190
)
200191

201192
def check(self, context, conv: ir.Value, **__) -> orp.MatchResult:
202193
del context
203194
check_result = orp.MatchResult()
204195

205196
# Conv constraints: attributes
206-
cnode = conv.producer()
207-
auto_pad = cnode.attributes.get("auto_pad", None)
197+
conv_node = conv.producer()
198+
auto_pad = conv_node.attributes.get("auto_pad", None)
208199
if auto_pad is None or auto_pad.as_string() == "NOTSET":
209-
return check_result.fail(f"{cnode.name} auto_pad must be different to 'NOTSET'.")
200+
return check_result.fail(
201+
f"{conv_node.name} auto_pad must be different to 'NOTSET'."
202+
)
210203

211204
# Conv constraints: inputs/outputs
212-
if cnode.inputs[0].shape is None:
213-
return check_result.fail(f"Input shapes are not defined on {cnode.name}.")
214-
if cnode.outputs[0].shape is None:
215-
return check_result.fail(f"Output shapes are not defined on {cnode.name}.")
205+
if conv_node.inputs[0].shape is None:
206+
return check_result.fail(f"Input shapes are not defined on {conv_node.name}.")
207+
if conv_node.outputs[0].shape is None:
208+
return check_result.fail(f"Output shapes are not defined on {conv_node.name}.")
216209
return check_result
217210

218211

219212
class NormalizePadFormatConv(_NormalizePadFormatBase):
220213
"""Convert auto_pad attribute into 'NOTSET' in Conv nodes ."""
221214

215+
@staticmethod
222216
def compute_pads(
223-
self,
224-
input_shape: typing.Sequence[int],
225-
output_shape: typing.Sequence[int],
226-
attributes: dict[str, typing.Sequence[int] | str],
227-
) -> typing.Sequence[int]:
217+
input_shape: Sequence[int],
218+
output_shape: Sequence[int],
219+
attributes: dict[str, Sequence[int] | str],
220+
) -> Sequence[int]:
228221
# Compute pads, following auto_pad/pads attributes
229222
if attributes["auto_pad"] in ["NOTSET", "VALID"]:
230223
return attributes.get("pads", [0] * len(input_shape) * 2)

onnxscript/rewriter/fuse_pad_into_conv_test.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
3-
import typing
3+
from __future__ import annotations
4+
45
import unittest
6+
from typing import Mapping, Sequence
57

68
import numpy as np
79
import onnx_ir as ir
@@ -26,7 +28,7 @@ class FusePadConvBaseTest(unittest.TestCase):
2628
def rng(self):
2729
return np.random.default_rng(20250522)
2830

29-
def get_conv_weights(self, shape: typing.Sequence[int], tape: ir.tape.Tape = None):
31+
def get_conv_weights(self, shape: Sequence[int], tape: ir.tape.Tape = None):
3032
w = ir.tensor(self.rng.uniform(-0.5, 0.5, shape).astype("float32"), name="W")
3133
if tape is not None:
3234
w = tape.initializer(w)
@@ -36,11 +38,10 @@ def build_model(
3638
self,
3739
op_type: str,
3840
input_shape: ir.Shape,
39-
weight_shape: typing.Sequence[int],
40-
pad_inputs: typing.Sequence[ir.TensorProtocol | ir.Value | None],
41-
pad_attributes: typing.Mapping[str, ir.Attr] | None = None,
42-
conv_attributes: typing.Mapping[str, ir.Attr] | None = None,
43-
opset_imports: typing.Mapping[str, int] = {"": 20},
41+
weight_shape: Sequence[int],
42+
pad_inputs: Sequence[ir.TensorProtocol | ir.Value | None],
43+
pad_attributes: Mapping[str, ir.Attr] | None = None,
44+
conv_attributes: Mapping[str, ir.Attr] | None = None,
4445
) -> ir.Model:
4546
tape = ir.tape.Tape()
4647
inputs = []
@@ -78,10 +79,10 @@ def build_model(
7879
outputs=[y],
7980
nodes=tape.nodes,
8081
initializers=tape.initializers,
81-
opset_imports=opset_imports,
82+
opset_imports={"": 20},
8283
name="model",
8384
),
84-
ir_version=9,
85+
ir_version=10,
8586
)
8687
onnx_checker.CheckerPass(True)(ir_model)
8788
ir_model = shape_inference.infer_shapes(ir_model)
@@ -218,7 +219,7 @@ def test_unsupported_fuse_pad_into_conv(
218219

219220

220221
class FusePadConvIntegerTest(FusePadConvBaseTest):
221-
def get_conv_weights(self, shape: typing.Sequence[int], tape: ir.tape.Tape = None):
222+
def get_conv_weights(self, shape: Sequence[int], tape: ir.tape.Tape = None):
222223
w = ir.tensor(self.rng.integers(0, 256, shape).astype("uint8"), name="W")
223224
if tape is not None:
224225
w = tape.initializer(w)

0 commit comments

Comments
 (0)