1
1
# Copyright (c) Microsoft Corporation.
2
2
# Licensed under the MIT License.
3
3
"""Fuses Pad nodes into preceding nodes. Supported fusion patterns:
4
- - Pad ∘ Conv -> Conv
5
- - Pad ∘ ConvInteger -> ConvInteger
4
+ - Conv ∘ Pad -> Conv
5
+ - ConvInteger ∘ Pad -> ConvInteger
6
6
7
7
To make some rules possible, we implicitly transform `auto_pad` attribute into its explicit list.
8
8
"""
9
9
10
- import typing
10
+ from __future__ import annotations
11
+
12
+ from typing import List , Sequence
11
13
12
14
import numpy as np
13
15
import onnx_ir as ir
14
16
15
17
from onnxscript .rewriter import pattern as orp
16
18
17
19
18
- def fill_pads_with_axes (
19
- pads : typing .Sequence [int ], axes : typing .Sequence [int ], rank : int
20
- ) -> typing .List [int ]:
20
+ def fill_pads_with_axes (pads : Sequence [int ], axes : Sequence [int ], rank : int ) -> List [int ]:
21
21
new_pads = [0 ] * 2 * rank
22
22
N = len (axes )
23
23
for start_idx , axis in enumerate (axes ):
@@ -26,43 +26,39 @@ def fill_pads_with_axes(
26
26
return new_pads
27
27
28
28
29
- def read_conv_attributes (ir_conv : ir .Node ) -> dict [str , typing . Sequence [int ] | str ]:
29
+ def read_conv_attributes (ir_conv : ir .Node ) -> dict [str , Sequence [int ] | str ]:
30
30
# Read attributes
31
31
attributes = {}
32
- if (kernel_shape := ir_conv .attributes .get ("kernel_shape" , None )) is not None :
33
- attributes ["kernel_shape" ] = kernel_shape .as_ints ()
34
- else :
35
- attributes ["kernel_shape" ] = ir_conv .inputs [1 ].shape [2 :]
36
- if (strides := ir_conv .attributes .get ("strides" , None )) is not None :
37
- attributes ["strides" ] = strides .as_ints ()
38
- else :
39
- attributes ["strides" ] = [1 ] * len (ir_conv .inputs [0 ].shape [2 :])
40
- if (auto_pad := ir_conv .attributes .get ("auto_pad" , None )) is not None :
41
- attributes ["auto_pad" ] = auto_pad .as_string ()
42
- else :
43
- attributes ["auto_pad" ] = "NOTSET"
44
- if (pads := ir_conv .attributes .get ("pads" , None )) is not None :
45
- attributes ["pads" ] = pads .as_ints ()
32
+ ir_attributes = ir_conv .attributes
33
+ attributes ["kernel_shape" ] = ir_attributes .get_ints (
34
+ "kernel_shape" , ir_conv .inputs [1 ].shape [2 :]
35
+ )
36
+ attributes ["strides" ] = ir_attributes .get_ints (
37
+ "strides" , [1 ] * len (ir_conv .inputs [0 ].shape [2 :])
38
+ )
39
+ attributes ["auto_pad" ] = ir_attributes .get_string ("auto_pad" , "NOTSET" )
40
+ if "pads" in ir_attributes :
41
+ attributes ["pads" ] = ir_attributes .get_ints ("pads" )
46
42
return attributes
47
43
48
44
49
45
class _FusePadConvBase (orp .RewriteRuleClassBase ):
50
46
"""Interface for PadConv nodes fusion."""
51
47
52
- def __init__ (self , name : str , as_function : bool = False ):
48
+ def __init__ (self , as_function : bool = False ):
53
49
# Remove nodes is set to False to remove unused nodes after the rewrite.
54
- super ().__init__ (name = name , remove_nodes = False , as_function = as_function )
50
+ super ().__init__ (remove_nodes = False , as_function = as_function )
55
51
56
52
def rewrite (
57
53
self , op : ir .tape .Tape , x : ir .Value , pad : ir .Value , conv : ir .Value
58
54
) -> ir .Value :
59
- pnode = pad .producer ()
60
- cnode = conv .producer ()
55
+ pad_node = pad .producer ()
56
+ conv_node = conv .producer ()
61
57
62
58
# Retrieve the padding and axes
63
59
x_rank = len (x .shape )
64
- pad_pads = pnode .inputs [1 ].const_value .numpy ().tolist ()
65
- if len (pnode .inputs ) > 3 and (axes := pnode .inputs [3 ]) is not None :
60
+ pad_pads = pad_node .inputs [1 ].const_value .numpy ().tolist ()
61
+ if len (pad_node .inputs ) > 3 and (axes := pad_node .inputs [3 ]) is not None :
66
62
axes = [x if x >= 0 else x_rank + x for x in axes .const_value .numpy ()]
67
63
else :
68
64
axes = list (range (x_rank ))
@@ -74,41 +70,40 @@ def rewrite(
74
70
new_pads = pad_pads [2 :x_rank ] + pad_pads [x_rank + 2 :]
75
71
76
72
# Replace conv pads = new + old
77
- conv_attr : typing . Mapping [ str , ir . Attr ] = cnode .attributes .copy ()
73
+ conv_attr = conv_node .attributes .copy ()
78
74
if "pads" in conv_attr :
79
75
new_pads = [x + y for x , y in zip (conv_attr ["pads" ].as_ints (), new_pads )]
80
- conv_attr [ "pads" ] = ir .convenience . convert_attribute ("pads" , new_pads )
76
+ conv_attr . add ( ir .AttrInt64s ("pads" , new_pads ) )
81
77
82
78
return op .op (
83
- cnode .op_type ,
84
- inputs = (x , * cnode .inputs [1 :]),
79
+ conv_node .op_type ,
80
+ inputs = (x , * conv_node .inputs [1 :]),
85
81
attributes = conv_attr ,
86
- domain = cnode .domain ,
87
- name = cnode .name ,
82
+ domain = conv_node .domain ,
83
+ name = conv_node .name ,
88
84
)
89
85
90
86
def check (self , context , x : ir .Value , pad : ir .Value , conv : ir .Value ) -> orp .MatchResult :
91
87
del context # Unused
92
88
check_result = orp .MatchResult ()
93
- pnode = pad .producer ()
89
+ pad_node = pad .producer ()
94
90
x_rank = len (x .shape )
95
91
96
92
# Pad constraints: attributes
97
- if (mode := pnode .attributes .get ("mode" , None )) and mode .as_string () != "constant" :
98
- return check_result .fail (f"{ pnode .name } mode must be 'constant'." )
93
+ if (mode := pad_node .attributes .get ("mode" , None )) and mode .as_string () != "constant" :
94
+ return check_result .fail (f"{ pad_node .name } mode must be 'constant'." )
99
95
100
96
# Pad constraints: inputs
101
- if (pads := pnode .inputs [1 ]).const_value is None :
97
+ if (pads := pad_node .inputs [1 ]).const_value is None :
102
98
return check_result .fail (f"{ pads .name } is not a constant/initializer." )
103
- if len (pnode .inputs ) > 2 and (constant_value := pnode .inputs [2 ]) is not None :
99
+ if len (pad_node .inputs ) > 2 and (constant_value := pad_node .inputs [2 ]) is not None :
104
100
if constant_value .const_value is None :
105
101
return check_result .fail (
106
102
f"{ constant_value .name } is not a constant/initializer."
107
103
)
108
104
elif constant_value .const_value .numpy ().item () != 0 :
109
105
return check_result .fail (f"{ constant_value .name } must be equal to 0." )
110
- axes = list (range (x_rank ))
111
- if len (pnode .inputs ) > 3 and (axes := pnode .inputs [3 ]) is not None :
106
+ if len (pad_node .inputs ) > 3 and (axes := pad_node .inputs [3 ]) is not None :
112
107
if axes .const_value is None :
113
108
return check_result .fail (f"{ axes .name } is not a constant/initializer." )
114
109
axes_list = [x if x >= 0 else x_rank + x for x in axes .const_value .numpy ()]
@@ -126,9 +121,6 @@ def check(self, context, x: ir.Value, pad: ir.Value, conv: ir.Value) -> orp.Matc
126
121
class FusePadConv (_FusePadConvBase ):
127
122
"""Replaces ``Pad(Conv(x))`` with ``Conv(x)``."""
128
123
129
- def __init__ (self , as_function : bool = False ):
130
- super ().__init__ (name = "FusePadConv" , as_function = as_function )
131
-
132
124
def pattern (self , op : ir .tape .Tape , x : ir .Value ) -> ir .Value :
133
125
return op .Conv (
134
126
op .Pad (x , _allow_other_inputs = True , _outputs = ["pad" ]),
@@ -142,18 +134,17 @@ def check(self, context, x: ir.Value, pad: ir.Value, conv: ir.Value) -> orp.Matc
142
134
return check_result
143
135
144
136
# Conv constraints: attributes
145
- cnode = conv .producer ()
146
- if (apad := cnode .attributes .get ("auto_pad" , None )) and apad .as_string () != "NOTSET" :
147
- return check_result .fail (f"{ cnode .name } auto_pad must be 'NOTSET'." )
137
+ conv_node = conv .producer ()
138
+ if (
139
+ apad := conv_node .attributes .get ("auto_pad" , None )
140
+ ) and apad .as_string () != "NOTSET" :
141
+ return check_result .fail (f"{ conv_node .name } auto_pad must be 'NOTSET'." )
148
142
return check_result
149
143
150
144
151
145
class FusePadConvInteger (FusePadConv ):
152
146
"""Replaces ``Pad(ConvInteger(x))`` with ``ConvInteger(x)``."""
153
147
154
- def __init__ (self , as_function : bool = False ):
155
- super (FusePadConv , self ).__init__ (name = "FusePadConvInteger" , as_function = as_function )
156
-
157
148
def pattern (self , op : ir .tape .Tape , x : ir .Value ) -> ir .Value :
158
149
return op .ConvInteger (
159
150
op .Pad (x , _allow_other_inputs = True , _outputs = ["pad" ]),
@@ -165,66 +156,68 @@ def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value:
165
156
class _NormalizePadFormatBase (orp .RewriteRuleClassBase ):
166
157
"""Interface to normalize pad attributes in conv nodes."""
167
158
159
+ @staticmethod
168
160
def compute_pads (
169
- self ,
170
- input_shape : typing .Sequence [int ],
171
- output_shape : typing .Sequence [int ],
172
- attributes : dict [str , typing .Sequence [int ] | str ],
173
- ) -> typing .Sequence [int ]:
161
+ input_shape : Sequence [int ],
162
+ output_shape : Sequence [int ],
163
+ attributes : dict [str , Sequence [int ] | str ],
164
+ ) -> Sequence [int ]:
174
165
raise NotImplementedError ("Child have to implement this function" )
175
166
176
167
def rewrite (self , op : ir .tape .Tape , conv : ir .Value , ** __ ) -> ir .Value :
177
- cnode = conv .producer ()
168
+ conv_node = conv .producer ()
178
169
179
170
# Read spatial dimensions and attributes
180
- input_shape = cnode .inputs [0 ].shape [2 :]
181
- output_shape = cnode .outputs [0 ].shape [2 :]
182
- attributes = read_conv_attributes (cnode )
171
+ input_shape = conv_node .inputs [0 ].shape [2 :]
172
+ output_shape = conv_node .outputs [0 ].shape [2 :]
173
+ attributes = read_conv_attributes (conv_node )
183
174
184
175
# Convert auto_pad mode into an explicit list
185
176
pads = self .compute_pads (input_shape , output_shape , attributes )
186
177
187
178
# Replace auto_pad, forcing to the explicit list
188
- conv_attr : typing . Mapping [ str , ir . Attr ] = cnode .attributes .copy ()
189
- conv_attr [ "auto_pad" ] = ir .convenience . convert_attribute ("auto_pad" , "NOTSET" )
179
+ conv_attr = conv_node .attributes .copy ()
180
+ conv_attr . add ( ir .AttrString ("auto_pad" , "NOTSET" ) )
190
181
if any (x != 0 for x in pads ):
191
- conv_attr [ "pads" ] = ir .convenience . convert_attribute ("pads" , pads )
182
+ conv_attr . add ( ir .AttrInt64s ("pads" , pads ) )
192
183
193
184
return op .op (
194
- cnode .op_type ,
195
- inputs = cnode .inputs ,
185
+ conv_node .op_type ,
186
+ inputs = conv_node .inputs ,
196
187
attributes = conv_attr ,
197
- domain = cnode .domain ,
198
- name = cnode .name ,
188
+ domain = conv_node .domain ,
189
+ name = conv_node .name ,
199
190
)
200
191
201
192
def check (self , context , conv : ir .Value , ** __ ) -> orp .MatchResult :
202
193
del context
203
194
check_result = orp .MatchResult ()
204
195
205
196
# Conv constraints: attributes
206
- cnode = conv .producer ()
207
- auto_pad = cnode .attributes .get ("auto_pad" , None )
197
+ conv_node = conv .producer ()
198
+ auto_pad = conv_node .attributes .get ("auto_pad" , None )
208
199
if auto_pad is None or auto_pad .as_string () == "NOTSET" :
209
- return check_result .fail (f"{ cnode .name } auto_pad must be different to 'NOTSET'." )
200
+ return check_result .fail (
201
+ f"{ conv_node .name } auto_pad must be different to 'NOTSET'."
202
+ )
210
203
211
204
# Conv constraints: inputs/outputs
212
- if cnode .inputs [0 ].shape is None :
213
- return check_result .fail (f"Input shapes are not defined on { cnode .name } ." )
214
- if cnode .outputs [0 ].shape is None :
215
- return check_result .fail (f"Output shapes are not defined on { cnode .name } ." )
205
+ if conv_node .inputs [0 ].shape is None :
206
+ return check_result .fail (f"Input shapes are not defined on { conv_node .name } ." )
207
+ if conv_node .outputs [0 ].shape is None :
208
+ return check_result .fail (f"Output shapes are not defined on { conv_node .name } ." )
216
209
return check_result
217
210
218
211
219
212
class NormalizePadFormatConv (_NormalizePadFormatBase ):
220
213
"""Convert auto_pad attribute into 'NOTSET' in Conv nodes ."""
221
214
215
+ @staticmethod
222
216
def compute_pads (
223
- self ,
224
- input_shape : typing .Sequence [int ],
225
- output_shape : typing .Sequence [int ],
226
- attributes : dict [str , typing .Sequence [int ] | str ],
227
- ) -> typing .Sequence [int ]:
217
+ input_shape : Sequence [int ],
218
+ output_shape : Sequence [int ],
219
+ attributes : dict [str , Sequence [int ] | str ],
220
+ ) -> Sequence [int ]:
228
221
# Compute pads, following auto_pad/pads attributes
229
222
if attributes ["auto_pad" ] in ["NOTSET" , "VALID" ]:
230
223
return attributes .get ("pads" , [0 ] * len (input_shape ) * 2 )
0 commit comments