Skip to content

Commit 7b5fd7c

Browse files
author
Prashant Kumar
committed
Black format the files.
Black format with default option, i.e., black file_name.py.
1 parent 0aa226d commit 7b5fd7c

File tree

2 files changed

+139
-19
lines changed

2 files changed

+139
-19
lines changed

convbench/conv_utils.py

+130-17
Original file line numberDiff line numberDiff line change
@@ -38,25 +38,71 @@ class ConvConfig:
3838
output_dtype: str
3939

4040
def get_name(self) -> str:
41-
return self.OP + "_" + f"{self.N}x{self.H}x{self.W}x{self.C}x{self.P}x{self.Q}x{self.F}" + "_" + f"{self.input_dtype}x{self.input_dtype}x{self.output_dtype}" + "_stride" + str(self.S)
42-
41+
return (
42+
self.OP
43+
+ "_"
44+
+ f"{self.N}x{self.H}x{self.W}x{self.C}x{self.P}x{self.Q}x{self.F}"
45+
+ "_"
46+
+ f"{self.input_dtype}x{self.input_dtype}x{self.output_dtype}"
47+
+ "_stride"
48+
+ str(self.S)
49+
)
50+
4351
def get_img_shape(self) -> str:
4452
if "nhwc" in self.OP:
4553
in_h = self.H * self.S + self.P - 1
4654
in_w = self.W * self.S + self.Q - 1
47-
return str(self.N) + "x" + str(in_h) + "x" + str(in_w) + "x" + str(self.C) + "x" + self.input_dtype
55+
return (
56+
str(self.N)
57+
+ "x"
58+
+ str(in_h)
59+
+ "x"
60+
+ str(in_w)
61+
+ "x"
62+
+ str(self.C)
63+
+ "x"
64+
+ self.input_dtype
65+
)
4866
if "nchw" in self.OP:
4967
in_h = self.H * self.S + self.P - 1
5068
in_w = self.W * self.S + self.Q - 1
51-
return str(self.N) + "x" + str(self.C) + "x" + str(in_h) + "x" + str(in_w) + "x" + self.input_dtype
52-
53-
69+
return (
70+
str(self.N)
71+
+ "x"
72+
+ str(self.C)
73+
+ "x"
74+
+ str(in_h)
75+
+ "x"
76+
+ str(in_w)
77+
+ "x"
78+
+ self.input_dtype
79+
)
80+
5481
def get_kernel_shape(self) -> str:
5582
if "nhwc" in self.OP:
56-
return str(self.P) + "x" + str(self.Q) + "x" + str(self.C) + "x" + str(self.F) + "x" + self.input_dtype
83+
return (
84+
str(self.P)
85+
+ "x"
86+
+ str(self.Q)
87+
+ "x"
88+
+ str(self.C)
89+
+ "x"
90+
+ str(self.F)
91+
+ "x"
92+
+ self.input_dtype
93+
)
5794
if "nchw" in self.OP:
58-
return str(self.F) + "x" + str(self.C) + "x" + str(self.P) + "x" + str(self.Q) + "x" + self.input_dtype
59-
95+
return (
96+
str(self.F)
97+
+ "x"
98+
+ str(self.C)
99+
+ "x"
100+
+ str(self.P)
101+
+ "x"
102+
+ str(self.Q)
103+
+ "x"
104+
+ self.input_dtype
105+
)
60106

