@@ -628,7 +628,7 @@ inline KernelCode createMatmul12(const char *shaderTemplate, const size_t M,
628628 {" {{TM}}" , toString (TM)},
629629 {" {{TN}}" , toString (TN)}
630630 });
631- return {codeString, workgroupSize, precision};
631+ return {loopUnrolling ( codeString) , workgroupSize, precision};
632632}
633633
634634// ─────────────────────────────────────────────────────────────────────────────
@@ -663,25 +663,25 @@ fn main(@builtin(workgroup_id) wg: vec3<u32>) {
663663
664664 for (var k: u32 = 0u; k < {{K}}; k = k + 8u) {
665665 workgroupBarrier();
666- for (var i : u32 = 0; i < {{TM}}; i ++) {
667- Ax[i ] = subgroupMatrixLoad<subgroup_matrix_left<{{precision}},8,8>>(&A, baseA + i * 8u*{{K}} + k, false, {{K}});
666+ for (var idx_i : u32 = 0; idx_i < {{TM}}; idx_i ++) {
667+ Ax[idx_i ] = subgroupMatrixLoad<subgroup_matrix_left<{{precision}},8,8>>(&A, baseA + idx_i * 8u*{{K}} + k, false, {{K}});
668668 }
669669
670- for (var i : u32 = 0; i < {{TN}}; i ++) {
671- Bx[i ] = subgroupMatrixLoad<subgroup_matrix_right<{{precision}},8,8>>(&B, baseB + k*{{N}} + 8u * i , false, {{N}});
670+ for (var idx_i : u32 = 0; idx_i < {{TN}}; idx_i ++) {
671+ Bx[idx_i ] = subgroupMatrixLoad<subgroup_matrix_right<{{precision}},8,8>>(&B, baseB + k*{{N}} + 8u * idx_i , false, {{N}});
672672 }
673673
674- for (var i : u32 = 0; i < {{TM}}; i ++) {
675- for (var j : u32 = 0; j < {{TN}}; j ++) {
676- accxx[i+j *{{TM}}] = subgroupMatrixMultiplyAccumulate(Ax[i ], Bx[j ], accxx[i+j *{{TM}}]);
674+ for (var idx_i : u32 = 0; idx_i < {{TM}}; idx_i ++) {
675+ for (var idx_j : u32 = 0; idx_j < {{TN}}; idx_j ++) {
676+ accxx[idx_i+idx_j *{{TM}}] = subgroupMatrixMultiplyAccumulate(Ax[idx_i ], Bx[idx_j ], accxx[idx_i+idx_j *{{TM}}]);
677677 }
678678 }
679679 }
680680
681681 workgroupBarrier();
682- for (var i : u32 = 0; i < {{TM}}; i ++) {
683- for (var j : u32 = 0; j < {{TN}}; j ++) {
684- subgroupMatrixStore(&C, cBase + i * 8u * {{N}} + 8u * j , accxx[i+j *{{TM}}], false, {{N}});
682+ for (var idx_i : u32 = 0; idx_i < {{TM}}; idx_i ++) {
683+ for (var idx_j : u32 = 0; idx_j < {{TN}}; idx_j ++) {
684+ subgroupMatrixStore(&C, cBase + idx_i * 8u * {{N}} + 8u * idx_j , accxx[idx_i+idx_j *{{TM}}], false, {{N}});
685685 }
686686 }
687687}
0 commit comments