Skip to content

Add support for bitnets to ORT WebGPU EP #25587

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

sushraja-msft
Copy link
Contributor

@sushraja-msft sushraja-msft commented Jul 30, 2025

Description

This change introduces support for BitNet models—specifically microsoft/bitnet-b1.58-2B-4T-bf16—by adding a new operator, BitLinear, which handles the unique weight format used by BitNets for matrix multiplication.

Converted onnx model for testing purposes - https://huggingface.co/sushraja/bitnet-b1.58-2B-4T-fp16-onnx

Motivation and Context

BitNets significantly reduce memory usage due to their compact parameter representation, making them well-suited for client-side inference scenarios.

BitLinear Operator

BitNets encode matrix weights as ternary values (+1, 0, -1) along with a scale factor. These ternary values are represented in base-3 and packed such that 5 weights fit into a single uint8.

Inference Workflow

The inference process involves two main steps:

Step 1: Quantization of Input A

  • Input tensor A (in fp16) is quantized to int8 with a single scale per token.
  • Four int8 values are packed into a u32.
  • Every 5th value of A is extracted and stored in a separate tensor (A5) to align with the BitNet weight packing.
  • Result: For every 20 values of A, you get:
    • A vec4 (4 × 4 packed values)
    • One u32 for the 5th values

Step 2: Multiplication with Weights B

  • Weights B are stored transposed as a stream of uint8s, each encoding 5 ternary weights.
  • Each uint8 is unpacked using a lookup table into:
    • A u32 containing 4 packed weights
    • 1 extra weight (the 5th), which is collected across 4 uint8s into a separateu32
  • This results in:
    • A vec4 from the packed weights
    • One u32 from the extra weights
      These are then multiplied using the DP4A instruction, leveraging shared memory for efficient cooperative matmul.

Key Notes

The BitLinear operator does not enforce a specific layout for B.
Weights are stored using ternary packing (5 weights per uint8), and decompression is handled dynamically at runtime.

@sushraja-msft sushraja-msft requested review from guschmue and qjia7 July 30, 2025 01:51
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.


// Step 1: Quantize input A using BitLinearQuantizeProgram
const uint32_t quantize_output_size = (M * (K_PADDED - (K_PADDED / kWeightsPerByte)) / 4); // skipping every 5th, packed into u32
const uint32_t quantize_5th_output_size = M * K_PADDED / kQuantizationBlockSize; // every 5th element packed int u32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const uint32_t quantize_5th_output_size = M * K_PADDED / kQuantizationBlockSize; // every 5th element packed int u32
const uint32_t quantize_5th_output_size = M * K_PADDED / kQuantizationBlockSize; // every 5th element packed int u32

public:
BitLinearQuantizeProgram(uint32_t k, uint32_t k_padded) : Program{"BitLinearQuantize"}, K_(k), K_PADDED_(k_padded) {}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Status GenerateShaderCode(ShaderHelper& sh) const override;
Status GenerateShaderCode(ShaderHelper& sh) const override;

@@ -0,0 +1,131 @@
// Copyright (c) Microsoft Corporation. All rights reserved.

Check warning

Code scanning / lintrunner

CLANGFORMAT/format Warning

See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch.
@@ -0,0 +1,56 @@
// Copyright (c) Microsoft Corporation. All rights reserved.

Check warning

Code scanning / lintrunner

CLANGFORMAT/format Warning

See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch.
@sushraja-msft
Copy link
Contributor Author

sushraja-msft commented Jul 30, 2025

Going to work on tests in the next couple of days. Sharing this PR for early review on operator shape and any feedback.

@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label Jul 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:WebGPU ort-web webgpu provider
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants