[CPU - Linux] AVX SIMD backend for fp16 and bf16 matmul#3502
Conversation
|
@dhiltgen I remember ollama was doing something similar? Can you please check if this would live together with your work? |
|
I have an older PR #3019 which I've been meaning to break up into smaller chunks. I'll add inline comments on this PR with some suggestions on how this could become a partial precursor to that broader implementation. |
| @@ -0,0 +1,432 @@ | |||
| // Copyright © 2025 Apple Inc. | |||
| #pragma once | |||
There was a problem hiding this comment.
Since this is largely GEMM-oriented helpers this could move to mlx/backend/cpu/gemms/avx2_gemm_simd.h and use a GEMM-private namespace rather than mlx::core::simd. That lets a future broad AVX2 SIMD layer land without colliding with this PR.
| message( | ||
| STATUS "Compiler supports AVX2/FMA/F16C - enabling AVX SIMD backend") | ||
| add_compile_options(-mavx2 -mfma -mf16c) | ||
| add_compile_definitions(MLX_USE_AVX) |
There was a problem hiding this comment.
I would suggest MLX_USE_AVX2 for clarity, and to set up for future MLX_USE_AVX512 or other variations.
| AND HAS_F16C) | ||
| message( | ||
| STATUS "Compiler supports AVX2/FMA/F16C - enabling AVX SIMD backend") | ||
| add_compile_options(-mavx2 -mfma -mf16c) |
There was a problem hiding this comment.
I would scope this to the mlx target, or CPU backend.
Proposed changes
This PR adds an AVX SIMD backend for fp16 and bf16 matmul (GEMM and GEMV) on CPU for Linux. Follows from the discussion in #2037, and is a precursor to adding the full set of AVX SIMD instructions in a follow-up PR. Let me know what you think, I'd appreciate any feedback (including adjustments to benchmarking methodology).
I modified the
bench_gemm.pyandbench_gemv.pyinbenchmarks/python/blasso they'd complete in a reasonable amount of time. I ran them with a build of mlx from this PR and against the official mlx release for comparison. Note I left out the other dtypes from the benchmarked results printed below due to potential build differences (could be an error on my part). I built mlx with:Bench setup
6.18.9-arch1-2 x86_64mlx-cpu==0.31.2torch==2.5.1+cpuBench results
GEMM - branch (this PR)
GEMM -
mlx-cpu==0.31.2GEMV - branch (this PR)
GEMV -
mlx-cpu==0.31.2Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes