Skip to content

Commit a31f546

Browse files
committed
Add support for 3D and 2D grouped conolutions
1 parent 32fb07d commit a31f546

File tree

2 files changed

+66
-38
lines changed

2 files changed

+66
-38
lines changed

convbench/conv_utils.py

+28
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
CONV_Q = r"""%c0_i32 = arith.constant 0 : i32
1414
%11 = linalg.conv_2d_{CONV_TYPE}_q {{dilations = dense<1> : vector<2xi64>, strides = dense<{STRIDE}> : vector<2xi64>}} ins(%arg0, %arg1, %c0_i32, %c0_i32 : tensor<{INPUT_TYPE}>, tensor<{FILTER_TYPE}>, i32, i32) outs(%10 : tensor<{OUTPUT_TYPE}>) -> tensor<{OUTPUT_TYPE}>"""
1515

16+
CONV_3D = r"""%11 = linalg.conv_3d_{CONV_TYPE} {dilations = dense<1> : tensor<3xi64>, strides = dense<{STRIDE}> : tensor<3xi64>} ins (%arg0, %arg1: tensor<{INPUT_TYPE}>, tensor<{FILTER_TYPE}>>) outs(%10 : tensor<{OUTPUT_TYPE}>) -> tensor<{OUTPUT_TYPE}>"""
17+
1618
TEST = r"""util.func public @{FUNC_NAME}({FUNC_ARGS}) -> tensor<{OUT_TYPE}> {{{CONSTANT_INPUTS}
1719
%cst = arith.constant {ZERO} : {OUT_ELEM_TYPE}
1820
%9 = tensor.empty() : tensor<{OUT_TYPE}>
@@ -33,6 +35,12 @@ class ConvConfig:
3335
Q: int
3436
F: int
3537
S: int
38+
G: int # group count
39+
D: int # input depth
40+
R: int # filter depth
41+
P_D: int # padding along depth
42+
S_D: int # stride along depth
43+
D_D: int # dilation along depth
3644
OP: str
3745
input_dtype: str
3846
output_dtype: str
@@ -109,11 +117,18 @@ def generate_mlir(config: ConvConfig):
109117
q = config.Q
110118
f = config.F
111119
stride = config.S
120+
g = config.G
121+
d = config.D
122+
r = config.R
123+
p_d = config.P_D
124+
s_d = config.S_D
125+
d_d = config.D_D
112126
operation = config.OP
113127
dtypes = f"{config.input_dtype}x{config.input_dtype}x{config.output_dtype}"
114128
elem_types = dtypes.split("x")
115129
in_h = str(int(h) * int(stride) + int(p) - 1)
116130
in_w = str(int(w) * int(stride) + int(q) - 1)
131+
in_d = str(int(d) * int(s_d) + int(r) - 1)
117132
if "nhwc" in operation:
118133
conv_type = "nhwc_hwcf"
119134
lhs = str(n) + "x" + str(in_h) + "x" + str(in_w) + "x" + str(c) + "x" + str(elem_types[0])
@@ -124,6 +139,17 @@ def generate_mlir(config: ConvConfig):
124139
lhs = str(n) + "x" + str(c) + "x" + str(in_h) + "x" + str(in_w) + "x" + str(elem_types[0])
125140
rhs = str(f) + "x" + str(c) + "x" + str(p) + "x" + str(q) + "x" + str(elem_types[1])
126141
out = str(n) + "x" + str(f) + "x" + str(h) + "x" + str(w) + "x" + str(elem_types[2])
142+
if "ngchw" in operation:
143+
conv_type = "ngchw_fgchw"
144+
lhs = str(n) + "x" + str(g) + "x" + str(c) + "x" + str(in_h) + "x" + str(in_w) + "x" + str(elem_types[0])
145+
rhs = str(f) + "x" + str(g) + "x" + str(c) + "x" + str(p) + "x" + str(q) + "x" + str(elem_types[1])
146+
out = str(n) + "x" + str(g) + "x" + str(f) + "x" + str(h) + "x" + str(w) + "x" + str(elem_types[2])
147+
if "ncdhw" in operation:
148+
conv_type = "ncdhw_fcdhw"
149+
lhs = str(n) + "x" + str(c) + "x" + str(in_d) + "x" + str(in_h) + "x" + str(in_w) + "x" + str(elem_types[0])
150+
rhs = str(f) + "x" + str(c) + "x" + str(r) + "x" + str(p) + "x" + str(q) + "x" + str(elem_types[1])
151+
out = str(n) + "x" + str(f) + "x" + str(d) + "x" + str(h) + "x" + str(w) + "x" + str(elem_types[2])
152+
127153
one = "1"
128154
zero = "0"
129155
if (elem_types[0][0] == "f"):
@@ -132,6 +158,8 @@ def generate_mlir(config: ConvConfig):
132158
conv_template = CONV
133159
if "q" in operation:
134160
conv_template = CONV_Q
161+
if "ncdhw" in operation:
162+
conv_template = CONV_3D
135163
operation = conv_template.format(
136164
INPUT_TYPE=lhs,
137165
FILTER_TYPE=rhs,

convbench/problems.py

+38-38
Original file line numberDiff line numberDiff line change
@@ -4,49 +4,49 @@
44
def unet_sweep(op: str, input_dtype: str, output_dtype: str) -> list[ConvConfig]:
55
configs = []
66
for B in [1, 2, 4, 8]:
7-
configs.append(ConvConfig(B, 128, 128, 16, 3, 3, 320, 1, op, input_dtype, output_dtype))
8-
configs.append(ConvConfig(B, 128, 128, 320, 3, 3, 320, 1, op, input_dtype, output_dtype))
9-
configs.append(ConvConfig(B, 64, 64, 320, 3, 3, 320, 2, op, input_dtype, output_dtype))
10-
configs.append(ConvConfig(B, 64, 64, 320, 3, 3, 640, 1, op, input_dtype, output_dtype))
11-
configs.append(ConvConfig(B, 64, 64, 640, 3, 3, 640, 1, op, input_dtype, output_dtype))
12-
configs.append(ConvConfig(B, 64, 64, 320, 1, 1, 640, 1, op, input_dtype, output_dtype))
13-
configs.append(ConvConfig(B, 32, 32, 640, 3, 3, 640, 2, op, input_dtype, output_dtype))
14-
configs.append(ConvConfig(B, 32, 32, 640, 3, 3, 1280, 1, op, input_dtype, output_dtype))
15-
configs.append(ConvConfig(B, 32, 32, 1280, 3, 3, 1280, 1, op, input_dtype, output_dtype))
16-
configs.append(ConvConfig(B, 32, 32, 640, 1, 1, 1280, 1, op, input_dtype, output_dtype))
17-
configs.append(ConvConfig(B, 32, 32, 2560, 3, 3, 1280, 1, op, input_dtype, output_dtype))
18-
configs.append(ConvConfig(B, 32, 32, 2560, 1, 1, 1280, 1, op, input_dtype, output_dtype))
19-
configs.append(ConvConfig(B, 32, 32, 1920, 3, 3, 1280, 1, op, input_dtype, output_dtype))
20-
configs.append(ConvConfig(B, 32, 32, 1920, 1, 1, 1280, 1, op, input_dtype, output_dtype))
21-
configs.append(ConvConfig(B, 64, 64, 1280, 3, 3, 1280, 1, op, input_dtype, output_dtype))
22-
configs.append(ConvConfig(B, 64, 64, 1920, 3, 3, 640, 1, op, input_dtype, output_dtype))
23-
configs.append(ConvConfig(B, 64, 64, 1920, 1, 1, 640, 1, op, input_dtype, output_dtype))
24-
configs.append(ConvConfig(B, 64, 64, 1280, 3, 3, 640, 1, op, input_dtype, output_dtype))
25-
configs.append(ConvConfig(B, 64, 64, 1280, 1, 1, 640, 1, op, input_dtype, output_dtype))
26-
configs.append(ConvConfig(B, 64, 64, 960, 3, 3, 640, 1, op, input_dtype, output_dtype))
27-
configs.append(ConvConfig(B, 64, 64, 960, 1, 1, 640, 1, op, input_dtype, output_dtype))
28-
configs.append(ConvConfig(B, 128, 128, 640, 3, 3, 640, 1, op, input_dtype, output_dtype))
29-
configs.append(ConvConfig(B, 128, 128, 960, 3, 3, 320, 1, op, input_dtype, output_dtype))
30-
configs.append(ConvConfig(B, 128, 128, 960, 1, 1, 320, 1, op, input_dtype, output_dtype))
31-
configs.append(ConvConfig(B, 128, 128, 640, 3, 3, 320, 1, op, input_dtype, output_dtype))
32-
configs.append(ConvConfig(B, 128, 128, 640, 1, 1, 320, 1, op, input_dtype, output_dtype))
33-
configs.append(ConvConfig(B, 128, 128, 320, 3, 3, 16, 1, op, input_dtype, output_dtype))
7+
configs.append(ConvConfig(B, 128, 128, 16, 3, 3, 320, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
8+
configs.append(ConvConfig(B, 128, 128, 320, 3, 3, 320, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
9+
configs.append(ConvConfig(B, 64, 64, 320, 3, 3, 320, 2, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
10+
configs.append(ConvConfig(B, 64, 64, 320, 3, 3, 640, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
11+
configs.append(ConvConfig(B, 64, 64, 640, 3, 3, 640, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
12+
configs.append(ConvConfig(B, 64, 64, 320, 1, 1, 640, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
13+
configs.append(ConvConfig(B, 32, 32, 640, 3, 3, 640, 2, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
14+
configs.append(ConvConfig(B, 32, 32, 640, 3, 3, 1280, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
15+
configs.append(ConvConfig(B, 32, 32, 1280, 3, 3, 1280, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
16+
configs.append(ConvConfig(B, 32, 32, 640, 1, 1, 1280, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
17+
configs.append(ConvConfig(B, 32, 32, 2560, 3, 3, 1280, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
18+
configs.append(ConvConfig(B, 32, 32, 2560, 1, 1, 1280, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
19+
configs.append(ConvConfig(B, 32, 32, 1920, 3, 3, 1280, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
20+
configs.append(ConvConfig(B, 32, 32, 1920, 1, 1, 1280, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
21+
configs.append(ConvConfig(B, 64, 64, 1280, 3, 3, 1280, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
22+
configs.append(ConvConfig(B, 64, 64, 1920, 3, 3, 640, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
23+
configs.append(ConvConfig(B, 64, 64, 1920, 1, 1, 640, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
24+
configs.append(ConvConfig(B, 64, 64, 1280, 3, 3, 640, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
25+
configs.append(ConvConfig(B, 64, 64, 1280, 1, 1, 640, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
26+
configs.append(ConvConfig(B, 64, 64, 960, 3, 3, 640, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
27+
configs.append(ConvConfig(B, 64, 64, 960, 1, 1, 640, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
28+
configs.append(ConvConfig(B, 128, 128, 640, 3, 3, 640, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
29+
configs.append(ConvConfig(B, 128, 128, 960, 3, 3, 320, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
30+
configs.append(ConvConfig(B, 128, 128, 960, 1, 1, 320, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
31+
configs.append(ConvConfig(B, 128, 128, 640, 3, 3, 320, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
32+
configs.append(ConvConfig(B, 128, 128, 640, 1, 1, 320, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
33+
configs.append(ConvConfig(B, 128, 128, 320, 3, 3, 16, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
3434
return configs
3535

3636
def resnet_sweep(op: str, input_dtype: str, output_dtype: str) -> list[ConvConfig]:
3737
configs = []
3838
for B in [1, 2, 4, 8]:
39-
configs.append(ConvConfig(B, 112, 112, 64, 7, 7, 3, 2, op, input_dtype, output_dtype))
40-
configs.append(ConvConfig(B, 56, 56, 64, 3, 3, 64, 1, op, input_dtype, output_dtype))
41-
configs.append(ConvConfig(B, 28, 28, 128, 3, 3, 128, 2, op, input_dtype, output_dtype))
42-
configs.append(ConvConfig(B, 28, 28, 512, 1, 1, 256, 2, op, input_dtype, output_dtype))
43-
configs.append(ConvConfig(B, 28, 28, 128, 3, 3, 128, 1, op, input_dtype, output_dtype))
44-
configs.append(ConvConfig(B, 14, 14, 256, 3, 3, 256, 2, op, input_dtype, output_dtype))
45-
configs.append(ConvConfig(B, 14, 14, 1024, 1, 1, 512, 2, op, input_dtype, output_dtype))
46-
configs.append(ConvConfig(B, 14, 14, 256, 3, 3, 256, 1, op, input_dtype, output_dtype))
47-
configs.append(ConvConfig(B, 7, 7, 512, 3, 3, 512, 2, op, input_dtype, output_dtype))
48-
configs.append(ConvConfig(B, 7, 7, 2048, 1, 1, 1024, 2, op, input_dtype, output_dtype))
49-
configs.append(ConvConfig(B, 7, 7, 512, 3, 3, 512, 1, op, input_dtype, output_dtype))
39+
configs.append(ConvConfig(B, 112, 112, 64, 7, 7, 3, 2, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
40+
configs.append(ConvConfig(B, 56, 56, 64, 3, 3, 64, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
41+
configs.append(ConvConfig(B, 28, 28, 128, 3, 3, 128, 2, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
42+
configs.append(ConvConfig(B, 28, 28, 512, 1, 1, 256, 2, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
43+
configs.append(ConvConfig(B, 28, 28, 128, 3, 3, 128, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
44+
configs.append(ConvConfig(B, 14, 14, 256, 3, 3, 256, 2, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
45+
configs.append(ConvConfig(B, 14, 14, 1024, 1, 1, 512, 2, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
46+
configs.append(ConvConfig(B, 14, 14, 256, 3, 3, 256, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
47+
configs.append(ConvConfig(B, 7, 7, 512, 3, 3, 512, 2, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
48+
configs.append(ConvConfig(B, 7, 7, 2048, 1, 1, 1024, 2, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
49+
configs.append(ConvConfig(B, 7, 7, 512, 3, 3, 512, 1, 1, -1, -1, -1, -1, -1, op, input_dtype, output_dtype))
5050
return configs
5151

5252
def get_conv_configs() -> list[tuple[str, ConvConfig]]:

0 commit comments

Comments
 (0)