Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,4 @@ uv.lock
.cache/
# vim
*.swp
results/
29 changes: 29 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,35 @@ if(WIN32)
endif()

if(MLX_BUILD_CPU)
# ----------------------------- x86 SIMD --------------------------------
if(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64|AMD64|i[3-9]86")
include(CheckCXXCompilerFlag)
check_cxx_compiler_flag("-mavx2" HAS_AVX2)
check_cxx_compiler_flag("-mfma" HAS_FMA)
check_cxx_compiler_flag("-mf16c" HAS_F16C)

if(HAS_AVX2
AND HAS_FMA
AND HAS_F16C)
message(
STATUS "Compiler supports AVX2/FMA/F16C - enabling AVX2 SIMD backend")
target_compile_options(mlx PRIVATE -mavx2 -mfma -mf16c)
target_compile_definitions(mlx PRIVATE MLX_USE_AVX2)
else()
message(
STATUS "Missing required x86 SIMD support - using base SIMD backend")
if(NOT HAS_AVX2)
message(STATUS " Missing: AVX2")
endif()
if(NOT HAS_FMA)
message(STATUS " Missing: FMA")
endif()
if(NOT HAS_F16C)
message(STATUS " Missing: F16C")
endif()
endif()
endif()

find_library(ACCELERATE_LIBRARY Accelerate)
if(ACCELERATE_LIBRARY)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
Expand Down
162 changes: 129 additions & 33 deletions benchmarks/python/blas/bench_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,44 @@
import numpy as np
import torch

device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
device_name = device_name.decode("utf-8").strip("\n")
try:
device_name = (
subprocess.check_output(
["sysctl", "-n", "machdep.cpu.brand_string"], stderr=subprocess.DEVNULL
)
.decode("utf-8")
.strip()
)
except (subprocess.CalledProcessError, FileNotFoundError):
device_name = "unknown"

if torch.backends.mps.is_available():
torch_device = "mps"
torch_sync = torch.mps.synchronize
elif torch.cuda.is_available():
torch_device = "cuda"
torch_sync = torch.cuda.synchronize
else:
torch_device = "cpu"
torch_sync = lambda: None

N_warmup = 8
N_iter_bench = 80
N_iter_func = 5
FULL_WARMUP = 8
FULL_ITER_BENCH = 80
FULL_ITER_FUNC = 5

QUICK_WARMUP = 2
QUICK_ITER_BENCH = 10
QUICK_ITER_FUNC = 5

N_warmup = FULL_WARMUP
N_iter_bench = FULL_ITER_BENCH
N_iter_func = FULL_ITER_FUNC


def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
torch_sync()

s = time.perf_counter_ns()
for i in range(N_iter_bench):
Expand Down Expand Up @@ -72,7 +98,7 @@ def gemm_nn_torch(a, b):
for i in range(N_iter_func):
y = a @ b
ys.append(y)
torch.mps.synchronize()
torch_sync()
return ys


Expand All @@ -82,7 +108,7 @@ def gemm_nt_torch(a, b):
for i in range(N_iter_func):
y = a @ b.transpose(-1, -2)
ys.append(y)
torch.mps.synchronize()
torch_sync()
return ys


Expand All @@ -92,7 +118,7 @@ def gemm_tn_torch(a, b):
for i in range(N_iter_func):
y = a.transpose(-1, -2) @ b
ys.append(y)
torch.mps.synchronize()
torch_sync()
return ys


Expand All @@ -102,11 +128,11 @@ def gemm_tt_torch(a, b):
for i in range(N_iter_func):
y = a.transpose(-1, -2) @ b.transpose(-1, -2)
ys.append(y)
torch.mps.synchronize()
torch_sync()
return ys


def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
def bench_shape(B, M, N, K, np_dtype, transpose="nn", max_torch_ops=None):
shape_a = (B, M, K) if transpose[0] == "n" else (B, K, M)
shape_b = (B, K, N) if transpose[1] == "n" else (B, N, K)

Expand All @@ -116,10 +142,10 @@ def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)

