Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
6be1412
add template to support more dtypes
jiqing-feng Oct 28, 2025
252ac0f
update cmake list
jiqing-feng Oct 28, 2025
f98c9e5
fix typo
jiqing-feng Oct 28, 2025
902bf35
fix compile cpu
jiqing-feng Oct 28, 2025
fef8459
make different dtype works
jiqing-feng Oct 29, 2025
55cbaa0
use bf16 on CPU
jiqing-feng Oct 29, 2025
bbef95b
fix state2 dtype
jiqing-feng Oct 29, 2025
e842513
remove torch
jiqing-feng Oct 30, 2025
d4473fa
rm torch
jiqing-feng Oct 30, 2025
dea8dd6
enable float to bf16
jiqing-feng Oct 30, 2025
e9bb4fe
rm dequantizeBlockwise4bitCpu
jiqing-feng Oct 30, 2025
cdc8d5e
fix check
jiqing-feng Oct 30, 2025
baacfac
enable dequant 4bit kernel
jiqing-feng Oct 30, 2025
eec3521
fix typo
jiqing-feng Oct 30, 2025
d7cc1c5
fix typo
jiqing-feng Oct 30, 2025
124b754
fix dequantize
jiqing-feng Oct 30, 2025
0f918c7
fix
jiqing-feng Oct 30, 2025
e1a8b20
fix
jiqing-feng Oct 30, 2025
eab45c8
test
jiqing-feng Oct 30, 2025
d9f5dd8
fix
jiqing-feng Oct 30, 2025
070f8a0
fix
jiqing-feng Oct 30, 2025
a84addf
fix
jiqing-feng Oct 30, 2025
c4bb660
fix
jiqing-feng Oct 30, 2025
4ba13fd
fix
jiqing-feng Oct 30, 2025
c0d05ec
change input param
jiqing-feng Oct 31, 2025
62a16a6
fix typo
jiqing-feng Oct 31, 2025
d9ad828
fix input param
jiqing-feng Oct 31, 2025
09ed6cb
spliut 8bit and 4bit
jiqing-feng Oct 31, 2025
a3f7b61
fix typo
jiqing-feng Oct 31, 2025
4708470
fix typo
jiqing-feng Oct 31, 2025
1dfe9f7
fix input params
jiqing-feng Oct 31, 2025
00289c4
fix input params
jiqing-feng Oct 31, 2025
a2578ba
fix
jiqing-feng Oct 31, 2025
72033dc
fix typo
jiqing-feng Oct 31, 2025
1c20ae8
enable dequant4bit
jiqing-feng Oct 31, 2025
7552fe2
fix
jiqing-feng Oct 31, 2025
8b32a39
fix
jiqing-feng Oct 31, 2025
8f1cc36
fix reverse
jiqing-feng Oct 31, 2025
49d242a
fix dequant 4bit fallback path
jiqing-feng Nov 3, 2025
4a9a6dc
fix fp4 dequant
jiqing-feng Nov 3, 2025
6bcd19e
Merge branch 'main' into cpu_kernel
jiqing-feng Nov 4, 2025
d7e981d
rm _Float16
jiqing-feng Nov 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,16 @@ else()
set(BUILD_HIP OFF)
set(BUILD_MPS OFF)
set(BUILD_XPU OFF)
set(BUILD_CPU ON)
endif()


if (BUILD_CPU)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
find_package(OpenMP)
endif()

if(BUILD_CUDA)
# NVCC normally will only work with MSVC up to 1939. VS2022 17.10+ starts using versions 1940+.
# Workaround: use --allow-unsupported-compiler
Expand Down Expand Up @@ -262,6 +269,21 @@ add_library(bitsandbytes SHARED ${SRC_FILES})
target_compile_features(bitsandbytes PUBLIC cxx_std_17)
target_include_directories(bitsandbytes PUBLIC csrc include)

if (BUILD_CPU)
target_link_libraries(bitsandbytes PRIVATE OpenMP::OpenMP_CXX)
include(CheckCXXCompilerFlag)

check_cxx_compiler_flag(-mavx512f HAS_AVX512F)
check_cxx_compiler_flag(-mavx512bf16 HAS_AVX512BF16)

if(HAS_AVX512F)
target_compile_options(bitsandbytes PRIVATE -mavx512f)
endif()

if(HAS_AVX512BF16)
target_compile_options(bitsandbytes PRIVATE -mavx512bf16)
endif()
endif()

