Skip to content

Commit 4445aa2

Browse files
bertmaherfacebook-github-bot
authored andcommitted
Add layout options to gemm
Summary: We were only benchmarking `row-major x row-major` gemms (also called `TT` or `transpose-transpose`, because FORTRAN), which is actually not the common case; `nn.Linear` will use column-major layouts for weights, which means `TN` is actually much more common. Reviewed By: adamomainz Differential Revision: D63714661 fbshipit-source-id: 735c25c59ddeb6596afd9b19f463af92036a830b
1 parent d512e67 commit 4445aa2

File tree

3 files changed

+37
-37
lines changed

3 files changed

+37
-37
lines changed

torchbenchmark/operators/gemm/data_io.py

Lines changed: 0 additions & 34 deletions
This file was deleted.

torchbenchmark/operators/gemm/operator.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import torch._inductor.config as inductor_config
1010
import triton
1111

12+
from torchbenchmark import REPO_PATH
13+
1214
from torchbenchmark.util.triton_op import (
1315
BenchmarkOperator,
1416
BenchmarkOperatorMetrics,
@@ -19,7 +21,6 @@
1921
register_x_val,
2022
)
2123

22-
from .data_io import parse_args, read_shapes_from_csv
2324
from .kernels import matmul as kernels
2425
from .partition_k import matmul_partition_k
2526
from .persistent_matmul import (
@@ -88,6 +89,35 @@
8889
]
8990

9091

92+
def parse_args(args: List[str]) -> argparse.Namespace:
93+
parser = argparse.ArgumentParser(description="TorchBench Gemm operator Benchmark")
94+
parser.add_argument("--m", type=int)
95+
parser.add_argument("--k", type=int)
96+
parser.add_argument("--n", type=int)
97+
parser.add_argument("--bias", type=int)
98+
parser.add_argument("--input", type=str)
99+
parser.add_argument("--splitk", action="store_true", default=False)
100+
parser.add_argument("--llama", action="store_true", default=False)
101+
parser.add_argument("--layout", type=str, default="tn")
102+
args = parser.parse_args(args)
103+
return args
104+
105+
106+
def read_shapes_from_csv(csv_path: str) -> List[List[int]]:
107+
input_file_path = os.path.join(
108+
REPO_PATH, "torchbenchmark", "operators", "gemm", csv_path
109+
)
110+
shapes = []
111+
with open(input_file_path, "r") as f:
112+
reader = csv.DictReader(f)
113+
for row in reader:
114+
shape = [
115+
int(row.get(f)) if row.get(f) else None for f in ("M", "N", "K", "Bias")
116+
]
117+
shapes.append(shape)
118+
return shapes
119+
120+
91121
class Operator(BenchmarkOperator):
92122
DEFAULT_METRICS = ["speedup", "tflops"]
93123
DEFAULT_PRECISION = "fp16"
@@ -98,6 +128,7 @@ def __init__(
98128
super().__init__(tb_args, extra_args)
99129
self.use_cuda_graphs = False
100130
gemm_args = parse_args(self.extra_args)
131+
self.layout = gemm_args.layout
101132
if gemm_args.input:
102133
self.shapes = read_shapes_from_csv(gemm_args.input)
103134
elif gemm_args.splitk:
@@ -261,6 +292,11 @@ def get_input_iter(self) -> Generator:
261292
w = self._scaled_randn(
262293
(k, n), scale=k, device=self.device, dtype=self.dtype
263294
)
295+
# Convert inputs to column-major if layout is "n" (non-transposed)
296+
if self.layout[0] == "n":
297+
a = a.T.contiguous().T
298+
if self.layout[1] == "n":
299+
w = w.T.contiguous().T
264300
if not bias == None:
265301
bias = torch.randn(
266302
(bias), device=self.device, dtype=self.dtype

torchbenchmark/operators/gemm/triton_matmul.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,6 @@ def leaky_relu(x):
201201
def matmul(a, b, activation=""):
202202
# Check constraints.
203203
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
204-
assert a.is_contiguous(), "Matrix A must be contiguous"
205-
assert b.is_contiguous(), "Matrix B must be contiguous"
206204
M, K = a.shape
207205
K, N = b.shape
208206
# Allocates output.

0 commit comments

Comments
 (0)