a_pt = torch.from_numpy(a_np).to("mps")
b_pt = torch.from_numpy(b_np).to("mps")
a_pt = torch.from_numpy(a_np).to(torch_device)
b_pt = torch.from_numpy(b_np).to(torch_device)

torch.mps.synchronize()
torch_sync()

f_mx = {
"nn": gemm_nn_mlx,
Expand All @@ -135,7 +161,11 @@ def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
"tt": gemm_tt_torch,
}[transpose]

time_torch = bench(f_pt, a_pt, b_pt)
gemm_ops = B * M * N * K
time_torch = None
if max_torch_ops is None or gemm_ops <= max_torch_ops:
time_torch = bench(f_pt, a_pt, b_pt)

time_mlx = bench(f_mx, a_mx, b_mx)

t_a = (0, 1, 2) if transpose[0] == "n" else (0, 2, 1)
Expand All @@ -158,34 +188,100 @@ def get_gflop_count(B, M, N, K):
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)


if __name__ == "__main__":
def main():
global N_warmup, N_iter_bench, N_iter_func

parser = argparse.ArgumentParser(description="Run gemm benchmarks")
parser.add_argument(
"--quick",
action="store_true",
help="Run fewer iterations and a reduced shape set.",
)
parser.add_argument(
"--max-torch-ops",
type=int,
default=None,
help="Skip PyTorch timing for cases where B*M*N*K exceeds this value.",
)
parser.add_argument(
"--verbose",
action="store_true",
help="Print per-shape timing results.",
)
parser.add_argument(
"--single-threaded",
action="store_true",
help="Set OMP_NUM_THREADS=1 and OPENBLAS_NUM_THREADS=1 for single-threaded PyTorch/NumPy comparison.",
)
args = parser.parse_args()

if args.single_threaded:
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"

if args.quick:
N_warmup = QUICK_WARMUP
N_iter_bench = QUICK_ITER_BENCH
N_iter_func = QUICK_ITER_FUNC
else:
N_warmup = FULL_WARMUP
N_iter_bench = FULL_ITER_BENCH
N_iter_func = FULL_ITER_FUNC

dtypes = ("float32", "float16", "complex64")
transposes = ("nn", "nt", "tn")
shapes = (
(16, 234, 768, 3072),
(1, 64, 64, 25344),
(16, 1024, 1024, 1024),
(1, 1024, 1024, 2048),
(4, 1024, 1024, 4096),
(4, 1024, 4096, 1024),
(1, 4096, 4096, 4096),
)
if args.quick:
shapes = (
(16, 234, 768, 3072),
(1, 1024, 1024, 2048),
)
else:
shapes = (
(16, 234, 768, 3072),
(1, 64, 64, 25344),
(16, 1024, 1024, 1024),
(1, 1024, 1024, 2048),
(4, 1024, 1024, 4096),
(4, 1024, 4096, 1024),
(1, 4096, 4096, 4096),
)

if args.verbose:
print(
f"{'B':>3}, {'M':>4}, {'N':>4}, {'K':>4}, {'dtype':<9}, {'t':<2}, torch_gf, mlx_gf, diff"
)
print("-" * 66)

for dtype in dtypes:
for transpose in transposes:
for B, M, N, K in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(B, M, N, K, np_dtype, transpose)
time_mlx, time_torch = bench_shape(
B,
M,
N,
K,
np_dtype,
transpose,
args.max_torch_ops,
)

gflop_count = get_gflop_count(B, M, N, K)
gflops_mx = gflop_count / (time_mlx)
gflops_pt = gflop_count / (time_torch)
diff = gflops_mx / gflops_pt - 1.0
if args.verbose:
if time_torch is None:
print(
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, skipped, {gflops_mx:05.3f}, n/a"
)
else:
gflops_pt = gflop_count / (time_torch)
diff = gflops_mx / gflops_pt - 1.0
print(
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100.0 * diff:+5.2f}%"
)
if gflops_pt >= 2.0 * gflops_mx:
print("ATTENTION ^^^^^^^")

print(
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100.0 * diff:+5.2f}%"
)
if gflops_pt >= 2.0 * gflops_mx:
print("ATTENTION ^^^^^^^")

if __name__ == "__main__":
main()
Loading