61107
def get_byte_count(self) -> int:
62108
dtype_bits_map = {
@@ -80,7 +126,13 @@ def get_byte_count(self) -> int:
80126
k_height = self.P
81127
byte_count = (
82128
(batch * input_channels * in_w * in_h * bytes_per_input)
83-
+ (batch * output_channels * output_width * output_height * bytes_per_output)
129+
+ (
130+
batch
131+
* output_channels
132+
* output_width
133+
* output_height
134+
* bytes_per_output
135+
)
84136
+ (k_width * k_height * input_channels * output_channels * bytes_per_input)
85137
)
86138
return byte_count
@@ -100,6 +152,7 @@ def get_flops(self) -> int:
100152
flops = operation_per_pixel * output_pixels_per_batch * batch
101153
return flops
102154

155+
103156
def generate_mlir(config: ConvConfig):
104157
n = config.N
105158
h = config.H
@@ -116,17 +169,77 @@ def generate_mlir(config: ConvConfig):
116169
in_w = str(int(w) * int(stride) + int(q) - 1)
117170
if "nhwc" in operation:
118171
conv_type = "nhwc_hwcf"
119-
lhs = str(n) + "x" + str(in_h) + "x" + str(in_w) + "x" + str(c) + "x" + str(elem_types[0])
120-
rhs = str(p) + "x" + str(q) + "x" + str(c) + "x" + str(f) + "x" + str(elem_types[1])
121-
out = str(n) + "x" + str(h) + "x" + str(w) + "x" + str(f) + "x" + str(elem_types[2])
172+
lhs = (
173+
str(n)
174+
+ "x"
175+
+ str(in_h)
176+
+ "x"
177+
+ str(in_w)
178+
+ "x"
179+
+ str(c)
180+
+ "x"
181+
+ str(elem_types[0])
182+
)
183+
rhs = (
184+
str(p)
185+
+ "x"
186+
+ str(q)
187+
+ "x"
188+
+ str(c)
189+
+ "x"
190+
+ str(f)
191+
+ "x"
192+
+ str(elem_types[1])
193+
)
194+
out = (
195+
str(n)
196+
+ "x"
197+
+ str(h)
198+
+ "x"
199+
+ str(w)
200+
+ "x"
201+
+ str(f)
202+
+ "x"
203+
+ str(elem_types[2])
204+
)
122205
if "nchw" in operation:
123206
conv_type = "nchw_fchw"
124-
lhs = str(n) + "x" + str(c) + "x" + str(in_h) + "x" + str(in_w) + "x" + str(elem_types[0])
125-
rhs = str(f) + "x" + str(c) + "x" + str(p) + "x" + str(q) + "x" + str(elem_types[1])
126-
out = str(n) + "x" + str(f) + "x" + str(h) + "x" + str(w) + "x" + str(elem_types[2])
207+
lhs = (
208+
str(n)
209+
+ "x"
210+
+ str(c)
211+
+ "x"
212+
+ str(in_h)
213+
+ "x"
214+
+ str(in_w)
215+
+ "x"
216+
+ str(elem_types[0])
217+
)
218+
rhs = (
219+
str(f)
220+
+ "x"
221+
+ str(c)
222+
+ "x"
223+
+ str(p)
224+
+ "x"
225+
+ str(q)
226+
+ "x"
227+
+ str(elem_types[1])
228+
)
229+
out = (
230+
str(n)
231+
+ "x"
232+
+ str(f)
233+
+ "x"
234+
+ str(h)
235+
+ "x"
236+
+ str(w)
237+
+ "x"
238+
+ str(elem_types[2])
239+
)
127240
one = "1"
128241
zero = "0"
129-
if (elem_types[0][0] == "f"):
242+
if elem_types[0][0] == "f":
130243
one = "1.0"
131244
zero = "0.0"
132245
conv_template = CONV

convbench/shark_conv.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,21 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir):
2626
type=str.upper,
2727
help="Set the logging level",
2828
)
29-
parser.add_argument("--device", help="The IREE device to execute benchmarks on", type=str, default="hip")
29+
parser.add_argument(
30+
"--device",
31+
help="The IREE device to execute benchmarks on",
32+
type=str,
33+
default="hip",
34+
)
3035
parser.add_argument(
3136
"--roofline",
3237
help="Comma seperated csv file list to generate roofline plot with",
3338
default=None,
3439
)
3540
parser.add_argument("--plot", help="location to save plot", default=None)
36-
parser.add_argument("--batch", help="roofline on certain batch", type=int, default=None)
41+
parser.add_argument(
42+
"--batch", help="roofline on certain batch", type=int, default=None
43+
)
3744
parser.add_argument("--dtype", help="roofline on certain dtype", default=None)
3845
parser.add_argument("--model", help="roofline on certain model", default=None)
3946

0 commit comments

Comments
 (0)