if(BUILD_CUDA)
target_include_directories(bitsandbytes PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
Expand Down
3 changes: 3 additions & 0 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,9 @@ def matmul_4bit(
bias: Optional[torch.Tensor] = None,
):
assert quant_state is not None
# Change dtype to bfloat16 on CPU
if A.device.type == "cpu":
quant_state.dtype = A.dtype

if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu":
if A.shape[-1] % quant_state.blocksize != 0:
Expand Down
150 changes: 147 additions & 3 deletions bitsandbytes/backends/cpu/ops.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from collections.abc import Sequence
import ctypes as ct
import logging

import torch

from bitsandbytes.functional import get_ptr

from ..utils import CODE
from ..._ops import register_kernel
from ...cextension import ErrorHandlerMockBNBNativeLibrary, lib

Expand Down Expand Up @@ -76,10 +78,8 @@ def _(
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")

# Only FP32 has c++ kernrl
out = torch.empty_like(A, dtype=dtype)
if dtype == torch.float32:
out = torch.empty_like(A, dtype=dtype)

lib.cdequantize_blockwise_cpu_fp32(
get_ptr(code),
get_ptr(A),
Expand All @@ -88,6 +88,24 @@ def _(
ct.c_longlong(blocksize),
ct.c_longlong(A.numel()),
)
elif dtype == torch.bfloat16:
lib.cdequantize_blockwise_cpu_bf16(
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_longlong(blocksize),
ct.c_longlong(A.numel()),
)
elif dtype == torch.float16:
lib.cdequantize_blockwise_cpu_fp16(
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_longlong(blocksize),
ct.c_longlong(A.numel()),
)
else:
out = code[A.reshape(-1).int()]
blocks = out.shape[-1] // blocksize
Expand All @@ -99,3 +117,129 @@ def _(
out = out.reshape(A.shape)

return out

@register_kernel("bitsandbytes::dequantize_4bit", "cpu")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}")
torch._check(
dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
)
# Enable non uint8 dtype
if A.dtype != torch.uint8:
A = A.view(torch.uint8)

# TODO: support half precision absmax
if absmax.dtype != torch.float32:
absmax = absmax.float()

A = A.reshape(shape[0], shape[1] // 2)
out = torch.empty(shape, dtype=dtype, device=A.device)
if quant_type == "fp4":
if dtype == torch.float32:
lib.cdequantize_blockwise_cpu_fp4_fp32(
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_longlong(blocksize),
ct.c_longlong(shape[0]),
ct.c_longlong(shape[1]),
)
elif dtype == torch.bfloat16:
lib.cdequantize_blockwise_cpu_fp4_bf16(
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_longlong(blocksize),
ct.c_longlong(shape[0]),
ct.c_longlong(shape[1]),
)
elif dtype == torch.float16:
lib.cdequantize_blockwise_cpu_fp4_fp16(
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_longlong(blocksize),
ct.c_longlong(shape[0]),
ct.c_longlong(shape[1]),
)
elif quant_type == "nf4":
if dtype == torch.float32:
lib.cdequantize_blockwise_cpu_nf4_fp32(
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_longlong(blocksize),
ct.c_longlong(shape[0]),
ct.c_longlong(shape[1]),
)
elif dtype == torch.bfloat16:
lib.cdequantize_blockwise_cpu_nf4_bf16(
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_longlong(blocksize),
ct.c_longlong(shape[0]),
ct.c_longlong(shape[1]),
)
# out_2 = dequantize_nf4_test(_reverse_4bit_compress_format(A.reshape(-1)), absmax, blocksize, quant_type, shape, dtype)
# if not torch.allclose(out, out_2, rtol=1e-2, atol=5e-2):
# import pdb; pdb.set_trace()
elif dtype == torch.float16:
lib.cdequantize_blockwise_cpu_nf4_fp16(
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_longlong(blocksize),
ct.c_longlong(shape[0]),
ct.c_longlong(shape[1]),
)
else:
raise ValueError

return out

def dequantize_nf4_test(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
):
# Map nf4 to [-1, 1]
out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device)
n = out_dq.numel()
out_dq[1::2] = A & 0xF
out_dq[::2] = A >> 4
# code is fp32, cast to dtype to avoid the mismatch issue
code = CODE[quant_type].to(dtype).to(A.device)
out_dq = code[out_dq]

# Apply scales
if out_dq.numel() != n:
assert out_dq.numel() == n + 1
out_dq = torch.narrow(out_dq, 0, 0, n)
blocks = n // blocksize
blocks += 1 if n % blocksize > 0 else 0
rem = n % blocksize
has_rem = rem > 0

if has_rem:
out[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape(-1)
out[n - rem :] = out_dq[n - rem :] * absmax[-1]
else:
out = out_dq.view(-1, blocksize) * absmax.view(-1, 1)

out = out.reshape(-1, *shape[1:]).to(dtype)

return out

Loading