Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 167 additions & 0 deletions onnxruntime/contrib_ops/webgpu/quantization/bitlinear.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "contrib_ops/webgpu/quantization/bitlinear.h"

#include <string>
#include "core/providers/cpu/math/matmul_helper.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "core/providers/webgpu/webgpu_utils.h"
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

ONNX_OPERATOR_KERNEL_EX(
BitLinear,
kMSDomain,
1,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T1", WebGpuSupportedFloatTypes())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<uint8_t>()),
BitLinear);

Status BitLinearQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const {
shader.AddInput("input_a", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
shader.AddOutput("output", ShaderUsage::UseValueTypeAlias);
shader.AddOutput("output_5th", ShaderUsage::UseValueTypeAlias);
shader.AddOutput("scales", ShaderUsage::UseElementTypeAlias);

return WGSL_TEMPLATE_APPLY(shader, "quantization/bitlinear_quantize.wgsl.template",
WGSL_TEMPLATE_PARAMETER(K4, K_ / 4),
WGSL_TEMPLATE_PARAMETER(K_PADDED_4, K_PADDED_ / 4));
}

Status BitLinearMultiplyProgram::GenerateShaderCode(ShaderHelper& shader) const {
shader.AddInput("input_a", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
shader.AddInput("input_a5", ShaderUsage::UseValueTypeAlias);
shader.AddInput("scales_a", ShaderUsage::UseElementTypeAlias);
shader.AddInput("input_b", ShaderUsage::UseValueTypeAlias);
shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);

return WGSL_TEMPLATE_APPLY(shader, "quantization/bitlinear_multiply.wgsl.template");
}

Status BitLinearMultiplySingleMProgram::GenerateShaderCode(ShaderHelper& shader) const {
shader.AddInput("input_a", ShaderUsage::UseUniform);
shader.AddInput("input_a5", ShaderUsage::UseUniform);
shader.AddInput("scales_a", ShaderUsage::UseUniform);
shader.AddInput("input_b", ShaderUsage::UseUniform);
shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias);

ORT_ENFORCE(WorkgroupSizeX() % tile_size_k_ == 0 && tile_size_k_ % 4 == 0,
"tile_size_k_ must evenly divide workgroup size X and be divisible by 4");
// This algorithm processes K in chunks for efficient computation with BitLinear's ternary quantization
// Each workgroup handles one row of matrix A and tile_size rows of matrix B
// Uses the BitLinear-specific packing where 5 ternary weights are packed per uint8
return WGSL_TEMPLATE_APPLY(shader, "quantization/bitlinear_multiply_small_m.wgsl.template",
WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_),
WGSL_TEMPLATE_PARAMETER(tile_size_k, tile_size_k_));
}

Status BitLinear::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
const Tensor* a = context.Input(0);
const Tensor* b = context.Input(1);

// Validate input shapes
TensorShape b_shape({N_, K_});
MatMulComputeHelper helper;
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true));
auto* y = context.Output(0, helper.OutputShape());

const uint32_t data_size = onnxruntime::narrow<uint32_t>(y->Shape().Size());
if (data_size == 0) {
return Status::OK();
}

const uint32_t M = onnxruntime::narrow<uint32_t>(helper.M());
const uint32_t N = onnxruntime::narrow<uint32_t>(helper.N());
const uint32_t K = onnxruntime::narrow<uint32_t>(helper.K());

// Validate input B shape more specifically
const uint32_t kQuantizationBlockSize = 20;
const uint32_t kWeightsPerByte = 5;
// When K is not divisible by kQuantizationBlockSize, weights are padded to fit kQuantizationBlockSize.
// During quantization of A we also pad the resulting output to K_PADDED to match the weights.
const uint32_t K_PADDED = ((K + (kQuantizationBlockSize - 1)) / kQuantizationBlockSize) * kQuantizationBlockSize;
TensorShape expected_b_shape({N, K_PADDED / kWeightsPerByte});
ORT_ENFORCE(b->Shape() == expected_b_shape, "Unexpected input B shape", b->Shape().ToString());

// Step 1: Quantize input A using BitLinearQuantizeProgram
const uint32_t quantize_output_size = (M * (K_PADDED - (K_PADDED / kWeightsPerByte)) / 4); // skipping every 5th, packed into u32
const uint32_t quantize_5th_output_size = M * K_PADDED / kQuantizationBlockSize; // every 5th element packed int u32

TensorShape quantize_output_shape({quantize_output_size});
TensorShape quantize_5th_output_shape({quantize_5th_output_size});
TensorShape scales_output_shape({M});

auto quantized_a = context.CreateGPUTensor(DataTypeImpl::GetType<uint32_t>(), quantize_output_shape);
auto quantized_a5 = context.CreateGPUTensor(DataTypeImpl::GetType<uint32_t>(), quantize_5th_output_shape);
auto scales_a = context.CreateGPUTensor(a->DataType(), scales_output_shape);
constexpr uint32_t kVec4Components = 4;
constexpr uint32_t kU32Components = 4;

{
BitLinearQuantizeProgram quantize_program(K, K_PADDED);
quantize_program
.SetWorkgroupSize(128)
.SetDispatchGroupSize(M)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering for generation (M is 1), it's not efficient to only dispatch one workgroup. Maybe add another quantize shader to process vector input.

.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(kVec4Components)}})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should ensure that K % 4 == 0?

.AddOutputs({{&quantized_a, ProgramTensorMetadataDependency::None},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use kVec4Components for quantized_a?

{&quantized_a5, ProgramTensorMetadataDependency::None},
{&scales_a, ProgramTensorMetadataDependency::None}})
.CacheHint(K);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use uniform for K to reduce shader variant?

ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program));
}

// Step 2: Matrix multiplication using appropriate program based on M size
if (M == 1) {
// Use small M optimized program for generation mode
const uint32_t tile_size_k = 32;
const uint32_t tile_size = 4;

BitLinearMultiplySingleMProgram multiply_program(tile_size_k, tile_size);
uint32_t num_N_tile = (N + tile_size - 1) / tile_size;
multiply_program
.SetWorkgroupSize(128)
.SetDispatchGroupSize(num_N_tile)
.AddInputs({{&quantized_a, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(kVec4Components)},
{&quantized_a5, ProgramTensorMetadataDependency::TypeAndRank},
{&scales_a, ProgramTensorMetadataDependency::TypeAndRank},
{b, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(kU32Components)}})
.AddOutputs({{y, ProgramTensorMetadataDependency::TypeAndRank}})
.AddUniformVariables({{M}, {N}, {K_PADDED}, {K_PADDED / 20}, {scale_b_}, {num_N_tile}})
.CacheHint(tile_size_k, tile_size);
ORT_RETURN_IF_ERROR(context.RunProgram(multiply_program));
} else {
// Use original tiled program for larger batch sizes
BitLinearMultiplyProgram multiply_program;
// input_a is vectorized as vec4<u32> which gives a packing of 16 elements per value.
// Support for cases where (K_PADDED - (K_PADDED / 5)) is not divisible by 16, is not implemented.
ORT_ENFORCE((K_PADDED - (K_PADDED / 5)) % 16 == 0, "K_PADDED must be divisible by 16 after skipping every 5th element. K_PADDED: ", K_PADDED);
const uint32_t input_a_stride = (K_PADDED - (K_PADDED / 5)) / 16;
constexpr uint32_t kTileSize = 64;
TensorShape reshaped_y_shape{1, M, N / kVec4Components};
uint32_t num_M_tile = (M + kTileSize - 1) / kTileSize;
uint32_t num_N_tile = (N + kTileSize - 1) / kTileSize;
multiply_program
.SetWorkgroupSize(256)
.SetDispatchGroupSize(num_M_tile * num_N_tile)
.AddInputs({{&quantized_a, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(kVec4Components)},
{&quantized_a5, ProgramTensorMetadataDependency::TypeAndRank},
{&scales_a, ProgramTensorMetadataDependency::TypeAndRank},
{b, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(kU32Components)}})
.AddOutputs({{y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, static_cast<int>(kVec4Components)}})
.AddUniformVariables({{M}, {N}, {K_PADDED}, {input_a_stride}, {scale_b_}, {num_N_tile}});
ORT_RETURN_IF_ERROR(context.RunProgram(multiply_program));
}

return Status::OK();
}

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
76 changes: 76 additions & 0 deletions onnxruntime/contrib_ops/webgpu/quantization/bitlinear.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/webgpu_kernel.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

using namespace onnxruntime::webgpu;

Check warning on line 13 in onnxruntime/contrib_ops/webgpu/quantization/bitlinear.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/contrib_ops/webgpu/quantization/bitlinear.h:13: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]

