Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XPU backend support 8bit optimizer #1565

Open
wants to merge 9 commits into
base: multi-backend-refactor
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions bitsandbytes/backends/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ class CPUBackend(Backend):
mm_dequant_compute_dtype = torch.bfloat16
mm_dequant_output_dtype = torch.bfloat16

def device_synchronize(self):
pass

def int8_double_quant(
self,
A: torch.Tensor,
Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/backends/cpu_xpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _ipex_xpu_version_prereq(major, minor):

def _maybe_torch_compile(func):
# torch.compile requires g++ and pytorch >= 2.0
if gxx_available and _torch_version_prereq(2, 0) and not ipex_xpu:
if gxx_available and _torch_version_prereq(2, 0) and ipex_cpu_only:
options = {}
# fx_graph_cache requires pytorch >= 2.2
if _torch_version_prereq(2, 2):
Expand Down
3 changes: 3 additions & 0 deletions bitsandbytes/backends/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@


class CUDABackend(Backend):
def device_synchronize(self):
torch.cuda.synchronize()

def transform(
self,
A: torch.Tensor,
Expand Down
3 changes: 3 additions & 0 deletions bitsandbytes/backends/mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@


class MPSBackend(Backend):
def device_synchronize(self):
torch.mps.synchronize()

def double_quant(
self,
A: torch.Tensor,
Expand Down
3 changes: 3 additions & 0 deletions bitsandbytes/backends/npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def assert_on_npu(tensors):


class NPUBackend(Backend):
def device_synchronize(self):
torch.npu.synchronize()

def int8_double_quant(
self,
A: torch.Tensor,
Expand Down
77 changes: 75 additions & 2 deletions bitsandbytes/backends/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,28 @@
int8_linear_matmul_impl,
int8_mm_dequant_impl,
quantize_4bit_impl,
_ipex_xpu_version_prereq
)
try:
import intel_extension_for_pytorch as ipex
ipex_xpu = ipex if ipex._C._has_xpu() else None
except BaseException:
ipex_xpu = None

Tensor = torch.Tensor


str2optimizer8bit_blockwise = {}
if ipex_xpu is not None and _ipex_xpu_version_prereq(2, 7):
str2optimizer8bit_blockwise = {
"adam": (
ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_fp32,
ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_fp16,
ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_bf16,
),
}


def assert_on_xpu(tensors):
on_xpu = True
for t in tensors:
Expand All @@ -35,6 +52,9 @@ class XPUBackend(Backend):
mm_dequant_compute_dtype = torch.bfloat16
mm_dequant_output_dtype = torch.bfloat16

def device_synchronize(self):
torch.xpu.synchronize()

def int8_double_quant(
self,
A: torch.Tensor,
Expand Down Expand Up @@ -185,7 +205,19 @@ def dequantize_blockwise(
blocksize: int = 4096,
nested=False,
) -> torch.Tensor:
raise NotImplementedError
if ipex_xpu is None or not _ipex_xpu_version_prereq(2, 7):
raise RuntimeError("Please install intel_extension_for_ipex >= 2.7 for 8bit optimizer backend on XPU device.")

# void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream)
if out.dtype == torch.float16:
ipex.xpu.bitsandbytes.cdequantize_blockwise_fp16(code, A, absmax, out, blocksize, A.numel())
elif out.dtype == torch.bfloat16:
ipex.xpu.bitsandbytes.cdequantize_blockwise_bf16(code, A, absmax, out, blocksize, A.numel())
elif out.dtype == torch.float32:
ipex.xpu.bitsandbytes.cdequantize_blockwise_fp32(code, A, absmax, out, blocksize, A.numel())
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}")
Comment on lines +212 to +219
Copy link
Member

@matthewdouglas matthewdouglas Mar 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be useful when porting over to the new custom ops as an implementation for bitsandbytes::dequantize_blockwise.out(Tensor A, Tensor absmax, Tensor code, int blocksize, ScalarType dtype, Tensor! out) -> ()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hard to understand. Could you please supply more details or instructions? Thanks!



def quantize_blockwise(
self,
Expand Down Expand Up @@ -220,7 +252,48 @@ def optimizer_update_8bit_blockwise(
gnorm_scale: float = 1.0,
skip_zeros=False,
) -> None:
raise NotImplementedError
optim_func = None
if ipex_xpu is None or not _ipex_xpu_version_prereq(2, 7):
raise RuntimeError("Please install intel_extension_for_ipex >= 2.7 for 8bit optimizer backend on XPU device.")

assert_on_xpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2])

if g.dtype == torch.float32 and state1.dtype == torch.uint8:
optim_func = str2optimizer8bit_blockwise[optimizer_name][0]
elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
optim_func = str2optimizer8bit_blockwise[optimizer_name][1]
elif (
g.dtype == torch.bfloat16
and state1.dtype == torch.uint8
and len(str2optimizer8bit_blockwise[optimizer_name]) == 3
):
optim_func = str2optimizer8bit_blockwise[optimizer_name][2]
else:
raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}",
)
optim_func(
p,
g,
state1,
state2,
beta1,
beta2,
beta3,
alpha,
eps,
step,
lr,
qmap1,
qmap2,
absmax1,
absmax2,
weight_decay,
gnorm_scale,
skip_zeros,
g.numel()
)


def optimizer_update_32bit(
self,
Expand Down
11 changes: 10 additions & 1 deletion bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,16 @@ def dequantize_blockwise(
if out is None:
out = torch.empty(A.shape, dtype=quant_state.dtype, device=A.device)

if A.device.type != "cpu":
if A.device.type == "xpu":
backends[A.device.type].dequantize_blockwise(
A=A,
quant_state=quant_state,
absmax=absmax,
code=quant_state.code,
out=out,
blocksize=blocksize,
nested=quant_state.nested,)
elif A.device.type != "cpu":
code = quant_state.code.to(A.device)
supported_blocksizes = [2048, 4096, 1024, 512, 256, 128, 64]
# Some AMD GPUs have warpsize 64
Expand Down
5 changes: 3 additions & 2 deletions bitsandbytes/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch

import bitsandbytes.functional as F
from bitsandbytes.backends import backends


class MockArgs:
Expand Down Expand Up @@ -289,11 +290,11 @@ def step(self, closure=None):

self.prefetch_state(p)
self.update_step(group, p, gindex, pindex)
torch.cuda.synchronize()
backends[p.device.type].device_synchronize()
if self.is_paged:
# all paged operation are asynchronous, we need
# to sync to make sure all tensors are in the right state
torch.cuda.synchronize()
backends[p.device.type].device_synchronize()

return loss

Expand Down
Loading