Skip to content

Commit 0bdd955

Browse files
committed
Add support for 3D and 2D grouped conolutions
1 parent 32fb07d commit 0bdd955

File tree

2 files changed

+77
-49
lines changed

2 files changed

+77
-49
lines changed

convbench/conv_utils.py

+39-11
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
CONV_Q = r"""%c0_i32 = arith.constant 0 : i32
1414
%11 = linalg.conv_2d_{CONV_TYPE}_q {{dilations = dense<1> : vector<2xi64>, strides = dense<{STRIDE}> : vector<2xi64>}} ins(%arg0, %arg1, %c0_i32, %c0_i32 : tensor<{INPUT_TYPE}>, tensor<{FILTER_TYPE}>, i32, i32) outs(%10 : tensor<{OUTPUT_TYPE}>) -> tensor<{OUTPUT_TYPE}>"""
1515

16+
CONV_3D = r"""%11 = linalg.conv_3d_{CONV_TYPE} {dilations = dense<1> : tensor<3xi64>, strides = dense<{STRIDE}> : tensor<3xi64>} ins (%arg0, %arg1: tensor<{INPUT_TYPE}>, tensor<{FILTER_TYPE}>>) outs(%10 : tensor<{OUTPUT_TYPE}>) -> tensor<{OUTPUT_TYPE}>"""
17+
1618
TEST = r"""util.func public @{FUNC_NAME}({FUNC_ARGS}) -> tensor<{OUT_TYPE}> {{{CONSTANT_INPUTS}
1719
%cst = arith.constant {ZERO} : {OUT_ELEM_TYPE}
1820
%9 = tensor.empty() : tensor<{OUT_TYPE}>
@@ -33,30 +35,36 @@ class ConvConfig:
3335
Q: int
3436
F: int
3537
S: int
38+
is_grouped_conv: bool
39+
G: int # group count
40+
is_3D_conv: bool
41+
D: int # input depth
42+
R: int # filter depth
43+
S_D: int # stride along depth
3644
OP: str
3745
input_dtype: str
3846
output_dtype: str
3947

4048
def get_name(self) -> str:
41-
return self.OP + "_" + f"{self.N}x{self.H}x{self.W}x{self.C}x{self.P}x{self.Q}x{self.F}" + "_" + f"{self.input_dtype}x{self.input_dtype}x{self.output_dtype}" + "_stride" + str(self.S)
49+
return self.OP + "_" + f"{self.N}x{self.H}x{self.W}x{self.C}x{self.P}x{self.Q}x{self.F}" + "_" + f"{self.input_dtype}x{self.input_dtype}x{self.output_dtype}" + "_stride" + str(self.S) + "_groupcount" + str(self.G)
4250

4351
def get_img_shape(self) -> str:
52+
in_h = self.H * self.S + self.P - 1
53+
in_w = self.W * self.S + self.Q - 1
4454
if "nhwc" in self.OP:
45-
in_h = self.H * self.S + self.P - 1
46-
in_w = self.W * self.S + self.Q - 1
4755
return str(self.N) + "x" + str(in_h) + "x" + str(in_w) + "x" + str(self.C) + "x" + self.input_dtype
4856
if "nchw" in self.OP:
49-
in_h = self.H * self.S + self.P - 1
50-
in_w = self.W * self.S + self.Q - 1
5157
return str(self.N) + "x" + str(self.C) + "x" + str(in_h) + "x" + str(in_w) + "x" + self.input_dtype
52-
58+
if "ngchw" in operation:
59+
return str(self.N) + "x" + str(self.G) + "x" + str(self.C) + "x" + str(in_h) + "x" + str(in_w) + "x" + self.input_dtype
5360

5461
def get_kernel_shape(self) -> str:
5562
if "nhwc" in self.OP:
5663
return str(self.P) + "x" + str(self.Q) + "x" + str(self.C) + "x" + str(self.F) + "x" + self.input_dtype
5764
if "nchw" in self.OP:
5865
return str(self.F) + "x" + str(self.C) + "x" + str(self.P) + "x" + str(self.Q) + "x" + self.input_dtype
59-
66+
if "ngchw" in operation:
67+
return str(self.F) + "x" + str(self.G) + "x" + str(self.C) + "x" + str(self.P) + "x" + str(self.Q) + "x" + self.input_dtype
6068

6169
def get_byte_count(self) -> int:
6270
dtype_bits_map = {
@@ -73,15 +81,16 @@ def get_byte_count(self) -> int:
7381
in_h = self.H * self.S + self.P - 1
7482
in_w = self.W * self.S + self.Q - 1
7583
input_channels = self.C
84+
group_count = self.G
7685
output_channels = self.F
7786
output_width = self.W
7887
output_height = self.H
7988
k_width = self.Q
8089
k_height = self.P
8190
byte_count = (
82-
(batch * input_channels * in_w * in_h * bytes_per_input)
83-
+ (batch * output_channels * output_width * output_height * bytes_per_output)
84-
+ (k_width * k_height * input_channels * output_channels * bytes_per_input)
91+
(batch * group_count * input_channels * in_w * in_h * bytes_per_input)
92+
+ (batch * group_count * output_channels * output_width * output_height * bytes_per_output)
93+
+ (group_count * k_width * k_height * input_channels * output_channels * bytes_per_input)
8594
)
8695
return byte_count
8796

@@ -90,14 +99,15 @@ def get_flops(self) -> int:
9099
in_h = self.H * self.S + self.P - 1
91100
in_w = self.W * self.S + self.Q - 1
92101
input_channels = self.C
102+
group_count = self.G
93103
output_channels = self.F
94104
output_width = self.W
95105
output_height = self.H
96106
k_width = self.Q
97107
k_height = self.P
98108
operation_per_pixel = k_width * k_height * input_channels * 2
99109
output_pixels_per_batch = output_width * output_height * output_channels
100-
flops = operation_per_pixel * output_pixels_per_batch * batch
110+
flops = operation_per_pixel * output_pixels_per_batch * group_count * batch
101111
return flops
102112

103113
def generate_mlir(config: ConvConfig):
@@ -109,11 +119,16 @@ def generate_mlir(config: ConvConfig):
109119
q = config.Q
110120
f = config.F
111121
stride = config.S
122+
g = config.G
123+
d = config.D
124+
r = config.R
125+
s_d = config.S_D
112126
operation = config.OP
113127
dtypes = f"{config.input_dtype}x{config.input_dtype}x{config.output_dtype}"
114128
elem_types = dtypes.split("x")
115129
in_h = str(int(h) * int(stride) + int(p) - 1)
116130
in_w = str(int(w) * int(stride) + int(q) - 1)
131+
in_d = str(int(d) * int(s_d) + int(r) - 1)
117132
if "nhwc" in operation:
118133
conv_type = "nhwc_hwcf"
119134
lhs = str(n) + "x" + str(in_h) + "x" + str(in_w) + "x" + str(c) + "x" + str(elem_types[0])
@@ -124,6 +139,17 @@ def generate_mlir(config: ConvConfig):
124139
lhs = str(n) + "x" + str(c) + "x" + str(in_h) + "x" + str(in_w) + "x" + str(elem_types[0])
125140
rhs = str(f) + "x" + str(c) + "x" + str(p) + "x" + str(q) + "x" + str(elem_types[1])
126141
out = str(n) + "x" + str(f) + "x" + str(h) + "x" + str(w) + "x" + str(elem_types[2])
142+
if "ngchw" in operation:
143+
conv_type = "ngchw_fgchw"
144+
lhs = str(n) + "x" + str(g) + "x" + str(c) + "x" + str(in_h) + "x" + str(in_w) + "x" + str(elem_types[0])
145+
rhs = str(f) + "x" + str(g) + "x" + str(c) + "x" + str(p) + "x" + str(q) + "x" + str(elem_types[1])
146+
out = str(n) + "x" + str(g) + "x" + str(f) + "x" + str(h) + "x" + str(w) + "x" + str(elem_types[2])
147+
if "ncdhw" in operation:
148+
conv_type = "ncdhw_fcdhw"
149+
lhs = str(n) + "x" + str(c) + "x" + str(in_d) + "x" + str(in_h) + "x" + str(in_w) + "x" + str(elem_types[0])
150+
rhs = str(f) + "x" + str(c) + "x" + str(r) + "x" + str(p) + "x" + str(q) + "x" + str(elem_types[1])
151+
out = str(n) + "x" + str(f) + "x" + str(d) + "x" + str(h) + "x" + str(w) + "x" + str(elem_types[2])
152+
127153
one = "1"
128154
zero = "0"
129155
if (elem_types[0][0] == "f"):
@@ -132,6 +158,8 @@ def generate_mlir(config: ConvConfig):
132158
conv_template = CONV
133159
if "q" in operation:
134160
conv_template = CONV_Q
161+
if config.is_3D_conv:
162+
conv_template = CONV_3D
135163
operation = conv_template.format(
136164
INPUT_TYPE=lhs,
137165
FILTER_TYPE=rhs,

convbench/problems.py

+38-38
Original file line numberDiff line numberDiff line change
@@ -4,49 +4,49 @@
44
def unet_sweep(op: str, input_dtype: str, output_dtype: str) -> list[ConvConfig]:
55
configs = []
66
for B in [1, 2, 4, 8]:
7-
configs.append(ConvConfig(B, 128, 128, 16, 3, 3, 320, 1, op, input_dtype, output_dtype))
8-
configs.append(ConvConfig(B, 128, 128, 320, 3, 3, 320, 1, op, input_dtype, output_dtype))
9-
configs.append(ConvConfig(B, 64, 64, 320, 3, 3, 320, 2, op, input_dtype, output_dtype))
10-
configs.append(ConvConfig(B, 64, 64, 320, 3, 3, 640, 1, op, input_dtype, output_dtype))
11-
configs.append(ConvConfig(B, 64, 64, 640, 3, 3, 640, 1, op, input_dtype, output_dtype))
12-
configs.append(ConvConfig(B, 64, 64, 320, 1, 1, 640, 1, op, input_dtype, output_dtype))
13-
configs.append(ConvConfig(B, 32, 32, 640, 3, 3, 640, 2, op, input_dtype, output_dtype))
14-
configs.append(ConvConfig(B, 32, 32, 640, 3, 3, 1280, 1, op, input_dtype, output_dtype))
15-
configs.append(ConvConfig(B, 32, 32, 1280, 3, 3, 1280, 1, op, input_dtype, output_dtype))
16-
configs.append(ConvConfig(B, 32, 32, 640, 1, 1, 1280, 1, op, input_dtype, output_dtype))
17-
configs.append(ConvConfig(B, 32, 32, 2560, 3, 3, 1280, 1, op, input_dtype, output_dtype))
18-
configs.append(ConvConfig(B, 32, 32, 2560, 1, 1, 1280, 1, op, input_dtype, output_dtype))
19-
configs.append(ConvConfig(B, 32, 32, 1920, 3, 3, 1280, 1, op, input_dtype, output_dtype))
20-
configs.append(ConvConfig(B, 32, 32, 1920, 1, 1, 1280, 1, op, input_dtype, output_dtype))
21-
configs.append(ConvConfig(B, 64, 64, 1280, 3, 3, 1280, 1, op, input_dtype, output_dtype))
22-
configs.append(ConvConfig(B, 64, 64, 1920, 3, 3, 640, 1, op, input_dtype, output_dtype))
23-
configs.append(ConvConfig(B, 64, 64, 1920, 1, 1, 640, 1, op, input_dtype, output_dtype))
24-
configs.append(ConvConfig(B, 64, 64, 1280, 3, 3, 640, 1, op, input_dtype, output_dtype))
25-
configs.append(ConvConfig(B, 64, 64, 1280, 1, 1, 640, 1, op, input_dtype, output_dtype))
26-
configs.append(ConvConfig(B, 64, 64, 960, 3, 3, 640, 1, op, input_dtype, output_dtype))
27-
configs.append(ConvConfig(B, 64, 64, 960, 1, 1, 640, 1, op, input_dtype, output_dtype))
28-
configs.append(ConvConfig(B, 128, 128, 640, 3, 3, 640, 1, op, input_dtype, output_dtype))
29-
configs.append(ConvConfig(B, 128, 128, 960, 3, 3, 320, 1, op, input_dtype, output_dtype))
30-
configs.append(ConvConfig(B, 128, 128, 960, 1, 1, 320, 1, op, input_dtype, output_dtype))
31-
configs.append(ConvConfig(B, 128, 128, 640, 3, 3, 320, 1, op, input_dtype, output_dtype))
32-
configs.append(ConvConfig(B, 128, 128, 640, 1, 1, 320, 1, op, input_dtype, output_dtype))
33-
configs.append(ConvConfig(B, 128, 128, 320, 3, 3, 16, 1, op, input_dtype, output_dtype))
7+
configs.append(ConvConfig(B, 128, 128, 16, 3, 3, 320, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
8+
configs.append(ConvConfig(B, 128, 128, 320, 3, 3, 320, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
9+
configs.append(ConvConfig(B, 64, 64, 320, 3, 3, 320, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
10+
configs.append(ConvConfig(B, 64, 64, 320, 3, 3, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
11+
configs.append(ConvConfig(B, 64, 64, 640, 3, 3, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
12+
configs.append(ConvConfig(B, 64, 64, 320, 1, 1, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
13+
configs.append(ConvConfig(B, 32, 32, 640, 3, 3, 640, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
14+
configs.append(ConvConfig(B, 32, 32, 640, 3, 3, 1280, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
15+
configs.append(ConvConfig(B, 32, 32, 1280, 3, 3, 1280, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
16+
configs.append(ConvConfig(B, 32, 32, 640, 1, 1, 1280, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
17+
configs.append(ConvConfig(B, 32, 32, 2560, 3, 3, 1280, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
18+
configs.append(ConvConfig(B, 32, 32, 2560, 1, 1, 1280, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
19+
configs.append(ConvConfig(B, 32, 32, 1920, 3, 3, 1280, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
20+
configs.append(ConvConfig(B, 32, 32, 1920, 1, 1, 1280, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
21+
configs.append(ConvConfig(B, 64, 64, 1280, 3, 3, 1280, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
22+
configs.append(ConvConfig(B, 64, 64, 1920, 3, 3, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
23+
configs.append(ConvConfig(B, 64, 64, 1920, 1, 1, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
24+
configs.append(ConvConfig(B, 64, 64, 1280, 3, 3, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
25+
configs.append(ConvConfig(B, 64, 64, 1280, 1, 1, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
26+
configs.append(ConvConfig(B, 64, 64, 960, 3, 3, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
27+
configs.append(ConvConfig(B, 64, 64, 960, 1, 1, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
28+
configs.append(ConvConfig(B, 128, 128, 640, 3, 3, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
29+
configs.append(ConvConfig(B, 128, 128, 960, 3, 3, 320, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
30+
configs.append(ConvConfig(B, 128, 128, 960, 1, 1, 320, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
31+
configs.append(ConvConfig(B, 128, 128, 640, 3, 3, 320, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
32+
configs.append(ConvConfig(B, 128, 128, 640, 1, 1, 320, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
33+
configs.append(ConvConfig(B, 128, 128, 320, 3, 3, 16, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
3434
return configs
3535

3636
def resnet_sweep(op: str, input_dtype: str, output_dtype: str) -> list[ConvConfig]:
3737
configs = []
3838
for B in [1, 2, 4, 8]:
39-
configs.append(ConvConfig(B, 112, 112, 64, 7, 7, 3, 2, op, input_dtype, output_dtype))
40-
configs.append(ConvConfig(B, 56, 56, 64, 3, 3, 64, 1, op, input_dtype, output_dtype))
41-
configs.append(ConvConfig(B, 28, 28, 128, 3, 3, 128, 2, op, input_dtype, output_dtype))
42-
configs.append(ConvConfig(B, 28, 28, 512, 1, 1, 256, 2, op, input_dtype, output_dtype))
43-
configs.append(ConvConfig(B, 28, 28, 128, 3, 3, 128, 1, op, input_dtype, output_dtype))
44-
configs.append(ConvConfig(B, 14, 14, 256, 3, 3, 256, 2, op, input_dtype, output_dtype))
45-
configs.append(ConvConfig(B, 14, 14, 1024, 1, 1, 512, 2, op, input_dtype, output_dtype))
46-
configs.append(ConvConfig(B, 14, 14, 256, 3, 3, 256, 1, op, input_dtype, output_dtype))
47-
configs.append(ConvConfig(B, 7, 7, 512, 3, 3, 512, 2, op, input_dtype, output_dtype))
48-
configs.append(ConvConfig(B, 7, 7, 2048, 1, 1, 1024, 2, op, input_dtype, output_dtype))
49-
configs.append(ConvConfig(B, 7, 7, 512, 3, 3, 512, 1, op, input_dtype, output_dtype))
39+
configs.append(ConvConfig(B, 112, 112, 64, 7, 7, 3, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
40+
configs.append(ConvConfig(B, 56, 56, 64, 3, 3, 64, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
41+
configs.append(ConvConfig(B, 28, 28, 128, 3, 3, 128, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
42+
configs.append(ConvConfig(B, 28, 28, 512, 1, 1, 256, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
43+
configs.append(ConvConfig(B, 28, 28, 128, 3, 3, 128, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
44+
configs.append(ConvConfig(B, 14, 14, 256, 3, 3, 256, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
45+
configs.append(ConvConfig(B, 14, 14, 1024, 1, 1, 512, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
46+
configs.append(ConvConfig(B, 14, 14, 256, 3, 3, 256, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
47+
configs.append(ConvConfig(B, 7, 7, 512, 3, 3, 512, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
48+
configs.append(ConvConfig(B, 7, 7, 2048, 1, 1, 1024, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
49+
configs.append(ConvConfig(B, 7, 7, 512, 3, 3, 512, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
5050
return configs
5151

5252
def get_conv_configs() -> list[tuple[str, ConvConfig]]:

0 commit comments

Comments
 (0)