-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Add support for bitnets to ORT WebGPU EP #25587
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
|
||
// Step 1: Quantize input A using BitLinearQuantizeProgram | ||
const uint32_t quantize_output_size = (M * (K_PADDED - (K_PADDED / kWeightsPerByte)) / 4); // skipping every 5th, packed into u32 | ||
const uint32_t quantize_5th_output_size = M * K_PADDED / kQuantizationBlockSize; // every 5th element packed int u32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const uint32_t quantize_5th_output_size = M * K_PADDED / kQuantizationBlockSize; // every 5th element packed int u32 | |
const uint32_t quantize_5th_output_size = M * K_PADDED / kQuantizationBlockSize; // every 5th element packed int u32 |
public: | ||
BitLinearQuantizeProgram(uint32_t k, uint32_t k_padded) : Program{"BitLinearQuantize"}, K_(k), K_PADDED_(k_padded) {} | ||
|
||
Status GenerateShaderCode(ShaderHelper& sh) const override; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Status GenerateShaderCode(ShaderHelper& sh) const override; | |
Status GenerateShaderCode(ShaderHelper& sh) const override; | |
@@ -0,0 +1,131 @@ | |||
// Copyright (c) Microsoft Corporation. All rights reserved. |
Check warning
Code scanning / lintrunner
CLANGFORMAT/format Warning
Run lintrunner -a to apply this patch.
@@ -0,0 +1,56 @@ | |||
// Copyright (c) Microsoft Corporation. All rights reserved. |
Check warning
Code scanning / lintrunner
CLANGFORMAT/format Warning
Run lintrunner -a to apply this patch.
Going to work on tests in the next couple of days. Sharing this PR for early review on operator shape and any feedback. |
Description
This change introduces support for BitNet models—specifically microsoft/bitnet-b1.58-2B-4T-bf16—by adding a new operator, BitLinear, which handles the unique weight format used by BitNets for matrix multiplication.
Converted onnx model for testing purposes - https://huggingface.co/sushraja/bitnet-b1.58-2B-4T-fp16-onnx
Motivation and Context
BitNets significantly reduce memory usage due to their compact parameter representation, making them well-suited for client-side inference scenarios.
BitLinear Operator
BitNets encode matrix weights as ternary values (+1, 0, -1) along with a scale factor. These ternary values are represented in base-3 and packed such that 5 weights fit into a single uint8.
Inference Workflow
The inference process involves two main steps:
Step 1: Quantization of Input A
Step 2: Multiplication with Weights B
These are then multiplied using the DP4A instruction, leveraging shared memory for efficient cooperative matmul.
Key Notes
The BitLinear operator does not enforce a specific layout for B.
Weights are stored using ternary packing (5 weights per uint8), and decompression is handled dynamically at runtime.