Skip to content

Commit 271c509

Browse files
DP4AMatMul perf refinements (microsoft#23539)
In this change 1. Vectorization of k is updated to 4. 2. Tile_A, Tile_B are stored transposed in shared memory. This makes it so that memory locality is improved for our access pattern. 3. Lane output is switched to being individual vectors and its loop unrolled, this solves the problem where laneoutput was not on registers before. Perf improvements are not very consistent with this change. On Tigerlake GPU with 32.0.101.6460 (latest intel drivers) ``` Baseline model_benchmark.exe -i C:\Phi-3.5-mini-instruct-onnx-web\Phi-3.5-mini-instruct-onnx-web\ -l 1000 Batch size: 1, prompt tokens: 1001, tokens to generate: 128 Prompt processing (time to first token): avg (us): 7.36557e+06 <<<< avg (tokens/s): 135.903 p50 (us): 7.35498e+06 stddev (us): 27599 n: 5 * 1001 token(s) With Change model_benchmark.exe -i C:\Phi-3.5-mini-instruct-onnx-web\Phi-3.5-mini-instruct-onnx-web\ -l 1000 Batch size: 1, prompt tokens: 1001, tokens to generate: 128 Prompt processing (time to first token): avg (us): 6.52302e+06 <<<< avg (tokens/s): 153.457 p50 (us): 6.52224e+06 stddev (us): 10407.3 n: 5 * 1001 token(s) ``` However, using the Intel GPA comparing before and after profile, one can clearly see straight runs of ALU work without being interspersed by writebacks to local memory that contained lane_output before. ![image](https://github.com/user-attachments/assets/e01d3474-8406-4a61-b352-2ecbf0855a7f)
1 parent cb69c59 commit 271c509

File tree

1 file changed

+107
-63
lines changed

1 file changed

+107
-63
lines changed

onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc

+107-63
Original file line numberDiff line numberDiff line change
@@ -613,17 +613,14 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
613613
const tile_size_k = 32;
614614
const vec_factor = 4;
615615
const u32_factor = 4;
616-
const tile_size_k_vec = 4;
616+
const tile_size_k_vec = 2;
617617
const block_size = 32;
618618
619619
// Shared memory
620-
var<workgroup> tile_A : array<array<vec2<u32>, tile_size_k_vec>, tile_size>; // 64 x 32
621-
var<workgroup> scale_A : array<output_element_t, tile_size>; // 64 x 1
622-
var<workgroup> tile_B : array<array<vec2<u32>, tile_size_k_vec>, tile_size>; // 64 x 32
623-
var<workgroup> scale_B : array<output_element_t, tile_size>; // 64 x 1
624-
625-
// Private memory
626-
var<private> lane_output: array<output_element_t, 16>;
620+
var<workgroup> tile_A : array<array<vec4<u32>, tile_size>, tile_size_k_vec>; // 64 x 32
621+
var<workgroup> scale_A : array<output_element_t, tile_size>; // 64 x 1
622+
var<workgroup> tile_B : array<array<vec4<u32>, tile_size>, tile_size_k_vec>; // 64 x 32
623+
var<workgroup> scale_B : array<output_element_t, tile_size>; // 64 x 1
627624
628625
fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32)
629626
{
@@ -632,11 +629,11 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
632629
{
633630
return;
634631
}
635-
tile_A[row][col] = input_a[a_global*uniforms.K8+kidx_v+col];
632+
tile_A[col][row] = input_a[a_global*uniforms.K16+kidx_v+col];
636633
if (col == 0)
637634
{
638-
// kidx_v - covers 8 values of k
639-
scale_A[row] = scales_a[a_global*(uniforms.K/128) + kidx_v/16];
635+
// kidx_v - covers 16 values of k
636+
scale_A[row] = scales_a[a_global*(uniforms.K/128) + kidx_v/8];
640637
}
641638
}
642639
@@ -648,36 +645,45 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
648645
return;
649646
}
650647
651-
let b_value = input_b[b_global*uniforms.K8+kidx_v+col];
652-
var b_value_lower = vec4<i32>(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4<i32>(8);
653-
var b_value_upper = vec4<i32>(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
654-
tile_B[row][col][0] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
655-
tile_B[row][col][1] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
648+
let b_value = input_b[b_global*uniforms.K16+kidx_v+col];
649+
var b_value_lower = vec4<i32>(unpack4xU8(b_value[0] & 0x0F0F0F0Fu)) - vec4<i32>(8);
650+
var b_value_upper = vec4<i32>(unpack4xU8((b_value[0] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
651+
tile_B[col][row][0] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
652+
tile_B[col][row][1] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
653+
b_value_lower = vec4<i32>(unpack4xU8(b_value[1] & 0x0F0F0F0Fu)) - vec4<i32>(8);
654+
b_value_upper = vec4<i32>(unpack4xU8((b_value[1] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
655+
tile_B[col][row][2] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
656+
tile_B[col][row][3] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
656657
if (col == 0)
657658
{
658-
// kidx_v - each kidx_v covers 8 values of k
659-
scale_B[row] = scales_b[b_global*(uniforms.K/32) + kidx_v/4];
659+
// kidx_v - each kidx_v covers 16 values of k
660+
scale_B[row] = scales_b[b_global*(uniforms.K/32) + kidx_v/2];
660661
}
661662
}
662663
663-
fn DP4AI(a:vec4<u32>, b:vec4<u32>) -> i32
664+
// Scaled dot product of 8 packed unsigned integers.
665+
fn SDP8AI(a1:vec4<u32>, b1:vec4<u32>, a2:vec4<u32>, b2:vec4<u32>, scale:output_element_t) -> output_element_t
664666
{
665-
var local_sum = dot4I8Packed(a[0], b[0]);
666-
local_sum += dot4I8Packed(a[1], b[1]);
667-
local_sum += dot4I8Packed(a[2], b[2]);
668-
local_sum += dot4I8Packed(a[3], b[3]);
669-
return local_sum;
667+
var local_sum = dot4I8Packed(a1[0], b1[0]);
668+
local_sum += dot4I8Packed(a1[1], b1[1]);
669+
local_sum += dot4I8Packed(a1[2], b1[2]);
670+
local_sum += dot4I8Packed(a1[3], b1[3]);
671+
local_sum += dot4I8Packed(a2[0], b2[0]);
672+
local_sum += dot4I8Packed(a2[1], b2[1]);
673+
local_sum += dot4I8Packed(a2[2], b2[2]);
674+
local_sum += dot4I8Packed(a2[3], b2[3]);
675+
return output_element_t(local_sum) * scale;
670676
}
671-
672677
)ADDNL_FN";
673678

674679
shader.MainFunctionBody() << R"MAIN_FN(
675680
// During the load phase we use all 256 threads to load 64 rows of A/B.
676-
// For each row we load 4 vectorized elements, which are 32 elements of K.
681+
// For each row we load tile_size_k_vec (2) vectorized elements, which are 32 elements of K.
677682
let a_global_base = workgroup_id.x * tile_size;
678683
let b_global_base = workgroup_id.y * tile_size;
679-
let load_row = u32(local_idx/4);
680-
let load_col = u32(local_idx%4);
684+
let load_AorB = u32(local_idx/128);
685+
let load_row = u32((local_idx%128)/2);
686+
let load_col = u32(local_idx%2);
681687
682688
// During the compute phase, we have the 64x64 tile split into
683689
// subtiles of 16x16. We have a grid of 4x4 subtiles.
@@ -689,42 +695,81 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
689695
// For each subtile we have 16 threads assigned.
690696
let a_idx = u32(local_idx % subtile_size);
691697
692-
// K's vectrorization is 8 items per index. See input_a/input_b.
693-
// tile_size_k_vec - is the k tile size in vectorized k units/space (1/8).
694-
for (var kidx_v:u32 = 0; kidx_v < uniforms.K8; kidx_v+=tile_size_k_vec)
698+
var lane_output1: vec4<output_element_t>;
699+
var lane_output2: vec4<output_element_t>;
700+
var lane_output3: vec4<output_element_t>;
701+
var lane_output4: vec4<output_element_t>;
702+
// K's vectrorization is 16 items per index. See input_a/input_b.
703+
// tile_size_k_vec - is the k tile size in vectorized space (1/16). That is
704+
// k tile size is 32. In vectorized space that is 32/16 = 2.
705+
for (var kidx_v:u32 = 0; kidx_v < uniforms.K16; kidx_v+=tile_size_k_vec)
695706
{
696-
// Populate shared memory for the workgroup
697-
loadSHMA(a_global_base, kidx_v, load_row, load_col);
698-
loadSHMB(b_global_base, kidx_v, load_row, load_col);
707+
// Load Phase: Populate shared memory for the workgroup.
708+
if (load_AorB == 0)
709+
{
710+
loadSHMA(a_global_base, kidx_v, load_row, load_col);
711+
}
712+
else
713+
{
714+
loadSHMB(b_global_base, kidx_v, load_row, load_col);
715+
}
699716
workgroupBarrier();
700717
701-
var own_a0: vec4<u32> = vec4<u32>(tile_A[base_A + a_idx][0], tile_A[base_A + a_idx][1]);
702-
var own_a1: vec4<u32> = vec4<u32>(tile_A[base_A + a_idx][2], tile_A[base_A + a_idx][3]);
703-
var own_scale_a = scale_A[base_A + a_idx];
718+
// Compute phase: Perform matmul for this subtile 16 x 32 x 16.
719+
// Step 1: Load from shared memory into registers across entire subgroup.
720+
var own_a0: vec4<u32> = tile_A[0][base_A + a_idx];
721+
var own_a1: vec4<u32> = tile_A[1][base_A + a_idx];
722+
var own_scale_a: output_element_t = scale_A[base_A + a_idx];
704723
if (sg_size == 16)
705724
{
706-
var own_b0: vec4<u32> = vec4<u32>(tile_B[base_B + sg_id][0], tile_B[base_B + sg_id][1]);
707-
var own_b1: vec4<u32> = vec4<u32>(tile_B[base_B + sg_id][2], tile_B[base_B + sg_id][3]);
708-
var own_scale_b = scale_B[base_B + sg_id];
709-
for (var col:u32 = 0; col < 16; col++)
710-
{
711-
var local_scale_b = subgroupShuffle(own_scale_b, col);
712-
local_scale_b = local_scale_b * own_scale_a;
713-
var local_sum = DP4AI(own_a0, subgroupShuffle(own_b0, col));
714-
local_sum += DP4AI(own_a1, subgroupShuffle(own_b1, col));
715-
lane_output[col] += (output_element_t(local_sum) * local_scale_b);
716-
}
725+
var own_b0: vec4<u32> = tile_B[0][base_B + sg_id];
726+
var own_b1: vec4<u32> = tile_B[1][base_B + sg_id];
727+
var own_scale_b: output_element_t = scale_B[base_B + sg_id];
728+
// Step 2: Access registers across the subgroup using subgroupShuffle and perform the matmul.
729+
lane_output1[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 0), own_a1, subgroupShuffle(own_b1, 0), subgroupShuffle(own_scale_b, 0) * own_scale_a);
730+
lane_output1[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 1), own_a1, subgroupShuffle(own_b1, 1), subgroupShuffle(own_scale_b, 1) * own_scale_a);
731+
lane_output1[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 2), own_a1, subgroupShuffle(own_b1, 2), subgroupShuffle(own_scale_b, 2) * own_scale_a);
732+
lane_output1[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 3), own_a1, subgroupShuffle(own_b1, 3), subgroupShuffle(own_scale_b, 3) * own_scale_a);
733+
734+
lane_output2[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 4), own_a1, subgroupShuffle(own_b1, 4), subgroupShuffle(own_scale_b, 4) * own_scale_a);
735+
lane_output2[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 5), own_a1, subgroupShuffle(own_b1, 5), subgroupShuffle(own_scale_b, 5) * own_scale_a);
736+
lane_output2[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 6), own_a1, subgroupShuffle(own_b1, 6), subgroupShuffle(own_scale_b, 6) * own_scale_a);
737+
lane_output2[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 7), own_a1, subgroupShuffle(own_b1, 7), subgroupShuffle(own_scale_b, 7) * own_scale_a);
738+
739+
lane_output3[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 8), own_a1, subgroupShuffle(own_b1, 8), subgroupShuffle(own_scale_b, 8) * own_scale_a);
740+
lane_output3[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 9), own_a1, subgroupShuffle(own_b1, 9), subgroupShuffle(own_scale_b, 9) * own_scale_a);
741+
lane_output3[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 10), own_a1, subgroupShuffle(own_b1, 10), subgroupShuffle(own_scale_b, 10) * own_scale_a);
742+
lane_output3[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 11), own_a1, subgroupShuffle(own_b1, 11), subgroupShuffle(own_scale_b, 11) * own_scale_a);
743+
744+
lane_output4[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 12), own_a1, subgroupShuffle(own_b1, 12), subgroupShuffle(own_scale_b, 12) * own_scale_a);
745+
lane_output4[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 13), own_a1, subgroupShuffle(own_b1, 13), subgroupShuffle(own_scale_b, 13) * own_scale_a);
746+
lane_output4[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 14), own_a1, subgroupShuffle(own_b1, 14), subgroupShuffle(own_scale_b, 14) * own_scale_a);
747+
lane_output4[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 15), own_a1, subgroupShuffle(own_b1, 15), subgroupShuffle(own_scale_b, 15) * own_scale_a);
717748
}
718749
else
719750
{
720-
for (var col:u32 = 0; col < 16; col++)
721-
{
722-
var b0: vec4<u32> = vec4<u32>(tile_B[base_B + col][0], tile_B[base_B + col][1]);
723-
var b1: vec4<u32> = vec4<u32>(tile_B[base_B + col][2], tile_B[base_B + col][3]);
724-
var local_sum = DP4AI(own_a0, b0);
725-
local_sum += DP4AI(own_a1, b1);
726-
lane_output[col] += (output_element_t(local_sum) * own_scale_a * scale_B[base_B + col]);
727-
}
751+
// Code for other subgroup sizes, simply doesnt use subgroups at all.
752+
// Relies on reads from single location tile_B[][base_B + col] by all
753+
// being optimized by the hardware.
754+
lane_output1[0] += SDP8AI(own_a0, tile_B[0][base_B + 0], own_a1, tile_B[1][base_B + 0], own_scale_a * scale_B[base_B + 0]);
755+
lane_output1[1] += SDP8AI(own_a0, tile_B[0][base_B + 1], own_a1, tile_B[1][base_B + 1], own_scale_a * scale_B[base_B + 1]);
756+
lane_output1[2] += SDP8AI(own_a0, tile_B[0][base_B + 2], own_a1, tile_B[1][base_B + 2], own_scale_a * scale_B[base_B + 2]);
757+
lane_output1[3] += SDP8AI(own_a0, tile_B[0][base_B + 3], own_a1, tile_B[1][base_B + 3], own_scale_a * scale_B[base_B + 3]);
758+
759+
lane_output2[0] += SDP8AI(own_a0, tile_B[0][base_B + 4], own_a1, tile_B[1][base_B + 4], own_scale_a * scale_B[base_B + 4]);
760+
lane_output2[1] += SDP8AI(own_a0, tile_B[0][base_B + 5], own_a1, tile_B[1][base_B + 5], own_scale_a * scale_B[base_B + 5]);
761+
lane_output2[2] += SDP8AI(own_a0, tile_B[0][base_B + 6], own_a1, tile_B[1][base_B + 6], own_scale_a * scale_B[base_B + 6]);
762+
lane_output2[3] += SDP8AI(own_a0, tile_B[0][base_B + 7], own_a1, tile_B[1][base_B + 7], own_scale_a * scale_B[base_B + 7]);
763+
764+
lane_output3[0] += SDP8AI(own_a0, tile_B[0][base_B + 8], own_a1, tile_B[1][base_B + 8], own_scale_a * scale_B[base_B + 8]);
765+
lane_output3[1] += SDP8AI(own_a0, tile_B[0][base_B + 9], own_a1, tile_B[1][base_B + 9], own_scale_a * scale_B[base_B + 9]);
766+
lane_output3[2] += SDP8AI(own_a0, tile_B[0][base_B + 10], own_a1, tile_B[1][base_B + 10], own_scale_a * scale_B[base_B + 10]);
767+
lane_output3[3] += SDP8AI(own_a0, tile_B[0][base_B + 11], own_a1, tile_B[1][base_B + 11], own_scale_a * scale_B[base_B + 11]);
768+
769+
lane_output4[0] += SDP8AI(own_a0, tile_B[0][base_B + 12], own_a1, tile_B[1][base_B + 12], own_scale_a * scale_B[base_B + 12]);
770+
lane_output4[1] += SDP8AI(own_a0, tile_B[0][base_B + 13], own_a1, tile_B[1][base_B + 13], own_scale_a * scale_B[base_B + 13]);
771+
lane_output4[2] += SDP8AI(own_a0, tile_B[0][base_B + 14], own_a1, tile_B[1][base_B + 14], own_scale_a * scale_B[base_B + 14]);
772+
lane_output4[3] += SDP8AI(own_a0, tile_B[0][base_B + 15], own_a1, tile_B[1][base_B + 15], own_scale_a * scale_B[base_B + 15]);
728773
}
729774
workgroupBarrier();
730775
}
@@ -735,11 +780,10 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
735780
// This creates a shader requirement that uniforms.N % 16 == 0
736781
if (a_global < uniforms.M && b_global < uniforms.N)
737782
{
738-
for (var i:u32 = 0; i < 4; i++)
739-
{
740-
let lidx = i * 4;
741-
output[output_idx+i] = vec4<output_element_t>(lane_output[lidx], lane_output[lidx+1] , lane_output[lidx+2], lane_output[lidx+3]);
742-
}
783+
output[output_idx] = lane_output1;
784+
output[output_idx+1] = lane_output2;
785+
output[output_idx+2] = lane_output3;
786+
output[output_idx+3] = lane_output4;
743787
}
744788
)MAIN_FN";
745789

@@ -812,9 +856,9 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
812856
mul_program.SetDispatchGroupSize(
813857
(M + kTileSize - 1) / kTileSize,
814858
(N + kTileSize - 1) / kTileSize, 1);
815-
mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(kVec2Components)},
859+
mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(kVec4Components)},
816860
{&a_scale, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(1)},
817-
{b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(kU32Components)},
861+
{b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(kVec2Components * kU32Components)},
818862
{scales, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(1)}})
819863
.AddUniformVariables({{static_cast<uint32_t>(M)},
820864
{static_cast<uint32_t>(N)},

0 commit comments

Comments
 (0)