@@ -613,17 +613,14 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
613
613
const tile_size_k = 32;
614
614
const vec_factor = 4;
615
615
const u32_factor = 4;
616
- const tile_size_k_vec = 4 ;
616
+ const tile_size_k_vec = 2 ;
617
617
const block_size = 32;
618
618
619
619
// Shared memory
620
- var<workgroup> tile_A : array<array<vec2<u32>, tile_size_k_vec>, tile_size>; // 64 x 32
621
- var<workgroup> scale_A : array<output_element_t, tile_size>; // 64 x 1
622
- var<workgroup> tile_B : array<array<vec2<u32>, tile_size_k_vec>, tile_size>; // 64 x 32
623
- var<workgroup> scale_B : array<output_element_t, tile_size>; // 64 x 1
624
-
625
- // Private memory
626
- var<private> lane_output: array<output_element_t, 16>;
620
+ var<workgroup> tile_A : array<array<vec4<u32>, tile_size>, tile_size_k_vec>; // 64 x 32
621
+ var<workgroup> scale_A : array<output_element_t, tile_size>; // 64 x 1
622
+ var<workgroup> tile_B : array<array<vec4<u32>, tile_size>, tile_size_k_vec>; // 64 x 32
623
+ var<workgroup> scale_B : array<output_element_t, tile_size>; // 64 x 1
627
624
628
625
fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32)
629
626
{
@@ -632,11 +629,11 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
632
629
{
633
630
return;
634
631
}
635
- tile_A[row][col ] = input_a[a_global*uniforms.K8 +kidx_v+col];
632
+ tile_A[col][row ] = input_a[a_global*uniforms.K16 +kidx_v+col];
636
633
if (col == 0)
637
634
{
638
- // kidx_v - covers 8 values of k
639
- scale_A[row] = scales_a[a_global*(uniforms.K/128) + kidx_v/16 ];
635
+ // kidx_v - covers 16 values of k
636
+ scale_A[row] = scales_a[a_global*(uniforms.K/128) + kidx_v/8 ];
640
637
}
641
638
}
642
639
@@ -648,36 +645,45 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
648
645
return;
649
646
}
650
647
651
- let b_value = input_b[b_global*uniforms.K8+kidx_v+col];
652
- var b_value_lower = vec4<i32>(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4<i32>(8);
653
- var b_value_upper = vec4<i32>(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
654
- tile_B[row][col][0] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
655
- tile_B[row][col][1] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
648
+ let b_value = input_b[b_global*uniforms.K16+kidx_v+col];
649
+ var b_value_lower = vec4<i32>(unpack4xU8(b_value[0] & 0x0F0F0F0Fu)) - vec4<i32>(8);
650
+ var b_value_upper = vec4<i32>(unpack4xU8((b_value[0] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
651
+ tile_B[col][row][0] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
652
+ tile_B[col][row][1] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
653
+ b_value_lower = vec4<i32>(unpack4xU8(b_value[1] & 0x0F0F0F0Fu)) - vec4<i32>(8);
654
+ b_value_upper = vec4<i32>(unpack4xU8((b_value[1] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
655
+ tile_B[col][row][2] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
656
+ tile_B[col][row][3] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
656
657
if (col == 0)
657
658
{
658
- // kidx_v - each kidx_v covers 8 values of k
659
- scale_B[row] = scales_b[b_global*(uniforms.K/32) + kidx_v/4 ];
659
+ // kidx_v - each kidx_v covers 16 values of k
660
+ scale_B[row] = scales_b[b_global*(uniforms.K/32) + kidx_v/2 ];
660
661
}
661
662
}
662
663
663
- fn DP4AI(a:vec4<u32>, b:vec4<u32>) -> i32
664
+ // Scaled dot product of 8 packed unsigned integers.
665
+ fn SDP8AI(a1:vec4<u32>, b1:vec4<u32>, a2:vec4<u32>, b2:vec4<u32>, scale:output_element_t) -> output_element_t
664
666
{
665
- var local_sum = dot4I8Packed(a[0], b[0]);
666
- local_sum += dot4I8Packed(a[1], b[1]);
667
- local_sum += dot4I8Packed(a[2], b[2]);
668
- local_sum += dot4I8Packed(a[3], b[3]);
669
- return local_sum;
667
+ var local_sum = dot4I8Packed(a1[0], b1[0]);
668
+ local_sum += dot4I8Packed(a1[1], b1[1]);
669
+ local_sum += dot4I8Packed(a1[2], b1[2]);
670
+ local_sum += dot4I8Packed(a1[3], b1[3]);
671
+ local_sum += dot4I8Packed(a2[0], b2[0]);
672
+ local_sum += dot4I8Packed(a2[1], b2[1]);
673
+ local_sum += dot4I8Packed(a2[2], b2[2]);
674
+ local_sum += dot4I8Packed(a2[3], b2[3]);
675
+ return output_element_t(local_sum) * scale;
670
676
}
671
-
672
677
)ADDNL_FN" ;
673
678
674
679
shader.MainFunctionBody () << R"MAIN_FN(
675
680
// During the load phase we use all 256 threads to load 64 rows of A/B.
676
- // For each row we load 4 vectorized elements, which are 32 elements of K.
681
+ // For each row we load tile_size_k_vec (2) vectorized elements, which are 32 elements of K.
677
682
let a_global_base = workgroup_id.x * tile_size;
678
683
let b_global_base = workgroup_id.y * tile_size;
679
- let load_row = u32(local_idx/4);
680
- let load_col = u32(local_idx%4);
684
+ let load_AorB = u32(local_idx/128);
685
+ let load_row = u32((local_idx%128)/2);
686
+ let load_col = u32(local_idx%2);
681
687
682
688
// During the compute phase, we have the 64x64 tile split into
683
689
// subtiles of 16x16. We have a grid of 4x4 subtiles.
@@ -689,42 +695,81 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
689
695
// For each subtile we have 16 threads assigned.
690
696
let a_idx = u32(local_idx % subtile_size);
691
697
692
- // K's vectrorization is 8 items per index. See input_a/input_b.
693
- // tile_size_k_vec - is the k tile size in vectorized k units/space (1/8).
694
- for (var kidx_v:u32 = 0; kidx_v < uniforms.K8; kidx_v+=tile_size_k_vec)
698
+ var lane_output1: vec4<output_element_t>;
699
+ var lane_output2: vec4<output_element_t>;
700
+ var lane_output3: vec4<output_element_t>;
701
+ var lane_output4: vec4<output_element_t>;
702
+ // K's vectrorization is 16 items per index. See input_a/input_b.
703
+ // tile_size_k_vec - is the k tile size in vectorized space (1/16). That is
704
+ // k tile size is 32. In vectorized space that is 32/16 = 2.
705
+ for (var kidx_v:u32 = 0; kidx_v < uniforms.K16; kidx_v+=tile_size_k_vec)
695
706
{
696
- // Populate shared memory for the workgroup
697
- loadSHMA(a_global_base, kidx_v, load_row, load_col);
698
- loadSHMB(b_global_base, kidx_v, load_row, load_col);
707
+ // Load Phase: Populate shared memory for the workgroup.
708
+ if (load_AorB == 0)
709
+ {
710
+ loadSHMA(a_global_base, kidx_v, load_row, load_col);
711
+ }
712
+ else
713
+ {
714
+ loadSHMB(b_global_base, kidx_v, load_row, load_col);
715
+ }
699
716
workgroupBarrier();
700
717
701
- var own_a0: vec4<u32> = vec4<u32>(tile_A[base_A + a_idx][0], tile_A[base_A + a_idx][1]);
702
- var own_a1: vec4<u32> = vec4<u32>(tile_A[base_A + a_idx][2], tile_A[base_A + a_idx][3]);
703
- var own_scale_a = scale_A[base_A + a_idx];
718
+ // Compute phase: Perform matmul for this subtile 16 x 32 x 16.
719
+ // Step 1: Load from shared memory into registers across entire subgroup.
720
+ var own_a0: vec4<u32> = tile_A[0][base_A + a_idx];
721
+ var own_a1: vec4<u32> = tile_A[1][base_A + a_idx];
722
+ var own_scale_a: output_element_t = scale_A[base_A + a_idx];
704
723
if (sg_size == 16)
705
724
{
706
- var own_b0: vec4<u32> = vec4<u32>(tile_B[base_B + sg_id][0], tile_B[base_B + sg_id][1]);
707
- var own_b1: vec4<u32> = vec4<u32>(tile_B[base_B + sg_id][2], tile_B[base_B + sg_id][3]);
708
- var own_scale_b = scale_B[base_B + sg_id];
709
- for (var col:u32 = 0; col < 16; col++)
710
- {
711
- var local_scale_b = subgroupShuffle(own_scale_b, col);
712
- local_scale_b = local_scale_b * own_scale_a;
713
- var local_sum = DP4AI(own_a0, subgroupShuffle(own_b0, col));
714
- local_sum += DP4AI(own_a1, subgroupShuffle(own_b1, col));
715
- lane_output[col] += (output_element_t(local_sum) * local_scale_b);
716
- }
725
+ var own_b0: vec4<u32> = tile_B[0][base_B + sg_id];
726
+ var own_b1: vec4<u32> = tile_B[1][base_B + sg_id];
727
+ var own_scale_b: output_element_t = scale_B[base_B + sg_id];
728
+ // Step 2: Access registers across the subgroup using subgroupShuffle and perform the matmul.
729
+ lane_output1[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 0), own_a1, subgroupShuffle(own_b1, 0), subgroupShuffle(own_scale_b, 0) * own_scale_a);
730
+ lane_output1[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 1), own_a1, subgroupShuffle(own_b1, 1), subgroupShuffle(own_scale_b, 1) * own_scale_a);
731
+ lane_output1[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 2), own_a1, subgroupShuffle(own_b1, 2), subgroupShuffle(own_scale_b, 2) * own_scale_a);
732
+ lane_output1[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 3), own_a1, subgroupShuffle(own_b1, 3), subgroupShuffle(own_scale_b, 3) * own_scale_a);
733
+
734
+ lane_output2[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 4), own_a1, subgroupShuffle(own_b1, 4), subgroupShuffle(own_scale_b, 4) * own_scale_a);
735
+ lane_output2[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 5), own_a1, subgroupShuffle(own_b1, 5), subgroupShuffle(own_scale_b, 5) * own_scale_a);
736
+ lane_output2[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 6), own_a1, subgroupShuffle(own_b1, 6), subgroupShuffle(own_scale_b, 6) * own_scale_a);
737
+ lane_output2[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 7), own_a1, subgroupShuffle(own_b1, 7), subgroupShuffle(own_scale_b, 7) * own_scale_a);
738
+
739
+ lane_output3[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 8), own_a1, subgroupShuffle(own_b1, 8), subgroupShuffle(own_scale_b, 8) * own_scale_a);
740
+ lane_output3[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 9), own_a1, subgroupShuffle(own_b1, 9), subgroupShuffle(own_scale_b, 9) * own_scale_a);
741
+ lane_output3[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 10), own_a1, subgroupShuffle(own_b1, 10), subgroupShuffle(own_scale_b, 10) * own_scale_a);
742
+ lane_output3[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 11), own_a1, subgroupShuffle(own_b1, 11), subgroupShuffle(own_scale_b, 11) * own_scale_a);
743
+
744
+ lane_output4[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 12), own_a1, subgroupShuffle(own_b1, 12), subgroupShuffle(own_scale_b, 12) * own_scale_a);
745
+ lane_output4[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 13), own_a1, subgroupShuffle(own_b1, 13), subgroupShuffle(own_scale_b, 13) * own_scale_a);
746
+ lane_output4[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 14), own_a1, subgroupShuffle(own_b1, 14), subgroupShuffle(own_scale_b, 14) * own_scale_a);
747
+ lane_output4[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 15), own_a1, subgroupShuffle(own_b1, 15), subgroupShuffle(own_scale_b, 15) * own_scale_a);
717
748
}
718
749
else
719
750
{
720
- for (var col:u32 = 0; col < 16; col++)
721
- {
722
- var b0: vec4<u32> = vec4<u32>(tile_B[base_B + col][0], tile_B[base_B + col][1]);
723
- var b1: vec4<u32> = vec4<u32>(tile_B[base_B + col][2], tile_B[base_B + col][3]);
724
- var local_sum = DP4AI(own_a0, b0);
725
- local_sum += DP4AI(own_a1, b1);
726
- lane_output[col] += (output_element_t(local_sum) * own_scale_a * scale_B[base_B + col]);
727
- }
751
+ // Code for other subgroup sizes, simply doesnt use subgroups at all.
752
+ // Relies on reads from single location tile_B[][base_B + col] by all
753
+ // being optimized by the hardware.
754
+ lane_output1[0] += SDP8AI(own_a0, tile_B[0][base_B + 0], own_a1, tile_B[1][base_B + 0], own_scale_a * scale_B[base_B + 0]);
755
+ lane_output1[1] += SDP8AI(own_a0, tile_B[0][base_B + 1], own_a1, tile_B[1][base_B + 1], own_scale_a * scale_B[base_B + 1]);
756
+ lane_output1[2] += SDP8AI(own_a0, tile_B[0][base_B + 2], own_a1, tile_B[1][base_B + 2], own_scale_a * scale_B[base_B + 2]);
757
+ lane_output1[3] += SDP8AI(own_a0, tile_B[0][base_B + 3], own_a1, tile_B[1][base_B + 3], own_scale_a * scale_B[base_B + 3]);
758
+
759
+ lane_output2[0] += SDP8AI(own_a0, tile_B[0][base_B + 4], own_a1, tile_B[1][base_B + 4], own_scale_a * scale_B[base_B + 4]);
760
+ lane_output2[1] += SDP8AI(own_a0, tile_B[0][base_B + 5], own_a1, tile_B[1][base_B + 5], own_scale_a * scale_B[base_B + 5]);
761
+ lane_output2[2] += SDP8AI(own_a0, tile_B[0][base_B + 6], own_a1, tile_B[1][base_B + 6], own_scale_a * scale_B[base_B + 6]);
762
+ lane_output2[3] += SDP8AI(own_a0, tile_B[0][base_B + 7], own_a1, tile_B[1][base_B + 7], own_scale_a * scale_B[base_B + 7]);
763
+
764
+ lane_output3[0] += SDP8AI(own_a0, tile_B[0][base_B + 8], own_a1, tile_B[1][base_B + 8], own_scale_a * scale_B[base_B + 8]);
765
+ lane_output3[1] += SDP8AI(own_a0, tile_B[0][base_B + 9], own_a1, tile_B[1][base_B + 9], own_scale_a * scale_B[base_B + 9]);
766
+ lane_output3[2] += SDP8AI(own_a0, tile_B[0][base_B + 10], own_a1, tile_B[1][base_B + 10], own_scale_a * scale_B[base_B + 10]);
767
+ lane_output3[3] += SDP8AI(own_a0, tile_B[0][base_B + 11], own_a1, tile_B[1][base_B + 11], own_scale_a * scale_B[base_B + 11]);
768
+
769
+ lane_output4[0] += SDP8AI(own_a0, tile_B[0][base_B + 12], own_a1, tile_B[1][base_B + 12], own_scale_a * scale_B[base_B + 12]);
770
+ lane_output4[1] += SDP8AI(own_a0, tile_B[0][base_B + 13], own_a1, tile_B[1][base_B + 13], own_scale_a * scale_B[base_B + 13]);
771
+ lane_output4[2] += SDP8AI(own_a0, tile_B[0][base_B + 14], own_a1, tile_B[1][base_B + 14], own_scale_a * scale_B[base_B + 14]);
772
+ lane_output4[3] += SDP8AI(own_a0, tile_B[0][base_B + 15], own_a1, tile_B[1][base_B + 15], own_scale_a * scale_B[base_B + 15]);
728
773
}
729
774
workgroupBarrier();
730
775
}
@@ -735,11 +780,10 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
735
780
// This creates a shader requirement that uniforms.N % 16 == 0
736
781
if (a_global < uniforms.M && b_global < uniforms.N)
737
782
{
738
- for (var i:u32 = 0; i < 4; i++)
739
- {
740
- let lidx = i * 4;
741
- output[output_idx+i] = vec4<output_element_t>(lane_output[lidx], lane_output[lidx+1] , lane_output[lidx+2], lane_output[lidx+3]);
742
- }
783
+ output[output_idx] = lane_output1;
784
+ output[output_idx+1] = lane_output2;
785
+ output[output_idx+2] = lane_output3;
786
+ output[output_idx+3] = lane_output4;
743
787
}
744
788
)MAIN_FN" ;
745
789
@@ -812,9 +856,9 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
812
856
mul_program.SetDispatchGroupSize (
813
857
(M + kTileSize - 1 ) / kTileSize ,
814
858
(N + kTileSize - 1 ) / kTileSize , 1 );
815
- mul_program.AddInputs ({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int >(kVec2Components )},
859
+ mul_program.AddInputs ({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int >(kVec4Components )},
816
860
{&a_scale, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int >(1 )},
817
- {b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int >(kU32Components )},
861
+ {b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int >(kVec2Components * kU32Components )},
818
862
{scales, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int >(1 )}})
819
863
.AddUniformVariables ({{static_cast <uint32_t >(M)},
820
864
{static_cast <uint32_t >(N)},
0 commit comments