class BitLinearQuantizeProgram final : public Program<BitLinearQuantizeProgram> {
public:
BitLinearQuantizeProgram(uint32_t k, uint32_t k_padded) : Program{"BitLinearQuantize"}, K_(k), K_PADDED_(k_padded) {}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Status GenerateShaderCode(ShaderHelper& sh) const override;
Status GenerateShaderCode(ShaderHelper& sh) const override;


private:
uint32_t K_;
uint32_t K_PADDED_;
};

class BitLinearMultiplyProgram final : public Program<BitLinearMultiplyProgram> {
public:
BitLinearMultiplyProgram() : Program{"BitLinearMultiply"} {}

Status GenerateShaderCode(ShaderHelper& sh) const override;
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32},
{"N", ProgramUniformVariableDataType::Uint32},
{"K", ProgramUniformVariableDataType::Uint32},
{"InputAStride", ProgramUniformVariableDataType::Uint32},
{"scale_B", ProgramUniformVariableDataType::Float32},
{"num_N_tile", ProgramUniformVariableDataType::Uint32});
};

class BitLinearMultiplySingleMProgram final : public Program<BitLinearMultiplySingleMProgram> {
public:
BitLinearMultiplySingleMProgram(uint32_t tile_size_k, uint32_t tile_size) : Program{"BitLinearMultiplySingleM"},
tile_size_k_(tile_size_k),
tile_size_(tile_size) {}

Status GenerateShaderCode(ShaderHelper& sh) const override;
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32},
{"N", ProgramUniformVariableDataType::Uint32},
{"K", ProgramUniformVariableDataType::Uint32},
{"K20", ProgramUniformVariableDataType::Uint32},
{"scale_B", ProgramUniformVariableDataType::Float32},
{"num_N_tile", ProgramUniformVariableDataType::Uint32});

private:
uint32_t tile_size_k_;
uint32_t tile_size_;
};

class BitLinear final : public WebGpuKernel {
public:
BitLinear(const OpKernelInfo& info) : WebGpuKernel(info) {
K_ = info.GetAttr<int64_t>("K");
N_ = info.GetAttr<int64_t>("N");
scale_b_ = info.GetAttr<float>("scale");
}

Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override;

private:
int64_t K_;
int64_t N_;
float scale_b_ = 1.0f;
};

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
const lut_size : u32 = 86;

// Unpacks a uint8 into 4xI8 vector, statically generated lookup table.
// Since every 3 consecutive values are the same, index with /3.
const dequantize_LUT = array<u32, lut_size>(
0xFFFFFFFF,
0x00FFFFFF,
0x01FFFFFF,
0xFF00FFFF,
0x0000FFFF,
0x0100FFFF,
0xFF01FFFF,
0x0001FFFF,
0x0101FFFF,
0xFFFF00FF,
0x00FF00FF,
0x01FF00FF,
0xFF0000FF,
0x000000FF,
0x010000FF,
0xFF0100FF,
0x000100FF,
0x010100FF,
0xFFFF01FF,
0x00FF01FF,
0x01FF01FF,
0xFF0001FF,
0x000001FF,
0x010001FF,
0xFF0101FF,
0x000101FF,
0x010101FF,
0xFFFFFF00,
0x00FFFF00,
0x01FFFF00,
0xFF00FF00,
0x0000FF00,
0x0100FF00,
0xFF01FF00,
0x0001FF00,
0x0101FF00,
0xFFFF0000,
0x00FF0000,
0x01FF0000,
0xFF000000,
0x00000000,
0x01000000,
0xFF010000,
0x00010000,
0x01010000,
0xFFFF0100,
0x00FF0100,
0x01FF0100,
0xFF000100,
0x00000100,
0x01000100,
0xFF010100,
0x00010100,
0x01010100,
0xFFFFFF01,
0x00FFFF01,
0x01FFFF01,
0xFF00FF01,
0x0000FF01,
0x0100FF01,
0xFF01FF01,
0x0001FF01,
0x0101FF01,
0xFFFF0001,
0x00FF0001,
0x01FF0001,
0xFF000001,
0x00000001,
0x01000001,
0xFF010001,
0x00010001,
0x01010001,
0xFFFF0101,
0x00FF0101,
0x01FF0101,
0xFF000101,
0x00000101,
0x01000101,
0xFF010101,
0x00010101,
0x01010101,
0xFFFFFF00,
0xFFFFFF00,
0xFFFFFF00,
0x00FFFF00,
0x00FFFF00);
Loading
Loading