@@ -38,25 +38,71 @@ class ConvConfig:
38
38
output_dtype : str
39
39
40
40
def get_name (self ) -> str :
41
- return self .OP + "_" + f"{ self .N } x{ self .H } x{ self .W } x{ self .C } x{ self .P } x{ self .Q } x{ self .F } " + "_" + f"{ self .input_dtype } x{ self .input_dtype } x{ self .output_dtype } " + "_stride" + str (self .S )
42
-
41
+ return (
42
+ self .OP
43
+ + "_"
44
+ + f"{ self .N } x{ self .H } x{ self .W } x{ self .C } x{ self .P } x{ self .Q } x{ self .F } "
45
+ + "_"
46
+ + f"{ self .input_dtype } x{ self .input_dtype } x{ self .output_dtype } "
47
+ + "_stride"
48
+ + str (self .S )
49
+ )
50
+
43
51
def get_img_shape (self ) -> str :
44
52
if "nhwc" in self .OP :
45
53
in_h = self .H * self .S + self .P - 1
46
54
in_w = self .W * self .S + self .Q - 1
47
- return str (self .N ) + "x" + str (in_h ) + "x" + str (in_w ) + "x" + str (self .C ) + "x" + self .input_dtype
55
+ return (
56
+ str (self .N )
57
+ + "x"
58
+ + str (in_h )
59
+ + "x"
60
+ + str (in_w )
61
+ + "x"
62
+ + str (self .C )
63
+ + "x"
64
+ + self .input_dtype
65
+ )
48
66
if "nchw" in self .OP :
49
67
in_h = self .H * self .S + self .P - 1
50
68
in_w = self .W * self .S + self .Q - 1
51
- return str (self .N ) + "x" + str (self .C ) + "x" + str (in_h ) + "x" + str (in_w ) + "x" + self .input_dtype
52
-
53
-
69
+ return (
70
+ str (self .N )
71
+ + "x"
72
+ + str (self .C )
73
+ + "x"
74
+ + str (in_h )
75
+ + "x"
76
+ + str (in_w )
77
+ + "x"
78
+ + self .input_dtype
79
+ )
80
+
54
81
def get_kernel_shape (self ) -> str :
55
82
if "nhwc" in self .OP :
56
- return str (self .P ) + "x" + str (self .Q ) + "x" + str (self .C ) + "x" + str (self .F ) + "x" + self .input_dtype
83
+ return (
84
+ str (self .P )
85
+ + "x"
86
+ + str (self .Q )
87
+ + "x"
88
+ + str (self .C )
89
+ + "x"
90
+ + str (self .F )
91
+ + "x"
92
+ + self .input_dtype
93
+ )
57
94
if "nchw" in self .OP :
58
- return str (self .F ) + "x" + str (self .C ) + "x" + str (self .P ) + "x" + str (self .Q ) + "x" + self .input_dtype
59
-
95
+ return (
96
+ str (self .F )
97
+ + "x"
98
+ + str (self .C )
99
+ + "x"
100
+ + str (self .P )
101
+ + "x"
102
+ + str (self .Q )
103
+ + "x"
104
+ + self .input_dtype
105
+ )
60
106
61
107
def get_byte_count (self ) -> int :
62
108
dtype_bits_map = {
@@ -80,7 +126,13 @@ def get_byte_count(self) -> int:
80
126
k_height = self .P
81
127
byte_count = (
82
128
(batch * input_channels * in_w * in_h * bytes_per_input )
83
- + (batch * output_channels * output_width * output_height * bytes_per_output )
129
+ + (
130
+ batch
131
+ * output_channels
132
+ * output_width
133
+ * output_height
134
+ * bytes_per_output
135
+ )
84
136
+ (k_width * k_height * input_channels * output_channels * bytes_per_input )
85
137
)
86
138
return byte_count
@@ -100,6 +152,7 @@ def get_flops(self) -> int:
100
152
flops = operation_per_pixel * output_pixels_per_batch * batch
101
153
return flops
102
154
155
+
103
156
def generate_mlir (config : ConvConfig ):
104
157
n = config .N
105
158
h = config .H
@@ -116,17 +169,77 @@ def generate_mlir(config: ConvConfig):
116
169
in_w = str (int (w ) * int (stride ) + int (q ) - 1 )
117
170
if "nhwc" in operation :
118
171
conv_type = "nhwc_hwcf"
119
- lhs = str (n ) + "x" + str (in_h ) + "x" + str (in_w ) + "x" + str (c ) + "x" + str (elem_types [0 ])
120
- rhs = str (p ) + "x" + str (q ) + "x" + str (c ) + "x" + str (f ) + "x" + str (elem_types [1 ])
121
- out = str (n ) + "x" + str (h ) + "x" + str (w ) + "x" + str (f ) + "x" + str (elem_types [2 ])
172
+ lhs = (
173
+ str (n )
174
+ + "x"
175
+ + str (in_h )
176
+ + "x"
177
+ + str (in_w )
178
+ + "x"
179
+ + str (c )
180
+ + "x"
181
+ + str (elem_types [0 ])
182
+ )
183
+ rhs = (
184
+ str (p )
185
+ + "x"
186
+ + str (q )
187
+ + "x"
188
+ + str (c )
189
+ + "x"
190
+ + str (f )
191
+ + "x"
192
+ + str (elem_types [1 ])
193
+ )
194
+ out = (
195
+ str (n )
196
+ + "x"
197
+ + str (h )
198
+ + "x"
199
+ + str (w )
200
+ + "x"
201
+ + str (f )
202
+ + "x"
203
+ + str (elem_types [2 ])
204
+ )
122
205
if "nchw" in operation :
123
206
conv_type = "nchw_fchw"
124
- lhs = str (n ) + "x" + str (c ) + "x" + str (in_h ) + "x" + str (in_w ) + "x" + str (elem_types [0 ])
125
- rhs = str (f ) + "x" + str (c ) + "x" + str (p ) + "x" + str (q ) + "x" + str (elem_types [1 ])
126
- out = str (n ) + "x" + str (f ) + "x" + str (h ) + "x" + str (w ) + "x" + str (elem_types [2 ])
207
+ lhs = (
208
+ str (n )
209
+ + "x"
210
+ + str (c )
211
+ + "x"
212
+ + str (in_h )
213
+ + "x"
214
+ + str (in_w )
215
+ + "x"
216
+ + str (elem_types [0 ])
217
+ )
218
+ rhs = (
219
+ str (f )
220
+ + "x"
221
+ + str (c )
222
+ + "x"
223
+ + str (p )
224
+ + "x"
225
+ + str (q )
226
+ + "x"
227
+ + str (elem_types [1 ])
228
+ )
229
+ out = (
230
+ str (n )
231
+ + "x"
232
+ + str (f )
233
+ + "x"
234
+ + str (h )
235
+ + "x"
236
+ + str (w )
237
+ + "x"
238
+ + str (elem_types [2 ])
239
+ )
127
240
one = "1"
128
241
zero = "0"
129
- if ( elem_types [0 ][0 ] == "f" ) :
242
+ if elem_types [0 ][0 ] == "f" :
130
243
one = "1.0"
131
244
zero = "0.0"
132
245
conv_template = CONV
0 commit comments