@@ -618,6 +618,7 @@ inline KernelCode createMatmulWithTranspose(const char *shaderTemplate, const si
618618inline KernelCode createMatmul12 (const char *shaderTemplate, const size_t M,
619619 const size_t K, const size_t N,
620620 const size_t TM, const size_t TN,
621+ const size_t LID,
621622 const Shape &workgroupSize = {256 , 1 , 1 },
622623 NumType precision = kf32) {
623624 std::string codeString (shaderTemplate);
@@ -626,7 +627,8 @@ inline KernelCode createMatmul12(const char *shaderTemplate, const size_t M,
626627 {" {{K}}" , toString (K)},
627628 {" {{N}}" , toString (N)},
628629 {" {{TM}}" , toString (TM)},
629- {" {{TN}}" , toString (TN)}
630+ {" {{TN}}" , toString (TN)},
631+ {" {{LID}}" , toString (LID)}
630632 });
631633 return {loopUnrolling (codeString), workgroupSize, precision};
632634}
@@ -638,18 +640,18 @@ inline KernelCode createMatmul12(const char *shaderTemplate, const size_t M,
638640const char * kShaderSubgroupMatrixMultiply = R"(
639641enable subgroups;
640642enable chromium_experimental_subgroup_matrix;
643+ diagnostic (off, chromium.subgroup_matrix_uniformity);
641644
642645@group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>;
643646@group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>;
644647@group(0) @binding(2) var<storage, read_write> C: array<{{precision}}>;
645648
646649@compute @workgroup_size({{workgroupSize}})
647- fn main(@builtin(workgroup_id) wg: vec3<u32>) {
650+ fn main(@builtin(workgroup_id) wg: vec3<u32>,
651+ @builtin(local_invocation_id) localID : vec3<u32>) {
648652
649653 let rowStart: u32 = wg.x * 8u * {{TM}};
650- let colStart: u32 = wg.y * 8u * {{TN}};
651-
652- if (rowStart >= u32({{M}}) || colStart >= u32({{N}})) { return; }
654+ let colStart: u32 = (wg.y * {{LID}} + localID.y) * 8u * {{TN}};
653655
654656 let baseA: u32 = rowStart * {{K}};
655657 let baseB: u32 = colStart;
@@ -661,27 +663,41 @@ fn main(@builtin(workgroup_id) wg: vec3<u32>) {
661663 // 4x4 accumulators (8x8 each)
662664 var accxx: array<subgroup_matrix_result<{{precision}}, 8, 8>, {{TM}} * {{TN}}>;
663665
666+ for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
667+ Ax[idx_i] = subgroup_matrix_left<{{precision}}, 8, 8>(0);
668+ }
669+
670+ for (var idx_i: u32 = 0; idx_i < {{TN}}; idx_i++) {
671+ Bx[idx_i] = subgroup_matrix_right<{{precision}}, 8, 8>(0);
672+ }
673+
674+ for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
675+ for (var idx_j: u32 = 0; idx_j < {{TN}}; idx_j++) {
676+ accxx[idx_i+idx_j*{{TM}}] = subgroup_matrix_result<{{precision}}, 8, 8>(0);
677+ }
678+ }
679+
664680 for (var k: u32 = 0u; k < {{K}}; k = k + 8u) {
665681 workgroupBarrier();
666682 for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
667- Ax[idx_i] = subgroupMatrixLoad<subgroup_matrix_left<{{precision}},8,8>>(&A, baseA + idx_i * 8u* {{K}} + k , false, {{K}});
683+ Ax[idx_i] = subgroupMatrixLoad<subgroup_matrix_left<{{precision}},8,8>>(&A, baseA + k + 8u * {{K}} * idx_i , false, {{K}});
668684 }
669685
670686 for (var idx_i: u32 = 0; idx_i < {{TN}}; idx_i++) {
671- Bx[idx_i] = subgroupMatrixLoad<subgroup_matrix_right<{{precision}},8,8>>(&B, baseB + k* {{N}} + 8u * idx_i, false, {{N}});
687+ Bx[idx_i] = subgroupMatrixLoad<subgroup_matrix_right<{{precision}},8,8>>(&B, baseB + k * {{N}} + 8u * idx_i, false, {{N}});
672688 }
673689
674- for (var idx_i : u32 = 0; idx_i < {{TM }}; idx_i ++) {
675- for (var idx_j : u32 = 0; idx_j < {{TN }}; idx_j ++) {
676- accxx[idx_i+ idx_j*{{TM}}] = subgroupMatrixMultiplyAccumulate(Ax[idx_i], Bx[idx_j], accxx[idx_i+ idx_j*{{TM}}]);
690+ for (var idx_j : u32 = 0; idx_j < {{TN }}; idx_j ++) {
691+ for (var idx_i : u32 = 0; idx_i < {{TM }}; idx_i ++) {
692+ accxx[idx_j*{{TM}} + idx_i ] = subgroupMatrixMultiplyAccumulate(Ax[idx_i], Bx[idx_j], accxx[idx_j*{{TM}} + idx_i ]);
677693 }
678694 }
679695 }
680696
681697 workgroupBarrier();
682698 for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
683699 for (var idx_j: u32 = 0; idx_j < {{TN}}; idx_j++) {
684- subgroupMatrixStore(&C, cBase + idx_i * 8u * {{N}} + 8u * idx_j, accxx[idx_i+ idx_j*{{TM}}], false, {{N}});
700+ subgroupMatrixStore(&C, cBase + idx_i * 8u * {{N}} + 8u * idx_j, accxx[idx_j*{{TM}} + idx_i ], false, {{N}});
685701 }
686702 }
687703}
@@ -858,15 +874,16 @@ Kernel selectMatmul(Context &ctx, int version,
858874 /* nWorkgroups*/ nWorkgroups,
859875 NoParam{}, &info);
860876 } else if (version == 12 ) {
861- // f32: Subgroup matrix multiply
862- static constexpr size_t TM = 2 ;
863- static constexpr size_t TN = 4 ;
864- Shape wgSize = {64 , 1 , 1 }; // One subgroup per workgroup
865- Shape nWorkgroups = {cdiv (M, 8 * TM), cdiv (N, 8 * TN), 1 };
877+ // f16: Subgroup matrix multiply
878+ static constexpr size_t TM = 4 ;
879+ static constexpr size_t TN = 8 ;
880+ static constexpr size_t LID = 2 ;
881+ Shape wgSize = {64 , LID, 1 }; // One subgroup per workgroup
882+ Shape nWorkgroups = {cdiv (M, 8 * TM), cdiv (N, 8 * TN * LID), 1 };
866883 LOG (kDefLog , kInfo , " M: %zu, K: %zu, N: %zu" , M, K, N);
867884 LOG (kDefLog , kInfo , " wgSize: ( %s )" , toString (wgSize).c_str ());
868885 LOG (kDefLog , kInfo , " nWorkgroups: ( %s )" , toString (nWorkgroups).c_str ());
869- KernelCode matmul = createMatmul12 (kShaderSubgroupMatrixMultiply , M, K, N, TM, TN, wgSize, numtype);
886+ KernelCode matmul = createMatmul12 (kShaderSubgroupMatrixMultiply , M, K, N, TM, TN, LID, wgSize, numtype);
870887 kernel = createKernel (ctx, matmul, bindings, nWorkgroups,
871888 NoParam{}, &info);
872889 }
@@ -931,6 +948,10 @@ void runTest(int version, size_t M, size_t K, size_t N,
931948 LOG (kDefLog , kError , " [DeviceLost %d] %.*s\n " , (int )reason, (int )msg.length , msg.data );
932949 }
933950 };
951+
952+ static WGPULimits requiredLimits = WGPU_LIMITS_INIT;
953+
954+ devDesc.requiredLimits = &requiredLimits;
934955
935956 Context ctx = createContext ({}, {}, devDesc);
936957
0 commit comments