Skip to content

[CPU - Linux] AVX SIMD backend for fp16 and bf16 matmul#3502

Open
acsweet wants to merge 4 commits into
ml-explore:mainfrom
acsweet:simd-backend-avx
Open

[CPU - Linux] AVX SIMD backend for fp16 and bf16 matmul#3502
acsweet wants to merge 4 commits into
ml-explore:mainfrom
acsweet:simd-backend-avx

Conversation

@acsweet
Copy link
Copy Markdown

@acsweet acsweet commented May 9, 2026

Proposed changes

This PR adds an AVX SIMD backend for fp16 and bf16 matmul (GEMM and GEMV) on CPU for Linux. Follows from the discussion in #2037, and is a precursor to adding the full set of AVX SIMD instructions in a follow-up PR. Let me know what you think, I'd appreciate any feedback (including adjustments to benchmarking methodology).

I modified the bench_gemm.py and bench_gemv.py in benchmarks/python/blas so they'd complete in a reasonable amount of time. I ran them with a build of mlx from this PR and against the official mlx release for comparison. Note I left out the other dtypes from the benchmarked results printed below due to potential build differences (could be an error on my part). I built mlx with:

CMAKE_ARGS="-DMLX_BUILD_CPU=ON -DMLX_BUILD_CUDA=OFF -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas" CMAKE_BUILD_PARALLEL_LEVEL=8 pip install .

Bench setup

  • OS: Arch Linux, kernel 6.18.9-arch1-2 x86_64
  • CPU: Intel Core i7-10700K
  • mlx baseline: mlx-cpu==0.31.2
  • torch comparison: torch==2.5.1+cpu
  • benchmark commands:
python benchmarks/python/blas/bench_gemm.py --quick --verbose --single-threaded
python benchmarks/python/blas/bench_gemv.py --quick --verbose --single-threaded

Bench results

GEMM - branch (this PR)

B M N K dtype t torch_gf mlx_gf diff
16 234 768 3072 float16 nn 1.510 93.479 +6091.42%
1 1024 1024 2048 float16 nn 1.103 82.306 +7362.47%
16 234 768 3072 float16 nt 2.319 91.380 +3840.00%
1 1024 1024 2048 float16 nt 2.318 81.883 +3431.84%
16 234 768 3072 float16 tn 4.056 95.101 +2244.59%
1 1024 1024 2048 float16 tn 4.073 83.882 +1959.25%

GEMM - mlx-cpu==0.31.2

B M N K dtype t torch_gf mlx_gf diff
16 234 768 3072 float16 nn 1.623 3.641 +124.27%
1 1024 1024 2048 float16 nn 1.511 3.645 +141.26%
16 234 768 3072 float16 nt 2.320 3.884 +67.40%
1 1024 1024 2048 float16 nt 2.317 3.878 +67.42%
16 234 768 3072 float16 tn 4.067 3.532 -13.15%
1 1024 1024 2048 float16 tn 4.080 3.459 -15.22%

GEMV - branch (this PR)

============================================================
gemv | float16 | device: cpu
============================================================
--- sweep out_vec_len (fixed in_vec_len) ---
     in,   out,   mlx_GB/s,  trc_GB/s,    diff
  in=  512, out=  128, mlx=   8.92 GB/s, torch=   1.66 GB/s, diff=+436.2%
  in=  512, out= 1024, mlx=  21.82 GB/s, torch=   1.89 GB/s, diff=+1055.4%
  in=  512, out= 4096, mlx=  26.47 GB/s, torch=   1.88 GB/s, diff=+1306.6%
  in=  512, out=11008, mlx=  16.17 GB/s, torch=   1.63 GB/s, diff=+892.2%
     in,   out,   mlx_GB/s,  trc_GB/s,    diff
  in= 2048, out=  128, mlx=  26.23 GB/s, torch=   1.83 GB/s, diff=+1330.2%
  in= 2048, out= 1024, mlx=  30.29 GB/s, torch=   1.53 GB/s, diff=+1882.1%
  in= 2048, out= 4096, mlx=  20.94 GB/s, torch=   1.85 GB/s, diff=+1033.7%
  in= 2048, out=11008, mlx=  18.57 GB/s, torch=   1.59 GB/s, diff=+1070.8%
