6
6
import numpy as np
7
7
import onnx_ir as ir
8
8
import parameterized
9
- from onnx_ir .passes .common import onnx_checker
9
+ from onnx_ir .passes .common import onnx_checker , shape_inference
10
10
11
11
from onnxscript .rewriter import pattern as orp
12
12
from onnxscript .rewriter import testing
13
13
from onnxscript .rewriter .fuse_pad_into_conv import (
14
14
fuse_pad_into_conv ,
15
15
fuse_pad_into_conv_rule_set ,
16
+ normalize_pad_format_conv ,
16
17
)
17
18
18
19
@@ -83,22 +84,24 @@ def build_model(
83
84
ir_version = 9 ,
84
85
)
85
86
onnx_checker .CheckerPass (True )(ir_model )
87
+ ir_model = shape_inference .infer_shapes (ir_model )
86
88
return ir_model
87
89
88
90
89
91
class FusePadConvTest (FusePadConvBaseTest ):
90
92
@parameterized .parameterized .expand (
91
93
[
92
- (pad_pads , const_value , axes , conv_pads )
93
- for pad_pads , axes , conv_pads in [
94
- ([0 , 0 , 2 , 2 , 0 , 0 , 2 , 2 ], None , None ),
95
- ([0 , 2 , 2 , 0 , 2 , 2 ], ir .tensor ([1 , - 2 , - 1 ], name = "axes" ), [2 , 0 , 2 , 0 ]),
96
- ([1 , 1 , 1 , 1 ], ir .tensor ([- 2 , 3 ], name = "axes" ), [0 , 1 , 0 , 1 ]),
94
+ (pad_pads , const_value , axes , conv_pads , conv_auto_pad )
95
+ for pad_pads , axes , conv_pads , conv_auto_pad in [
96
+ ([0 , 0 , 2 , 2 , 0 , 0 , 2 , 2 ], None , None , None ),
97
+ ([0 , 2 , 2 , 0 , 2 , 2 ], ir .tensor ([1 , - 2 , - 1 ], name = "axes" ), [2 , 0 , 2 , 0 ], None ),
98
+ ([1 , 1 , 1 , 1 ], ir .tensor ([- 2 , 3 ], name = "axes" ), [0 , 1 , 0 , 1 ], None ),
99
+ ([1 , 3 , 1 , 3 ], ir .tensor ([3 , 2 ], name = "axes" ), None , "VALID" ),
97
100
]
98
101
for const_value in [None , 0.0 ]
99
102
]
100
103
)
101
- def test_fuse_pad_into_conv (self , pad_pads , const_value , axes , conv_pads ):
104
+ def test_fuse_pad_into_conv (self , pad_pads , const_value , axes , conv_pads , conv_auto_pad ):
102
105
pad_inputs = [ir .tensor (pad_pads , name = "pads" )]
103
106
if const_value is not None or axes is not None :
104
107
pad_inputs .append (const_value )
@@ -109,15 +112,15 @@ def test_fuse_pad_into_conv(self, pad_pads, const_value, axes, conv_pads):
109
112
input_shape = ir .Shape (("N" , 32 , 14 , 16 )),
110
113
weight_shape = (10 , 32 , 3 , 3 ),
111
114
pad_inputs = pad_inputs ,
112
- conv_attributes = {"pads" : conv_pads },
115
+ conv_attributes = {"pads" : conv_pads , "auto_pad" : conv_auto_pad },
113
116
)
114
117
updated_model = _clone_model (base_model )
115
118
116
119
# Apply rule
117
120
count = fuse_pad_into_conv_rule_set ().apply_to_model (updated_model )
118
121
119
122
# Check that Pad was fused
120
- self .assertEqual (count , 1 )
123
+ self .assertEqual (count , 1 if conv_auto_pad is None else 2 )
121
124
self .assertEqual (updated_model .graph .num_nodes (), 1 )
122
125
onnx_checker .CheckerPass (True )(updated_model )
123
126
@@ -223,16 +226,19 @@ def get_conv_weights(self, shape: typing.Sequence[int], tape: ir.tape.Tape = Non
223
226
224
227
@parameterized .parameterized .expand (
225
228
[
226
- (pad_pads , const_value , axes , conv_pads )
227
- for pad_pads , axes , conv_pads in [
228
- ([0 , 0 , 3 , 2 , 0 , 0 , 1 , 4 ], None , [1 , 1 , 1 , 1 ]),
229
- ([2 , 2 , 0 , 2 , 2 , 0 ], ir .tensor ([- 2 , - 1 , 1 ], name = "axes" ), None ),
230
- ([1 , 2 , 2 , 1 ], ir .tensor ([- 1 , 2 ], name = "axes" ), [0 , 1 , 0 , 1 ]),
229
+ (pad_pads , const_value , axes , conv_pads , conv_auto_pad )
230
+ for pad_pads , axes , conv_pads , conv_auto_pad in [
231
+ ([0 , 0 , 3 , 2 , 0 , 0 , 1 , 4 ], None , [1 , 1 , 1 , 1 ], None ),
232
+ ([2 , 2 , 0 , 2 , 2 , 0 ], ir .tensor ([- 2 , - 1 , 1 ], name = "axes" ), None , None ),
233
+ ([1 , 2 , 2 , 1 ], ir .tensor ([- 1 , 2 ], name = "axes" ), [0 , 1 , 0 , 1 ], None ),
234
+ ([3 , 3 ], ir .tensor ([2 ], name = "axes" ), None , "SAME_UPPER" ),
231
235
]
232
236
for const_value in [None , ir .tensor (np .array ([0 ], "uint8" ), name = "const_value" )]
233
237
]
234
238
)
235
- def test_fuse_pad_into_conv_integer (self , pad_pads , const_value , axes , conv_pads ):
239
+ def test_fuse_pad_into_conv_integer (
240
+ self , pad_pads , const_value , axes , conv_pads , conv_auto_pad
241
+ ):
236
242
pad_inputs = [ir .tensor (pad_pads , name = "pads" )]
237
243
if const_value is not None or axes is not None :
238
244
pad_inputs .append (const_value )
@@ -243,15 +249,15 @@ def test_fuse_pad_into_conv_integer(self, pad_pads, const_value, axes, conv_pads
243
249
input_shape = ir .Shape (("N" , 24 , 19 , 23 )),
244
250
weight_shape = (8 , 24 , 3 , 3 ),
245
251
pad_inputs = pad_inputs ,
246
- conv_attributes = {"pads" : conv_pads },
252
+ conv_attributes = {"pads" : conv_pads , "auto_pad" : conv_auto_pad },
247
253
)
248
254
updated_model = _clone_model (base_model )
249
255
250
256
# Apply rule
251
257
count = fuse_pad_into_conv_rule_set ().apply_to_model (updated_model )
252
258
253
259
# Check that Pad was fused
254
- self .assertEqual (count , 1 )
260
+ self .assertEqual (count , 1 if conv_auto_pad is None else 2 )
255
261
self .assertEqual (updated_model .graph .num_nodes (), 1 )
256
262
onnx_checker .CheckerPass (True )(updated_model )
257
263
@@ -260,5 +266,67 @@ def test_fuse_pad_into_conv_integer(self, pad_pads, const_value, axes, conv_pads
260
266
testing .assert_numerically_equal (base_model , updated_model , (inputs ,), atol = 0 , rtol = 0 )
261
267
262
268
269
+ class NormalizePadFormatTest (FusePadConvBaseTest ):
270
+ @parameterized .parameterized .expand (
271
+ [
272
+ (strides , kernel_shape , auto_pad )
273
+ for strides , kernel_shape in [((2 , 3 ), (1 , 4 )), ((2 , 1 ), (5 , 2 ))]
274
+ for auto_pad in ["SAME_UPPER" , "SAME_LOWER" , "VALID" ]
275
+ ]
276
+ )
277
+ def test_normalize_pad_format (self , strides , kernel_shape , auto_pad ):
278
+ pad_inputs = [
279
+ ir .tensor ([1 , 1 , 1 , 1 ], name = "pads" ),
280
+ None ,
281
+ ir .tensor ([2 , 3 ], name = "axes" ),
282
+ ]
283
+ base_model = self .build_model (
284
+ op_type = "Conv" ,
285
+ input_shape = ir .Shape (("N" , 32 , 22 , 27 )),
286
+ weight_shape = (32 , 32 , * kernel_shape ),
287
+ pad_inputs = pad_inputs ,
288
+ conv_attributes = {
289
+ "strides" : strides ,
290
+ "auto_pad" : auto_pad ,
291
+ "kernel_shape" : kernel_shape ,
292
+ },
293
+ )
294
+ updated_model = _clone_model (base_model )
295
+
296
+ # Apply rule
297
+ count = fuse_pad_into_conv_rule_set ().apply_to_model (updated_model )
298
+
299
+ # Check that Pad was fused
300
+ self .assertEqual (count , 2 )
301
+ self .assertEqual (updated_model .graph .num_nodes (), 1 )
302
+ onnx_checker .CheckerPass (True )(updated_model )
303
+
304
+ # Check inference
305
+ inputs = self .rng .random ((1 , 32 , 22 , 27 ), dtype = "float32" )
306
+ testing .assert_numerically_equal (base_model , updated_model , (inputs ,), atol = 0 , rtol = 0 )
307
+
308
+ def test_unsupported_normalize_pad_format (self ):
309
+ base_model = self .build_model (
310
+ op_type = "Conv" ,
311
+ input_shape = ir .Shape (("N" , 32 , 14 )),
312
+ weight_shape = (32 , 11 , 4 ),
313
+ pad_inputs = [ir .tensor ([0 , 0 , 0 , 0 , 0 , 0 ], name = "pads" )],
314
+ conv_attributes = {"auto_pad" : "VALID" },
315
+ )
316
+ # Drop convolutional input shape
317
+ base_model .graph [0 ].outputs [0 ].shape = None
318
+ onnx_checker .CheckerPass (True )(base_model )
319
+
320
+ # Apply rule and check it was not applied
321
+ tracer = orp .MatchingTracer ()
322
+ count = normalize_pad_format_conv .apply_to_model (base_model , tracer = tracer )
323
+ self .assertEqual (count , 0 )
324
+
325
+ # Check that the error message is the expected one
326
+ tracer_match = tracer .best_matches_map [normalize_pad_format_conv ][0 ]
327
+ self .assertEqual (tracer_match .status .value , orp .MatchStatus .CONDITION_FAILED )
328
+ self .assertRegex (tracer_match .match_result .reason , "Input shapes are not defined" )
329
+
330
+
263
331
if __name__ == "__main__" :
264
332
unittest .main ()
0 commit comments