Skip to content

Commit 8fc4326

Browse files
committed
[Rewriter] minor changes (#2301)
1 parent 876a277 commit 8fc4326

File tree

1 file changed

+7
-14
lines changed

1 file changed

+7
-14
lines changed

onnxscript/rewriter/fuse_pad_into_conv.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -68,19 +68,13 @@ def __init__(self, as_function: bool = False):
6868
def rewrite(
6969
self, op: ir.tape.Tape, x: ir.Value, pad: ir.Value, conv: ir.Value
7070
) -> ir.Value:
71-
pad_node = pad.producer()
7271
conv_node = conv.producer()
7372

7473
# Retrieve the padding and axes
7574
x_rank = len(x.shape)
76-
pad_pads = pad_node.inputs[1].const_value.numpy().tolist()
77-
if len(pad_node.inputs) > 3 and (axes := pad_node.inputs[3]) is not None:
78-
axes = [x if x >= 0 else x_rank + x for x in axes.const_value.numpy()]
79-
else:
80-
axes = list(range(x_rank))
8175

82-
# Fulfill pad_pads in every dimension (filling with zero the other ones)
83-
pad_pads = fill_pads_with_axes(pad_pads, axes, x_rank)
76+
# Get computed pads in check()
77+
pad_pads = self._pads_list
8478

8579
# Get only spatial pads
8680
new_pads = pad_pads[2:x_rank] + pad_pads[x_rank + 2 :]
@@ -145,8 +139,9 @@ def check(self, context, x: ir.Value, pad: ir.Value, conv: ir.Value) -> orp.Matc
145139
axes_list = list(range(x_rank))
146140

147141
# Pad constraints: values
148-
pads_list = fill_pads_with_axes(pads.const_value.numpy(), axes_list, x_rank)
149-
if np.any(pads_list[:2] + pads_list[x_rank : x_rank + 2]):
142+
self._pads_list = fill_pads_with_axes(pads.const_value.numpy(), axes_list, x_rank)
143+
if np.any(self._pads_list[:2] + self._pads_list[x_rank : x_rank + 2]):
144+
self._pads_list = None
150145
return check_result.fail(f"{pads.name} must be zero in non-spatial dimensions.")
151146

152147
return check_result
@@ -164,14 +159,12 @@ def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value:
164159

165160
def check(self, context, x: ir.Value, pad: ir.Value, conv: ir.Value) -> orp.MatchResult:
166161
check_result = super().check(context, x, pad, conv)
167-
if check_result.reason:
162+
if not check_result:
168163
return check_result
169164

170165
# Conv constraints: attributes
171166
conv_node = conv.producer()
172-
if (
173-
apad := conv_node.attributes.get("auto_pad", None)
174-
) and apad.as_string() != "NOTSET":
167+
if conv_node.attributes.get_string("auto_pad", "NOTSET") != "NOTSET":
175168
return check_result.fail(
176169
f"{conv_node.name} ({conv_node.op_type}) auto_pad must be 'NOTSET'."
177170
)

0 commit comments

Comments
 (0)