--- sweep in_vec_len (fixed out_vec_len) ---
     in,   out,   mlx_GB/s,  trc_GB/s,    diff
  in=  128, out=  512, mlx=  12.20 GB/s, torch=   1.23 GB/s, diff=+895.1%
  in= 1024, out=  512, mlx=  29.48 GB/s, torch=   1.75 GB/s, diff=+1585.6%
  in= 4096, out=  512, mlx=  25.17 GB/s, torch=   1.66 GB/s, diff=+1413.1%
  in=11008, out=  512, mlx=  39.61 GB/s, torch=   1.75 GB/s, diff=+2167.0%
     in,   out,   mlx_GB/s,  trc_GB/s,    diff
  in=  128, out= 2048, mlx=  20.01 GB/s, torch=   2.21 GB/s, diff=+803.5%
  in= 1024, out= 2048, mlx=  34.55 GB/s, torch=   2.29 GB/s, diff=+1410.6%
  in= 4096, out= 2048, mlx=  16.50 GB/s, torch=   2.08 GB/s, diff=+692.2%
  in=11008, out= 2048, mlx=  19.15 GB/s, torch=   1.91 GB/s, diff=+900.7%


============================================================
gemv_t | float16 | device: cpu
============================================================
--- sweep out_vec_len (fixed in_vec_len) ---
     in,   out,   mlx_GB/s,  trc_GB/s,    diff
  in=  512, out=  128, mlx=   6.97 GB/s, torch=   1.77 GB/s, diff=+294.4%
  in=  512, out= 1024, mlx=  13.19 GB/s, torch=   0.95 GB/s, diff=+1290.3%
  in=  512, out= 4096, mlx=  15.76 GB/s, torch=   0.81 GB/s, diff=+1839.1%
  in=  512, out=11008, mlx=  12.32 GB/s, torch=   0.95 GB/s, diff=+1193.0%
     in,   out,   mlx_GB/s,  trc_GB/s,    diff
  in= 2048, out=  128, mlx=  10.15 GB/s, torch=   0.85 GB/s, diff=+1099.1%
  in= 2048, out= 1024, mlx=  13.90 GB/s, torch=   0.87 GB/s, diff=+1499.0%
  in= 2048, out= 4096, mlx=  12.03 GB/s, torch=   0.55 GB/s, diff=+2090.4%
  in= 2048, out=11008, mlx=  16.71 GB/s, torch=   1.43 GB/s, diff=+1066.2%
--- sweep in_vec_len (fixed out_vec_len) ---
     in,   out,   mlx_GB/s,  trc_GB/s,    diff
  in=  128, out=  512, mlx=  11.21 GB/s, torch=   2.00 GB/s, diff=+460.8%
  in= 1024, out=  512, mlx=  14.76 GB/s, torch=   1.17 GB/s, diff=+1161.1%
  in= 4096, out=  512, mlx=  17.52 GB/s, torch=   1.16 GB/s, diff=+1412.8%
  in=11008, out=  512, mlx=  15.96 GB/s, torch=   1.07 GB/s, diff=+1385.2%
     in,   out,   mlx_GB/s,  trc_GB/s,    diff
  in=  128, out= 2048, mlx=  16.26 GB/s, torch=   1.27 GB/s, diff=+1180.0%
  in= 1024, out= 2048, mlx=  22.60 GB/s, torch=   1.41 GB/s, diff=+1502.3%
  in= 4096, out= 2048, mlx=  16.50 GB/s, torch=   0.58 GB/s, diff=+2748.3%
  in=11008, out= 2048, mlx=  18.33 GB/s, torch=   0.43 GB/s, diff=+4210.5%

GEMV - mlx-cpu==0.31.2

============================================================
gemv | float16 | device: cpu
============================================================
--- sweep out_vec_len (fixed in_vec_len) ---
     in,   out,   mlx_GB/s,  trc_GB/s,    diff
  in=  512, out=  128, mlx=   1.06 GB/s, torch=   2.13 GB/s, diff=-50.2%
  in=  512, out= 1024, mlx=   1.15 GB/s, torch=   2.23 GB/s, diff=-48.5%
  in=  512, out= 4096, mlx=   1.17 GB/s, torch=   2.24 GB/s, diff=-48.1%
  in=  512, out=11008, mlx=   1.11 GB/s, torch=   2.20 GB/s, diff=-49.6%
     in,   out,   mlx_GB/s,  trc_GB/s,    diff
  in= 2048, out=  128, mlx=   0.94 GB/s, torch=   2.22 GB/s, diff=-57.7%
  in= 2048, out= 1024, mlx=   0.97 GB/s, torch=   2.25 GB/s, diff=-57.0%
  in= 2048, out= 4096, mlx=   0.97 GB/s, torch=   2.20 GB/s, diff=-56.2%
  in= 2048, out=11008, mlx=   0.96 GB/s, torch=   2.20 GB/s, diff=-56.5%
