@@ -68,19 +68,13 @@ def __init__(self, as_function: bool = False):
68
68
def rewrite (
69
69
self , op : ir .tape .Tape , x : ir .Value , pad : ir .Value , conv : ir .Value
70
70
) -> ir .Value :
71
- pad_node = pad .producer ()
72
71
conv_node = conv .producer ()
73
72
74
73
# Retrieve the padding and axes
75
74
x_rank = len (x .shape )
76
- pad_pads = pad_node .inputs [1 ].const_value .numpy ().tolist ()
77
- if len (pad_node .inputs ) > 3 and (axes := pad_node .inputs [3 ]) is not None :
78
- axes = [x if x >= 0 else x_rank + x for x in axes .const_value .numpy ()]
79
- else :
80
- axes = list (range (x_rank ))
81
75
82
- # Fulfill pad_pads in every dimension (filling with zero the other ones )
83
- pad_pads = fill_pads_with_axes ( pad_pads , axes , x_rank )
76
+ # Get computed pads in check( )
77
+ pad_pads = self . _pads_list
84
78
85
79
# Get only spatial pads
86
80
new_pads = pad_pads [2 :x_rank ] + pad_pads [x_rank + 2 :]
@@ -145,8 +139,9 @@ def check(self, context, x: ir.Value, pad: ir.Value, conv: ir.Value) -> orp.Matc
145
139
axes_list = list (range (x_rank ))
146
140
147
141
# Pad constraints: values
148
- pads_list = fill_pads_with_axes (pads .const_value .numpy (), axes_list , x_rank )
149
- if np .any (pads_list [:2 ] + pads_list [x_rank : x_rank + 2 ]):
142
+ self ._pads_list = fill_pads_with_axes (pads .const_value .numpy (), axes_list , x_rank )
143
+ if np .any (self ._pads_list [:2 ] + self ._pads_list [x_rank : x_rank + 2 ]):
144
+ self ._pads_list = None
150
145
return check_result .fail (f"{ pads .name } must be zero in non-spatial dimensions." )
151
146
152
147
return check_result
@@ -164,14 +159,12 @@ def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value:
164
159
165
160
def check (self , context , x : ir .Value , pad : ir .Value , conv : ir .Value ) -> orp .MatchResult :
166
161
check_result = super ().check (context , x , pad , conv )
167
- if check_result . reason :
162
+ if not check_result :
168
163
return check_result
169
164
170
165
# Conv constraints: attributes
171
166
conv_node = conv .producer ()
172
- if (
173
- apad := conv_node .attributes .get ("auto_pad" , None )
174
- ) and apad .as_string () != "NOTSET" :
167
+ if conv_node .attributes .get_string ("auto_pad" , "NOTSET" ) != "NOTSET" :
175
168
return check_result .fail (
176
169
f"{ conv_node .name } ({ conv_node .op_type } ) auto_pad must be 'NOTSET'."
177
170
)
0 commit comments