-
Notifications
You must be signed in to change notification settings - Fork 3.4k
[webgpu] Add support for bitnets to ORT WebGPU EP #25587
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1bc39ec
3321b44
cfde2bb
b6f1cbb
7927b5d
fbb19c0
5a4fb16
9c72834
7086425
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "contrib_ops/webgpu/quantization/bitlinear.h" | ||
|
||
#include <string> | ||
#include "core/providers/cpu/math/matmul_helper.h" | ||
#include "core/providers/webgpu/shader_helper.h" | ||
#include "core/providers/webgpu/webgpu_supported_types.h" | ||
#include "core/providers/webgpu/webgpu_utils.h" | ||
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" | ||
|
||
namespace onnxruntime { | ||
namespace contrib { | ||
namespace webgpu { | ||
|
||
ONNX_OPERATOR_KERNEL_EX( | ||
BitLinear, | ||
kMSDomain, | ||
1, | ||
kWebGpuExecutionProvider, | ||
(*KernelDefBuilder::Create()) | ||
.TypeConstraint("T1", WebGpuSupportedFloatTypes()) | ||
.TypeConstraint("T2", DataTypeImpl::GetTensorType<uint8_t>()), | ||
BitLinear); | ||
|
||
Status BitLinearQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const { | ||
shader.AddInput("input_a", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); | ||
shader.AddOutput("output", ShaderUsage::UseValueTypeAlias); | ||
shader.AddOutput("output_5th", ShaderUsage::UseValueTypeAlias); | ||
shader.AddOutput("scales", ShaderUsage::UseElementTypeAlias); | ||
|
||
return WGSL_TEMPLATE_APPLY(shader, "quantization/bitlinear_quantize.wgsl.template", | ||
WGSL_TEMPLATE_PARAMETER(K4, K_ / 4), | ||
WGSL_TEMPLATE_PARAMETER(K_PADDED_4, K_PADDED_ / 4)); | ||
} | ||
|
||
Status BitLinearMultiplyProgram::GenerateShaderCode(ShaderHelper& shader) const { | ||
shader.AddInput("input_a", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); | ||
shader.AddInput("input_a5", ShaderUsage::UseValueTypeAlias); | ||
shader.AddInput("scales_a", ShaderUsage::UseElementTypeAlias); | ||
shader.AddInput("input_b", ShaderUsage::UseValueTypeAlias); | ||
shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); | ||
|
||
return WGSL_TEMPLATE_APPLY(shader, "quantization/bitlinear_multiply.wgsl.template"); | ||
} | ||
|
||
Status BitLinearMultiplySingleMProgram::GenerateShaderCode(ShaderHelper& shader) const { | ||
shader.AddInput("input_a", ShaderUsage::UseUniform); | ||
shader.AddInput("input_a5", ShaderUsage::UseUniform); | ||
shader.AddInput("scales_a", ShaderUsage::UseUniform); | ||
shader.AddInput("input_b", ShaderUsage::UseUniform); | ||
shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); | ||
|
||
ORT_ENFORCE(WorkgroupSizeX() % tile_size_k_ == 0 && tile_size_k_ % 4 == 0, | ||
"tile_size_k_ must evenly divide workgroup size X and be divisible by 4"); | ||
// This algorithm processes K in chunks for efficient computation with BitLinear's ternary quantization | ||
// Each workgroup handles one row of matrix A and tile_size rows of matrix B | ||
// Uses the BitLinear-specific packing where 5 ternary weights are packed per uint8 | ||
return WGSL_TEMPLATE_APPLY(shader, "quantization/bitlinear_multiply_small_m.wgsl.template", | ||
WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), | ||
WGSL_TEMPLATE_PARAMETER(tile_size_k, tile_size_k_)); | ||
} | ||
|
||
Status BitLinear::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { | ||
const Tensor* a = context.Input(0); | ||
const Tensor* b = context.Input(1); | ||
|
||
// Validate input shapes | ||
TensorShape b_shape({N_, K_}); | ||
MatMulComputeHelper helper; | ||
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); | ||
auto* y = context.Output(0, helper.OutputShape()); | ||
|
||
const uint32_t data_size = onnxruntime::narrow<uint32_t>(y->Shape().Size()); | ||
if (data_size == 0) { | ||
return Status::OK(); | ||
} | ||
|
||
const uint32_t M = onnxruntime::narrow<uint32_t>(helper.M()); | ||
const uint32_t N = onnxruntime::narrow<uint32_t>(helper.N()); | ||
const uint32_t K = onnxruntime::narrow<uint32_t>(helper.K()); | ||
|
||
// Validate input B shape more specifically | ||
const uint32_t kQuantizationBlockSize = 20; | ||
const uint32_t kWeightsPerByte = 5; | ||
// When K is not divisible by kQuantizationBlockSize, weights are padded to fit kQuantizationBlockSize. | ||
// During quantization of A we also pad the resulting output to K_PADDED to match the weights. | ||
const uint32_t K_PADDED = ((K + (kQuantizationBlockSize - 1)) / kQuantizationBlockSize) * kQuantizationBlockSize; | ||
TensorShape expected_b_shape({N, K_PADDED / kWeightsPerByte}); | ||
ORT_ENFORCE(b->Shape() == expected_b_shape, "Unexpected input B shape", b->Shape().ToString()); | ||
|
||
// Step 1: Quantize input A using BitLinearQuantizeProgram | ||
const uint32_t quantize_output_size = (M * (K_PADDED - (K_PADDED / kWeightsPerByte)) / 4); // skipping every 5th, packed into u32 | ||
const uint32_t quantize_5th_output_size = M * K_PADDED / kQuantizationBlockSize; // every 5th element packed int u32 | ||
|
||
TensorShape quantize_output_shape({quantize_output_size}); | ||
TensorShape quantize_5th_output_shape({quantize_5th_output_size}); | ||
TensorShape scales_output_shape({M}); | ||
|
||
auto quantized_a = context.CreateGPUTensor(DataTypeImpl::GetType<uint32_t>(), quantize_output_shape); | ||
auto quantized_a5 = context.CreateGPUTensor(DataTypeImpl::GetType<uint32_t>(), quantize_5th_output_shape); | ||
auto scales_a = context.CreateGPUTensor(a->DataType(), scales_output_shape); | ||
constexpr uint32_t kVec4Components = 4; | ||
constexpr uint32_t kU32Components = 4; | ||
|
||
{ | ||
BitLinearQuantizeProgram quantize_program(K, K_PADDED); | ||
quantize_program | ||
.SetWorkgroupSize(128) | ||
.SetDispatchGroupSize(M) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am wondering for generation (M is 1), it's not efficient to only dispatch one workgroup. Maybe add another quantize shader to process vector input. |
||
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(kVec4Components)}}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should ensure that K % 4 == 0? |
||
.AddOutputs({{&quantized_a, ProgramTensorMetadataDependency::None}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use |
||
{&quantized_a5, ProgramTensorMetadataDependency::None}, | ||
{&scales_a, ProgramTensorMetadataDependency::None}}) | ||
.CacheHint(K); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use uniform for |
||
ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program)); | ||
} | ||
|
||
// Step 2: Matrix multiplication using appropriate program based on M size | ||
if (M == 1) { | ||
// Use small M optimized program for generation mode | ||
const uint32_t tile_size_k = 32; | ||
const uint32_t tile_size = 4; | ||
|
||
BitLinearMultiplySingleMProgram multiply_program(tile_size_k, tile_size); | ||
uint32_t num_N_tile = (N + tile_size - 1) / tile_size; | ||
multiply_program | ||
.SetWorkgroupSize(128) | ||
.SetDispatchGroupSize(num_N_tile) | ||
.AddInputs({{&quantized_a, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(kVec4Components)}, | ||
{&quantized_a5, ProgramTensorMetadataDependency::TypeAndRank}, | ||
{&scales_a, ProgramTensorMetadataDependency::TypeAndRank}, | ||
{b, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(kU32Components)}}) | ||
.AddOutputs({{y, ProgramTensorMetadataDependency::TypeAndRank}}) | ||
.AddUniformVariables({{M}, {N}, {K_PADDED}, {K_PADDED / 20}, {scale_b_}, {num_N_tile}}) | ||
.CacheHint(tile_size_k, tile_size); | ||
ORT_RETURN_IF_ERROR(context.RunProgram(multiply_program)); | ||
} else { | ||
// Use original tiled program for larger batch sizes | ||
BitLinearMultiplyProgram multiply_program; | ||
// input_a is vectorized as vec4<u32> which gives a packing of 16 elements per value. | ||
// Support for cases where (K_PADDED - (K_PADDED / 5)) is not divisible by 16, is not implemented. | ||
ORT_ENFORCE((K_PADDED - (K_PADDED / 5)) % 16 == 0, "K_PADDED must be divisible by 16 after skipping every 5th element. K_PADDED: ", K_PADDED); | ||
const uint32_t input_a_stride = (K_PADDED - (K_PADDED / 5)) / 16; | ||
constexpr uint32_t kTileSize = 64; | ||
TensorShape reshaped_y_shape{1, M, N / kVec4Components}; | ||
uint32_t num_M_tile = (M + kTileSize - 1) / kTileSize; | ||
uint32_t num_N_tile = (N + kTileSize - 1) / kTileSize; | ||
multiply_program | ||
.SetWorkgroupSize(256) | ||
.SetDispatchGroupSize(num_M_tile * num_N_tile) | ||
.AddInputs({{&quantized_a, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(kVec4Components)}, | ||
{&quantized_a5, ProgramTensorMetadataDependency::TypeAndRank}, | ||
{&scales_a, ProgramTensorMetadataDependency::TypeAndRank}, | ||
{b, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(kU32Components)}}) | ||
.AddOutputs({{y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, static_cast<int>(kVec4Components)}}) | ||
.AddUniformVariables({{M}, {N}, {K_PADDED}, {input_a_stride}, {scale_b_}, {num_N_tile}}); | ||
ORT_RETURN_IF_ERROR(context.RunProgram(multiply_program)); | ||
} | ||
|
||
return Status::OK(); | ||
} | ||
|
||
} // namespace webgpu | ||
} // namespace contrib | ||
} // namespace onnxruntime |
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,76 @@ | ||||||||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||||||||
|
||||||||
// Licensed under the MIT License. | ||||||||
|
||||||||
#pragma once | ||||||||
|
||||||||
#include "core/providers/webgpu/program.h" | ||||||||
#include "core/providers/webgpu/webgpu_kernel.h" | ||||||||
|
||||||||
namespace onnxruntime { | ||||||||
namespace contrib { | ||||||||
namespace webgpu { | ||||||||
|
||||||||
using namespace onnxruntime::webgpu; | ||||||||
Check warning on line 13 in onnxruntime/contrib_ops/webgpu/quantization/bitlinear.h
|
||||||||
|
||||||||
class BitLinearQuantizeProgram final : public Program<BitLinearQuantizeProgram> { | ||||||||
public: | ||||||||
BitLinearQuantizeProgram(uint32_t k, uint32_t k_padded) : Program{"BitLinearQuantize"}, K_(k), K_PADDED_(k_padded) {} | ||||||||
|
||||||||
Status GenerateShaderCode(ShaderHelper& sh) const override; | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
|
||||||||
private: | ||||||||
uint32_t K_; | ||||||||
uint32_t K_PADDED_; | ||||||||
}; | ||||||||
|
||||||||
class BitLinearMultiplyProgram final : public Program<BitLinearMultiplyProgram> { | ||||||||
public: | ||||||||
BitLinearMultiplyProgram() : Program{"BitLinearMultiply"} {} | ||||||||
|
||||||||
Status GenerateShaderCode(ShaderHelper& sh) const override; | ||||||||
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32}, | ||||||||
{"N", ProgramUniformVariableDataType::Uint32}, | ||||||||
{"K", ProgramUniformVariableDataType::Uint32}, | ||||||||
{"InputAStride", ProgramUniformVariableDataType::Uint32}, | ||||||||
{"scale_B", ProgramUniformVariableDataType::Float32}, | ||||||||
{"num_N_tile", ProgramUniformVariableDataType::Uint32}); | ||||||||
}; | ||||||||
|
||||||||
class BitLinearMultiplySingleMProgram final : public Program<BitLinearMultiplySingleMProgram> { | ||||||||
public: | ||||||||
BitLinearMultiplySingleMProgram(uint32_t tile_size_k, uint32_t tile_size) : Program{"BitLinearMultiplySingleM"}, | ||||||||
tile_size_k_(tile_size_k), | ||||||||
tile_size_(tile_size) {} | ||||||||
|
||||||||
Status GenerateShaderCode(ShaderHelper& sh) const override; | ||||||||
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32}, | ||||||||
{"N", ProgramUniformVariableDataType::Uint32}, | ||||||||
{"K", ProgramUniformVariableDataType::Uint32}, | ||||||||
{"K20", ProgramUniformVariableDataType::Uint32}, | ||||||||
{"scale_B", ProgramUniformVariableDataType::Float32}, | ||||||||
{"num_N_tile", ProgramUniformVariableDataType::Uint32}); | ||||||||
|
||||||||
private: | ||||||||
uint32_t tile_size_k_; | ||||||||
uint32_t tile_size_; | ||||||||
}; | ||||||||
|
||||||||
class BitLinear final : public WebGpuKernel { | ||||||||
public: | ||||||||
BitLinear(const OpKernelInfo& info) : WebGpuKernel(info) { | ||||||||
K_ = info.GetAttr<int64_t>("K"); | ||||||||
N_ = info.GetAttr<int64_t>("N"); | ||||||||
scale_b_ = info.GetAttr<float>("scale"); | ||||||||
} | ||||||||
|
||||||||
Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override; | ||||||||
|
||||||||
private: | ||||||||
int64_t K_; | ||||||||
int64_t N_; | ||||||||
float scale_b_ = 1.0f; | ||||||||
}; | ||||||||
|
||||||||
} // namespace webgpu | ||||||||
} // namespace contrib | ||||||||
} // namespace onnxruntime |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
const lut_size : u32 = 86; | ||
|
||
// Unpacks a uint8 into 4xI8 vector, statically generated lookup table. | ||
// Since every 3 consecutive values are the same, index with /3. | ||
const dequantize_LUT = array<u32, lut_size>( | ||
0xFFFFFFFF, | ||
0x00FFFFFF, | ||
0x01FFFFFF, | ||
0xFF00FFFF, | ||
0x0000FFFF, | ||
0x0100FFFF, | ||
0xFF01FFFF, | ||
0x0001FFFF, | ||
0x0101FFFF, | ||
0xFFFF00FF, | ||
0x00FF00FF, | ||
0x01FF00FF, | ||
0xFF0000FF, | ||
0x000000FF, | ||
0x010000FF, | ||
0xFF0100FF, | ||
0x000100FF, | ||
0x010100FF, | ||
0xFFFF01FF, | ||
0x00FF01FF, | ||
0x01FF01FF, | ||
0xFF0001FF, | ||
0x000001FF, | ||
0x010001FF, | ||
0xFF0101FF, | ||
0x000101FF, | ||
0x010101FF, | ||
0xFFFFFF00, | ||
0x00FFFF00, | ||
0x01FFFF00, | ||
0xFF00FF00, | ||
0x0000FF00, | ||
0x0100FF00, | ||
0xFF01FF00, | ||
0x0001FF00, | ||
0x0101FF00, | ||
0xFFFF0000, | ||
0x00FF0000, | ||
0x01FF0000, | ||
0xFF000000, | ||
0x00000000, | ||
0x01000000, | ||
0xFF010000, | ||
0x00010000, | ||
0x01010000, | ||
0xFFFF0100, | ||
0x00FF0100, | ||
0x01FF0100, | ||
0xFF000100, | ||
0x00000100, | ||
0x01000100, | ||
0xFF010100, | ||
0x00010100, | ||
0x01010100, | ||
0xFFFFFF01, | ||
0x00FFFF01, | ||
0x01FFFF01, | ||
0xFF00FF01, | ||
0x0000FF01, | ||
0x0100FF01, | ||
0xFF01FF01, | ||
0x0001FF01, | ||
0x0101FF01, | ||
0xFFFF0001, | ||
0x00FF0001, | ||
0x01FF0001, | ||
0xFF000001, | ||
0x00000001, | ||
0x01000001, | ||
0xFF010001, | ||
0x00010001, | ||
0x01010001, | ||
0xFFFF0101, | ||
0x00FF0101, | ||
0x01FF0101, | ||
0xFF000101, | ||
0x00000101, | ||
0x01000101, | ||
0xFF010101, | ||
0x00010101, | ||
0x01010101, | ||
0xFFFFFF00, | ||
0xFFFFFF00, | ||
0xFFFFFF00, | ||
0x00FFFF00, | ||
0x00FFFF00); |
Uh oh!
There was an error while loading. Please reload this page.