Skip to content

Commit 19b0418

Browse files
committed
[Rewriter]: introduce normalize_pad_format_conv (#2301)
Convert 'auto_pad' attribute into a list of explicit pads.
1 parent 2e35259 commit 19b0418

File tree

2 files changed

+206
-17
lines changed

2 files changed

+206
-17
lines changed

onnxscript/rewriter/fuse_pad_into_conv.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
"""Fuses Pad nodes into preceding nodes. Supported fusion patterns:
44
- Pad ∘ Conv -> Conv
55
- Pad ∘ ConvInteger -> ConvInteger
6+
7+
To make some rules possible, we implicitly transform `auto_pad` attribute into its explicit list.
68
"""
79

810
import typing
@@ -29,6 +31,26 @@ def fill_pads_with_axes(
2931
return new_pads
3032

3133

34+
def read_conv_attributes(ir_conv: ir.Node) -> dict[str, typing.Sequence[int] | str]:
35+
# Read attributes
36+
attributes = {}
37+
if (kernel_shape := ir_conv.attributes.get("kernel_shape", None)) is not None:
38+
attributes["kernel_shape"] = kernel_shape.as_ints()
39+
else:
40+
attributes["kernel_shape"] = ir_conv.inputs[1].shape[2:]
41+
if (strides := ir_conv.attributes.get("strides", None)) is not None:
42+
attributes["strides"] = strides.as_ints()
43+
else:
44+
attributes["strides"] = [1] * len(ir_conv.inputs[0].shape[2:])
45+
if (auto_pad := ir_conv.attributes.get("auto_pad", None)) is not None:
46+
attributes["auto_pad"] = auto_pad.as_string()
47+
else:
48+
attributes["auto_pad"] = "NOTSET"
49+
if (pads := ir_conv.attributes.get("pads", None)) is not None:
50+
attributes["pads"] = pads.as_ints()
51+
return attributes
52+
53+
3254
class _FusePadConvBase(orp.RewriteRuleClassBase):
3355
"""Interface for PadConv nodes fusion."""
3456

@@ -145,6 +167,103 @@ def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value:
145167
)
146168

147169

170+
class _NormalizePadFormatBase(orp.RewriteRuleClassBase):
171+
"""Interface to normalize pad attributes in conv nodes."""
172+
173+
def compute_pads(
174+
self,
175+
input_shape: typing.Sequence[int],
176+
output_shape: typing.Sequence[int],
177+
attributes: dict[str, typing.Sequence[int] | str],
178+
) -> typing.Sequence[int]:
179+
raise NotImplementedError("Child have to implement this function")
180+
181+
def rewrite(self, op: ir.tape.Tape, conv: ir.Value, **__) -> ir.Value:
182+
cnode = conv.producer()
183+
184+
# Read spatial dimensions and attributes
185+
input_shape = cnode.inputs[0].shape[2:]
186+
output_shape = cnode.outputs[0].shape[2:]
187+
attributes = read_conv_attributes(cnode)
188+
189+
# Convert auto_pad mode into an explicit list
190+
pads = self.compute_pads(input_shape, output_shape, attributes)
191+
192+
# Replace auto_pad, forcing to the explicit list
193+
conv_attr: typing.Mapping[str, ir.Attr] = cnode.attributes.copy()
194+
conv_attr["auto_pad"] = ir.convenience.convert_attribute("auto_pad", "NOTSET")
195+
if any(x != 0 for x in pads):
196+
conv_attr["pads"] = ir.convenience.convert_attribute("pads", pads)
197+
198+
return op.op(
199+
cnode.op_type,
200+
inputs=cnode.inputs,
201+
attributes=conv_attr,
202+
domain=cnode.domain,
203+
name=cnode.name,
204+
)
205+
206+
def check(self, context, conv: ir.Value, **__) -> orp.MatchResult:
207+
del context
208+
check_result = orp.MatchResult()
209+
210+
# Conv constraints: attributes
211+
cnode = conv.producer()
212+
auto_pad = cnode.attributes.get("auto_pad", None)
213+
if auto_pad is None or auto_pad.as_string() == "NOTSET":
214+
return check_result.fail(f"{cnode.name} auto_pad must be different to 'NOTSET'.")
215+
216+
# Conv constraints: inputs/outputs
217+
if cnode.inputs[0].shape is None:
218+
return check_result.fail(f"Input shapes are not defined on {cnode.name}.")
219+
if cnode.outputs[0].shape is None:
220+
return check_result.fail(f"Output shapes are not defined on {cnode.name}.")
221+
return check_result
222+
223+
224+
class NormalizePadFormatConv(_NormalizePadFormatBase):
225+
"""Convert auto_pad attribute into 'NOTSET' in Conv nodes ."""
226+
227+
def compute_pads(
228+
self,
229+
input_shape: typing.Sequence[int],
230+
output_shape: typing.Sequence[int],
231+
attributes: dict[str, typing.Sequence[int] | str],
232+
) -> typing.Sequence[int]:
233+
# Compute pads, following auto_pad/pads attributes
234+
if attributes["auto_pad"] in ["NOTSET", "VALID"]:
235+
return attributes.get("pads", [0] * len(input_shape) * 2)
236+
237+
bottom_pads, top_pads = [], []
238+
kernel_shape, strides = attributes["kernel_shape"], attributes["strides"]
239+
for x, y, k, s in zip(input_shape, output_shape, kernel_shape, strides):
240+
# Compute the output shape and the total padding to apply
241+
total_pads = max(0, (y - 1) * s + k - x)
242+
243+
# Depending of mode, apply the padding to the upper or lower part
244+
pad1 = total_pads // 2
245+
pad2 = total_pads - pad1
246+
if attributes["auto_pad"] == "SAME_UPPER":
247+
bottom_pads.append(pad1)
248+
top_pads.append(pad2)
249+
else:
250+
top_pads.append(pad1)
251+
bottom_pads.append(pad2)
252+
return bottom_pads + top_pads
253+
254+
def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value:
255+
return op.Conv(x, _allow_other_inputs=True, _outputs=["conv"])
256+
257+
258+
class NormalizePadFormatConvInteger(NormalizePadFormatConv):
259+
"""Convert auto_pad attribute into 'NOTSET' in ConvInteger nodes ."""
260+
261+
def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value:
262+
return op.ConvInteger(x, _allow_other_inputs=True, _outputs=["conv"])
263+
264+
265+
normalize_pad_format_conv = NormalizePadFormatConv.rule()
266+
normalize_pad_format_conv_integer = NormalizePadFormatConvInteger.rule()
148267
fuse_pad_into_conv = FusePadConv.rule()
149268
fuse_pad_into_conv_integer = FusePadConvInteger.rule()
150269

@@ -159,6 +278,8 @@ def fuse_pad_into_conv_rule_set() -> orp.RewriteRuleSet:
159278
"""
160279
return orp.RewriteRuleSet(
161280
[
281+
normalize_pad_format_conv,
282+
normalize_pad_format_conv_integer,
162283
fuse_pad_into_conv,
163284
fuse_pad_into_conv_integer,
164285
]

onnxscript/rewriter/fuse_pad_into_conv_test.py

Lines changed: 85 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66
import numpy as np
77
import onnx_ir as ir
88
import parameterized
9-
from onnx_ir.passes.common import onnx_checker
9+
from onnx_ir.passes.common import onnx_checker, shape_inference
1010

1111
from onnxscript.rewriter import pattern as orp
1212
from onnxscript.rewriter import testing
1313
from onnxscript.rewriter.fuse_pad_into_conv import (
1414
fuse_pad_into_conv,
1515
fuse_pad_into_conv_rule_set,
16+
normalize_pad_format_conv,
1617
)
1718

1819

@@ -83,22 +84,24 @@ def build_model(
8384
ir_version=9,
8485
)
8586
onnx_checker.CheckerPass(True)(ir_model)
87+
ir_model = shape_inference.infer_shapes(ir_model)
8688
return ir_model
8789

8890

8991
class FusePadConvTest(FusePadConvBaseTest):
9092
@parameterized.parameterized.expand(
9193
[
92-
(pad_pads, const_value, axes, conv_pads)
93-
for pad_pads, axes, conv_pads in [
94-
([0, 0, 2, 2, 0, 0, 2, 2], None, None),
95-
([0, 2, 2, 0, 2, 2], ir.tensor([1, -2, -1], name="axes"), [2, 0, 2, 0]),
96-
([1, 1, 1, 1], ir.tensor([-2, 3], name="axes"), [0, 1, 0, 1]),
94+
(pad_pads, const_value, axes, conv_pads, conv_auto_pad)
95+
for pad_pads, axes, conv_pads, conv_auto_pad in [
96+
([0, 0, 2, 2, 0, 0, 2, 2], None, None, None),
97+
([0, 2, 2, 0, 2, 2], ir.tensor([1, -2, -1], name="axes"), [2, 0, 2, 0], None),
98+
([1, 1, 1, 1], ir.tensor([-2, 3], name="axes"), [0, 1, 0, 1], None),
99+
([1, 3, 1, 3], ir.tensor([3, 2], name="axes"), None, "VALID"),
97100
]
98101
for const_value in [None, 0.0]
99102
]
100103
)
101-
def test_fuse_pad_into_conv(self, pad_pads, const_value, axes, conv_pads):
104+
def test_fuse_pad_into_conv(self, pad_pads, const_value, axes, conv_pads, conv_auto_pad):
102105
pad_inputs = [ir.tensor(pad_pads, name="pads")]
103106
if const_value is not None or axes is not None:
104107
pad_inputs.append(const_value)
@@ -109,15 +112,15 @@ def test_fuse_pad_into_conv(self, pad_pads, const_value, axes, conv_pads):
109112
input_shape=ir.Shape(("N", 32, 14, 16)),
110113
weight_shape=(10, 32, 3, 3),
111114
pad_inputs=pad_inputs,
112-
conv_attributes={"pads": conv_pads},
115+
conv_attributes={"pads": conv_pads, "auto_pad": conv_auto_pad},
113116
)
114117
updated_model = _clone_model(base_model)
115118

116119
# Apply rule
117120
count = fuse_pad_into_conv_rule_set().apply_to_model(updated_model)
118121

119122
# Check that Pad was fused
120-
self.assertEqual(count, 1)
123+
self.assertEqual(count, 1 if conv_auto_pad is None else 2)
121124
self.assertEqual(updated_model.graph.num_nodes(), 1)
122125
onnx_checker.CheckerPass(True)(updated_model)
123126

@@ -223,16 +226,19 @@ def get_conv_weights(self, shape: typing.Sequence[int], tape: ir.tape.Tape = Non
223226

224227
@parameterized.parameterized.expand(
225228
[
226-
(pad_pads, const_value, axes, conv_pads)
227-
for pad_pads, axes, conv_pads in [
228-
([0, 0, 3, 2, 0, 0, 1, 4], None, [1, 1, 1, 1]),
229-
([2, 2, 0, 2, 2, 0], ir.tensor([-2, -1, 1], name="axes"), None),
230-
([1, 2, 2, 1], ir.tensor([-1, 2], name="axes"), [0, 1, 0, 1]),
229+
(pad_pads, const_value, axes, conv_pads, conv_auto_pad)
230+
for pad_pads, axes, conv_pads, conv_auto_pad in [
231+
([0, 0, 3, 2, 0, 0, 1, 4], None, [1, 1, 1, 1], None),
232+
([2, 2, 0, 2, 2, 0], ir.tensor([-2, -1, 1], name="axes"), None, None),
233+
([1, 2, 2, 1], ir.tensor([-1, 2], name="axes"), [0, 1, 0, 1], None),
234+
([3, 3], ir.tensor([2], name="axes"), None, "SAME_UPPER"),
231235
]
232236
for const_value in [None, ir.tensor(np.array([0], "uint8"), name="const_value")]
233237
]
234238
)
235-
def test_fuse_pad_into_conv_integer(self, pad_pads, const_value, axes, conv_pads):
239+
def test_fuse_pad_into_conv_integer(
240+
self, pad_pads, const_value, axes, conv_pads, conv_auto_pad
241+
):
236242
pad_inputs = [ir.tensor(pad_pads, name="pads")]
237243
if const_value is not None or axes is not None:
238244
pad_inputs.append(const_value)
@@ -243,15 +249,15 @@ def test_fuse_pad_into_conv_integer(self, pad_pads, const_value, axes, conv_pads
243249
input_shape=ir.Shape(("N", 24, 19, 23)),
244250
weight_shape=(8, 24, 3, 3),
245251
pad_inputs=pad_inputs,
246-
conv_attributes={"pads": conv_pads},
252+
conv_attributes={"pads": conv_pads, "auto_pad": conv_auto_pad},
247253
)
248254
updated_model = _clone_model(base_model)
249255

250256
# Apply rule
251257
count = fuse_pad_into_conv_rule_set().apply_to_model(updated_model)
252258

253259
# Check that Pad was fused
254-
self.assertEqual(count, 1)
260+
self.assertEqual(count, 1 if conv_auto_pad is None else 2)
255261
self.assertEqual(updated_model.graph.num_nodes(), 1)
256262
onnx_checker.CheckerPass(True)(updated_model)
257263

@@ -260,5 +266,67 @@ def test_fuse_pad_into_conv_integer(self, pad_pads, const_value, axes, conv_pads
260266
testing.assert_numerically_equal(base_model, updated_model, (inputs,), atol=0, rtol=0)
261267

262268

269+
class NormalizePadFormatTest(FusePadConvBaseTest):
270+
@parameterized.parameterized.expand(
271+
[
272+
(strides, kernel_shape, auto_pad)
273+
for strides, kernel_shape in [((2, 3), (1, 4)), ((2, 1), (5, 2))]
274+
for auto_pad in ["SAME_UPPER", "SAME_LOWER", "VALID"]
275+
]
276+
)
277+
def test_normalize_pad_format(self, strides, kernel_shape, auto_pad):
278+
pad_inputs = [
279+
ir.tensor([1, 1, 1, 1], name="pads"),
280+
None,
281+
ir.tensor([2, 3], name="axes"),
282+
]
283+
base_model = self.build_model(
284+
op_type="Conv",
285+
input_shape=ir.Shape(("N", 32, 22, 27)),
286+
weight_shape=(32, 32, *kernel_shape),
287+
pad_inputs=pad_inputs,
288+
conv_attributes={
289+
"strides": strides,
290+
"auto_pad": auto_pad,
291+
"kernel_shape": kernel_shape,
292+
},
293+
)
294+
updated_model = _clone_model(base_model)
295+
296+
# Apply rule
297+
count = fuse_pad_into_conv_rule_set().apply_to_model(updated_model)
298+
299+
# Check that Pad was fused
300+
self.assertEqual(count, 2)
301+
self.assertEqual(updated_model.graph.num_nodes(), 1)
302+
onnx_checker.CheckerPass(True)(updated_model)
303+
304+
# Check inference
305+
inputs = self.rng.random((1, 32, 22, 27), dtype="float32")
306+
testing.assert_numerically_equal(base_model, updated_model, (inputs,), atol=0, rtol=0)
307+
308+
def test_unsupported_normalize_pad_format(self):
309+
base_model = self.build_model(
310+
op_type="Conv",
311+
input_shape=ir.Shape(("N", 32, 14)),
312+
weight_shape=(32, 11, 4),
313+
pad_inputs=[ir.tensor([0, 0, 0, 0, 0, 0], name="pads")],
314+
conv_attributes={"auto_pad": "VALID"},
315+
)
316+
# Drop convolutional input shape
317+
base_model.graph[0].outputs[0].shape = None
318+
onnx_checker.CheckerPass(True)(base_model)
319+
320+
# Apply rule and check it was not applied
321+
tracer = orp.MatchingTracer()
322+
count = normalize_pad_format_conv.apply_to_model(base_model, tracer=tracer)
323+
self.assertEqual(count, 0)
324+
325+
# Check that the error message is the expected one
326+
tracer_match = tracer.best_matches_map[normalize_pad_format_conv][0]
327+
self.assertEqual(tracer_match.status.value, orp.MatchStatus.CONDITION_FAILED)
328+
self.assertRegex(tracer_match.match_result.reason, "Input shapes are not defined")
329+
330+
263331
if __name__ == "__main__":
264332
unittest.main()

0 commit comments

Comments
 (0)