--- sweep in_vec_len (fixed out_vec_len) ---
     in,   out,   mlx_GB/s,  trc_GB/s,    diff
  in=  128, out=  512, mlx=   1.06 GB/s, torch=   2.01 GB/s, diff=-47.3%
  in= 1024, out=  512, mlx=   1.06 GB/s, torch=   2.25 GB/s, diff=-53.0%
  in= 4096, out=  512, mlx=   0.83 GB/s, torch=   2.26 GB/s, diff=-63.5%
  in=11008, out=  512, mlx=   0.58 GB/s, torch=   2.21 GB/s, diff=-73.6%
     in,   out,   mlx_GB/s,  trc_GB/s,    diff
  in=  128, out= 2048, mlx=   1.18 GB/s, torch=   2.15 GB/s, diff=-45.1%
  in= 1024, out= 2048, mlx=   1.08 GB/s, torch=   2.25 GB/s, diff=-51.8%
  in= 4096, out= 2048, mlx=   0.83 GB/s, torch=   2.25 GB/s, diff=-63.3%
  in=11008, out= 2048, mlx=   0.58 GB/s, torch=   2.20 GB/s, diff=-73.5%

============================================================
gemv_t | float16 | device: cpu
============================================================
--- sweep out_vec_len (fixed in_vec_len) ---
     in,   out,   mlx_GB/s,  trc_GB/s,    diff
  in=  512, out=  128, mlx=   1.01 GB/s, torch=   1.70 GB/s, diff=-40.6%
  in=  512, out= 1024, mlx=   0.90 GB/s, torch=   1.42 GB/s, diff=-36.4%
  in=  512, out= 4096, mlx=   0.92 GB/s, torch=   1.57 GB/s, diff=-41.2%
  in=  512, out=11008, mlx=   0.96 GB/s, torch=   1.90 GB/s, diff=-49.3%
     in,   out,   mlx_GB/s,  trc_GB/s,    diff
  in= 2048, out=  128, mlx=   0.87 GB/s, torch=   1.40 GB/s, diff=-38.1%
  in= 2048, out= 1024, mlx=   0.81 GB/s, torch=   1.42 GB/s, diff=-43.1%
  in= 2048, out= 4096, mlx=   0.56 GB/s, torch=   0.68 GB/s, diff=-18.4%
  in= 2048, out=11008, mlx=   0.78 GB/s, torch=   1.55 GB/s, diff=-49.6%
--- sweep in_vec_len (fixed out_vec_len) ---
     in,   out,   mlx_GB/s,  trc_GB/s,    diff
  in=  128, out=  512, mlx=   1.04 GB/s, torch=   2.09 GB/s, diff=-50.0%
  in= 1024, out=  512, mlx=   0.87 GB/s, torch=   1.44 GB/s, diff=-39.1%
  in= 4096, out=  512, mlx=   0.70 GB/s, torch=   1.42 GB/s, diff=-50.4%
  in=11008, out=  512, mlx=   0.39 GB/s, torch=   1.18 GB/s, diff=-66.7%
     in,   out,   mlx_GB/s,  trc_GB/s,    diff
  in=  128, out= 2048, mlx=   0.92 GB/s, torch=   1.53 GB/s, diff=-40.0%
  in= 1024, out= 2048, mlx=   0.88 GB/s, torch=   1.57 GB/s, diff=-44.2%
  in= 4096, out= 2048, mlx=   0.49 GB/s, torch=   0.58 GB/s, diff=-14.9%
  in=11008, out= 2048, mlx=   0.31 GB/s, torch=   0.41 GB/s, diff=-25.2%

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@zcbenz
Copy link
Copy Markdown
Collaborator

zcbenz commented May 10, 2026

@dhiltgen I remember ollama was doing something similar? Can you please check if this would live together with your work?

@dhiltgen
Copy link
Copy Markdown
Contributor

I have an older PR #3019 which I've been meaning to break up into smaller chunks.

I'll add inline comments on this PR with some suggestions on how this could become a partial precursor to that broader implementation.

@@ -0,0 +1,432 @@
// Copyright © 2025 Apple Inc.
#pragma once
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is largely GEMM-oriented helpers this could move to mlx/backend/cpu/gemms/avx2_gemm_simd.h and use a GEMM-private namespace rather than mlx::core::simd. That lets a future broad AVX2 SIMD layer land without colliding with this PR.

Comment thread CMakeLists.txt
message(
STATUS "Compiler supports AVX2/FMA/F16C - enabling AVX SIMD backend")
add_compile_options(-mavx2 -mfma -mf16c)
add_compile_definitions(MLX_USE_AVX)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest MLX_USE_AVX2 for clarity, and to set up for future MLX_USE_AVX512 or other variations.

Comment thread CMakeLists.txt
AND HAS_F16C)
message(
STATUS "Compiler supports AVX2/FMA/F16C - enabling AVX SIMD backend")
add_compile_options(-mavx2 -mfma -mf16c)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would scope this to the mlx target, or CPU backend.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants