Skip to content

XPU backend support 8bit optimizer #1565

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

Liangliang-Ma
Copy link

@Liangliang-Ma Liangliang-Ma commented Mar 14, 2025

This pr adds support of 8bit optimizer for XPU backend.
The backend kernels is integrated in Intel_extension_for_pytorch now.
We have verified the whole path accuracy with 8bit Adam blockwise.

Also add device synchronize func for every backend class to avoid cuda hardcode.

@jiqing-feng @matthewdouglas @Titus-von-Koeller

@matthewdouglas matthewdouglas added Intel Optimizers Issues or feature requests relating to optimizers labels Mar 18, 2025
@jiqing-feng
Copy link
Contributor

jiqing-feng commented Mar 26, 2025

After I verified it on ipex 2.7, we can add XPU tests on test_optim.

@matthewdouglas
Copy link
Member

Thanks!

Optimizer support isn't addressed yet on the new custom ops interface that we've mainlined, but we can keep dev on it here in this branch until that's ready.

Is there a plan to support any other optimizers? Completely understandable if not; just curious!

Copy link

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines +212 to +219
if out.dtype == torch.float16:
ipex.xpu.bitsandbytes.cdequantize_blockwise_fp16(code, A, absmax, out, blocksize, A.numel())
elif out.dtype == torch.bfloat16:
ipex.xpu.bitsandbytes.cdequantize_blockwise_bf16(code, A, absmax, out, blocksize, A.numel())
elif out.dtype == torch.float32:
ipex.xpu.bitsandbytes.cdequantize_blockwise_fp32(code, A, absmax, out, blocksize, A.numel())
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}")
Copy link
Member

@matthewdouglas matthewdouglas Mar 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be useful when porting over to the new custom ops as an implementation for bitsandbytes::dequantize_blockwise.out(Tensor A, Tensor absmax, Tensor code, int blocksize, ScalarType dtype, Tensor! out) -> ()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hard to understand. Could you please supply more details or instructions? Thanks!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jinq-feng
What I meant by that is in the new interface, we define a custom op for the 8bit dynamic quantization that is used for the optimizers and nested absmax. Since there seems to exist an optimized implementation of this exact op in ipex.xpu now, we can just wrap it during our port.

@jiqing-feng
Copy link
Contributor

Thanks!

Optimizer support isn't addressed yet on the new custom ops interface that we've mainlined, but we can keep dev on it here in this branch until that's ready.

Is there a plan to support any other optimizers? Completely understandable if not; just curious!

Currently no plan to enable other optimizers.

@matthewdouglas matthewdouglas merged commit 5c48b33 into bitsandbytes-foundation:multi-backend-refactor Apr 15, 2025
1 of 2 checks passed
@Titus-von-Koeller
Copy link
Collaborator

Titus-von-Koeller commented Apr 15, 2025

The code looks good, thanks for your work on this!

@jiqing-feng @Liangliang-Ma

Please see this short update about the multi-backend refactor #1596.

Regarding the Intel backend, as discussed in parallel with Ke Ding, the target for PRs migrating existing work from multi-backend-refactor instead of main will be the new bitsandbytes-intel repo.

However, some of the pure torch ops and generic cpu functionality still make more sense in the main branch of bitsandbytes, if they don't have the Intel IPEX dependency. Please align with @matthewdouglas and me on those. It's probably best to discuss that in our shared Slack channel.

@Titus-von-Koeller
Copy link
Collaborator

@Liangliang-Ma I invited you to our bitsandbytes-intel slack channel. Could you join there to discuss if you're planning on supporting the PagedOptimizers of BNB?

The paged memory feature is what we have in functional.py:get_paged() using cudaMallocManaged under the hood.

@Liangliang-Ma
Copy link
Author

Liangliang-Ma commented Apr 24, 2025

@Titus-von-Koeller Due to changes in work content, I will not be doing related work in the near future. There will be my other colleague to take over. Thanks for invitation tho :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Intel Optimizers Issues or feature requests relating to optimizers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants