Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for 3D and 2D grouped conolutions #33

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 39 additions & 11 deletions convbench/conv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
CONV_Q = r"""%c0_i32 = arith.constant 0 : i32
%11 = linalg.conv_2d_{CONV_TYPE}_q {{dilations = dense<1> : vector<2xi64>, strides = dense<{STRIDE}> : vector<2xi64>}} ins(%arg0, %arg1, %c0_i32, %c0_i32 : tensor<{INPUT_TYPE}>, tensor<{FILTER_TYPE}>, i32, i32) outs(%10 : tensor<{OUTPUT_TYPE}>) -> tensor<{OUTPUT_TYPE}>"""

CONV_3D = r"""%11 = linalg.conv_3d_{CONV_TYPE} {dilations = dense<1> : tensor<3xi64>, strides = dense<{STRIDE}> : tensor<3xi64>} ins (%arg0, %arg1: tensor<{INPUT_TYPE}>, tensor<{FILTER_TYPE}>>) outs(%10 : tensor<{OUTPUT_TYPE}>) -> tensor<{OUTPUT_TYPE}>"""

TEST = r"""util.func public @{FUNC_NAME}({FUNC_ARGS}) -> tensor<{OUT_TYPE}> {{{CONSTANT_INPUTS}
%cst = arith.constant {ZERO} : {OUT_ELEM_TYPE}
%9 = tensor.empty() : tensor<{OUT_TYPE}>
Expand All @@ -33,30 +35,36 @@ class ConvConfig:
Q: int
F: int
S: int
is_grouped_conv: bool
G: int # group count
is_3D_conv: bool
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

possible to make another class instead for conv3d? That way every time we try to add a problem, we don't need the False, -1, -1, -1 part.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll suggest another option, which would be to give these parameters some default values and put them at the end of the __init__ function. Then we don't need an extra class and the conv2d configs don't need to add these fields.

It's probably a good idea to get rid of these classes in favor of using python bindings at some point anyway, so I see the config classes a temporary implementation. Linalg ops are relatively stable, so it has not caused us any maintenance issues yet, but the bindings are much more resilient to changes to IR assembly format. We are making the same transition on the tuner side, right now, because it is very dependent on codegen dialects (which are very in flux right now).

Eventually we can get rid of these separate config classes, and just have a single config class that is used across all kernels (gemm, attention, conv, etc.) that just has functions to build the desired kernel types, and track things like peak flops, arithmetic intensity, etc. Then we can get rid of all these classes, and move everything to a shared benchmarking implementation.

D: int # input depth
R: int # filter depth
S_D: int # stride along depth
OP: str
input_dtype: str
output_dtype: str

def get_name(self) -> str:
return self.OP + "_" + f"{self.N}x{self.H}x{self.W}x{self.C}x{self.P}x{self.Q}x{self.F}" + "_" + f"{self.input_dtype}x{self.input_dtype}x{self.output_dtype}" + "_stride" + str(self.S)
return self.OP + "_" + f"{self.N}x{self.H}x{self.W}x{self.C}x{self.P}x{self.Q}x{self.F}" + "_" + f"{self.input_dtype}x{self.input_dtype}x{self.output_dtype}" + "_stride" + str(self.S) + "_groupcount" + str(self.G)

def get_img_shape(self) -> str:
in_h = self.H * self.S + self.P - 1
in_w = self.W * self.S + self.Q - 1
if "nhwc" in self.OP:
in_h = self.H * self.S + self.P - 1
in_w = self.W * self.S + self.Q - 1
return str(self.N) + "x" + str(in_h) + "x" + str(in_w) + "x" + str(self.C) + "x" + self.input_dtype
if "nchw" in self.OP:
in_h = self.H * self.S + self.P - 1
in_w = self.W * self.S + self.Q - 1
return str(self.N) + "x" + str(self.C) + "x" + str(in_h) + "x" + str(in_w) + "x" + self.input_dtype

if "ngchw" in operation:
return str(self.N) + "x" + str(self.G) + "x" + str(self.C) + "x" + str(in_h) + "x" + str(in_w) + "x" + self.input_dtype

def get_kernel_shape(self) -> str:
if "nhwc" in self.OP:
return str(self.P) + "x" + str(self.Q) + "x" + str(self.C) + "x" + str(self.F) + "x" + self.input_dtype
if "nchw" in self.OP:
return str(self.F) + "x" + str(self.C) + "x" + str(self.P) + "x" + str(self.Q) + "x" + self.input_dtype

if "ngchw" in operation:
return str(self.F) + "x" + str(self.G) + "x" + str(self.C) + "x" + str(self.P) + "x" + str(self.Q) + "x" + self.input_dtype

def get_byte_count(self) -> int:
dtype_bits_map = {
Expand All @@ -73,15 +81,16 @@ def get_byte_count(self) -> int:
in_h = self.H * self.S + self.P - 1
in_w = self.W * self.S + self.Q - 1
input_channels = self.C
group_count = self.G
output_channels = self.F
output_width = self.W
output_height = self.H
k_width = self.Q
k_height = self.P
byte_count = (
(batch * input_channels * in_w * in_h * bytes_per_input)
+ (batch * output_channels * output_width * output_height * bytes_per_output)
+ (k_width * k_height * input_channels * output_channels * bytes_per_input)
(batch * group_count * input_channels * in_w * in_h * bytes_per_input)
+ (batch * group_count * output_channels * output_width * output_height * bytes_per_output)
+ (group_count * k_width * k_height * input_channels * output_channels * bytes_per_input)
)
return byte_count

Expand All @@ -90,14 +99,15 @@ def get_flops(self) -> int:
in_h = self.H * self.S + self.P - 1
in_w = self.W * self.S + self.Q - 1
input_channels = self.C
group_count = self.G
output_channels = self.F
output_width = self.W
output_height = self.H
k_width = self.Q
k_height = self.P
operation_per_pixel = k_width * k_height * input_channels * 2
output_pixels_per_batch = output_width * output_height * output_channels
flops = operation_per_pixel * output_pixels_per_batch * batch
flops = operation_per_pixel * output_pixels_per_batch * group_count * batch
return flops

def generate_mlir(config: ConvConfig):
Expand All @@ -109,11 +119,16 @@ def generate_mlir(config: ConvConfig):
q = config.Q
f = config.F
stride = config.S
g = config.G
d = config.D
r = config.R
s_d = config.S_D
operation = config.OP
dtypes = f"{config.input_dtype}x{config.input_dtype}x{config.output_dtype}"
elem_types = dtypes.split("x")
in_h = str(int(h) * int(stride) + int(p) - 1)
in_w = str(int(w) * int(stride) + int(q) - 1)
in_d = str(int(d) * int(s_d) + int(r) - 1)
if "nhwc" in operation:
conv_type = "nhwc_hwcf"
lhs = str(n) + "x" + str(in_h) + "x" + str(in_w) + "x" + str(c) + "x" + str(elem_types[0])
Expand All @@ -124,6 +139,17 @@ def generate_mlir(config: ConvConfig):
lhs = str(n) + "x" + str(c) + "x" + str(in_h) + "x" + str(in_w) + "x" + str(elem_types[0])
rhs = str(f) + "x" + str(c) + "x" + str(p) + "x" + str(q) + "x" + str(elem_types[1])
out = str(n) + "x" + str(f) + "x" + str(h) + "x" + str(w) + "x" + str(elem_types[2])
if "ngchw" in operation:
conv_type = "ngchw_fgchw"
lhs = str(n) + "x" + str(g) + "x" + str(c) + "x" + str(in_h) + "x" + str(in_w) + "x" + str(elem_types[0])
rhs = str(f) + "x" + str(g) + "x" + str(c) + "x" + str(p) + "x" + str(q) + "x" + str(elem_types[1])
out = str(n) + "x" + str(g) + "x" + str(f) + "x" + str(h) + "x" + str(w) + "x" + str(elem_types[2])
if "ncdhw" in operation:
conv_type = "ncdhw_fcdhw"
lhs = str(n) + "x" + str(c) + "x" + str(in_d) + "x" + str(in_h) + "x" + str(in_w) + "x" + str(elem_types[0])
rhs = str(f) + "x" + str(c) + "x" + str(r) + "x" + str(p) + "x" + str(q) + "x" + str(elem_types[1])
out = str(n) + "x" + str(f) + "x" + str(d) + "x" + str(h) + "x" + str(w) + "x" + str(elem_types[2])

one = "1"
zero = "0"
if (elem_types[0][0] == "f"):
Expand All @@ -132,6 +158,8 @@ def generate_mlir(config: ConvConfig):
conv_template = CONV
if "q" in operation:
conv_template = CONV_Q
if config.is_3D_conv:
conv_template = CONV_3D
operation = conv_template.format(
INPUT_TYPE=lhs,
FILTER_TYPE=rhs,
Expand Down
76 changes: 38 additions & 38 deletions convbench/problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,49 +4,49 @@
def unet_sweep(op: str, input_dtype: str, output_dtype: str) -> list[ConvConfig]:
configs = []
for B in [1, 2, 4, 8]:
configs.append(ConvConfig(B, 128, 128, 16, 3, 3, 320, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 320, 3, 3, 320, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 320, 3, 3, 320, 2, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 320, 3, 3, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 640, 3, 3, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 320, 1, 1, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 640, 3, 3, 640, 2, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 640, 3, 3, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 1280, 3, 3, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 640, 1, 1, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 2560, 3, 3, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 2560, 1, 1, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 1920, 3, 3, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 1920, 1, 1, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1280, 3, 3, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1920, 3, 3, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1920, 1, 1, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1280, 3, 3, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1280, 1, 1, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 960, 3, 3, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 960, 1, 1, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 640, 3, 3, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 960, 3, 3, 320, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 960, 1, 1, 320, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 640, 3, 3, 320, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 640, 1, 1, 320, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 320, 3, 3, 16, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 16, 3, 3, 320, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 320, 3, 3, 320, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 320, 3, 3, 320, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 320, 3, 3, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 640, 3, 3, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 320, 1, 1, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 640, 3, 3, 640, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 640, 3, 3, 1280, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 1280, 3, 3, 1280, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 640, 1, 1, 1280, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 2560, 3, 3, 1280, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 2560, 1, 1, 1280, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 1920, 3, 3, 1280, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 1920, 1, 1, 1280, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1280, 3, 3, 1280, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1920, 3, 3, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1920, 1, 1, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1280, 3, 3, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1280, 1, 1, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 960, 3, 3, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 960, 1, 1, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 640, 3, 3, 640, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 960, 3, 3, 320, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 960, 1, 1, 320, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 640, 3, 3, 320, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 640, 1, 1, 320, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 320, 3, 3, 16, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
return configs

def resnet_sweep(op: str, input_dtype: str, output_dtype: str) -> list[ConvConfig]:
configs = []
for B in [1, 2, 4, 8]:
configs.append(ConvConfig(B, 112, 112, 64, 7, 7, 3, 2, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 56, 56, 64, 3, 3, 64, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 28, 28, 128, 3, 3, 128, 2, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 28, 28, 512, 1, 1, 256, 2, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 28, 28, 128, 3, 3, 128, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 14, 14, 256, 3, 3, 256, 2, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 14, 14, 1024, 1, 1, 512, 2, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 14, 14, 256, 3, 3, 256, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 7, 7, 512, 3, 3, 512, 2, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 7, 7, 2048, 1, 1, 1024, 2, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 7, 7, 512, 3, 3, 512, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 112, 112, 64, 7, 7, 3, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 56, 56, 64, 3, 3, 64, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 28, 28, 128, 3, 3, 128, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 28, 28, 512, 1, 1, 256, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 28, 28, 128, 3, 3, 128, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 14, 14, 256, 3, 3, 256, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 14, 14, 1024, 1, 1, 512, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 14, 14, 256, 3, 3, 256, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 7, 7, 512, 3, 3, 512, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 7, 7, 2048, 1, 1, 1024, 2, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 7, 7, 512, 3, 3, 512, 1, False, 1, False, -1, -1, -1, op, input_dtype, output_dtype))
return configs

def get_conv_configs() -> list[tuple[str, ConvConfig]]:
Expand Down
Loading