Skip to content

Commit e87d791

Browse files
Apply loop-unrolling
1 parent 3165df5 commit e87d791

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

examples/matmul/run.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,7 @@ inline KernelCode createMatmul12(const char *shaderTemplate, const size_t M,
628628
{"{{TM}}", toString(TM)},
629629
{"{{TN}}", toString(TN)}
630630
});
631-
return {codeString, workgroupSize, precision};
631+
return {loopUnrolling(codeString), workgroupSize, precision};
632632
}
633633

634634
// ─────────────────────────────────────────────────────────────────────────────
@@ -663,25 +663,25 @@ fn main(@builtin(workgroup_id) wg: vec3<u32>) {
663663
664664
for (var k: u32 = 0u; k < {{K}}; k = k + 8u) {
665665
workgroupBarrier();
666-
for (var i: u32 = 0; i < {{TM}}; i++) {
667-
Ax[i] = subgroupMatrixLoad<subgroup_matrix_left<{{precision}},8,8>>(&A, baseA + i * 8u*{{K}} + k, false, {{K}});
666+
for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
667+
Ax[idx_i] = subgroupMatrixLoad<subgroup_matrix_left<{{precision}},8,8>>(&A, baseA + idx_i * 8u*{{K}} + k, false, {{K}});
668668
}
669669
670-
for (var i: u32 = 0; i < {{TN}}; i++) {
671-
Bx[i] = subgroupMatrixLoad<subgroup_matrix_right<{{precision}},8,8>>(&B, baseB + k*{{N}} + 8u * i, false, {{N}});
670+
for (var idx_i: u32 = 0; idx_i < {{TN}}; idx_i++) {
671+
Bx[idx_i] = subgroupMatrixLoad<subgroup_matrix_right<{{precision}},8,8>>(&B, baseB + k*{{N}} + 8u * idx_i, false, {{N}});
672672
}
673673
674-
for (var i: u32 = 0; i < {{TM}}; i++) {
675-
for (var j: u32 = 0; j < {{TN}}; j++) {
676-
accxx[i+j*{{TM}}] = subgroupMatrixMultiplyAccumulate(Ax[i], Bx[j], accxx[i+j*{{TM}}]);
674+
for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
675+
for (var idx_j: u32 = 0; idx_j < {{TN}}; idx_j++) {
676+
accxx[idx_i+idx_j*{{TM}}] = subgroupMatrixMultiplyAccumulate(Ax[idx_i], Bx[idx_j], accxx[idx_i+idx_j*{{TM}}]);
677677
}
678678
}
679679
}
680680
681681
workgroupBarrier();
682-
for (var i: u32 = 0; i < {{TM}}; i++) {
683-
for (var j: u32 = 0; j < {{TN}}; j++) {
684-
subgroupMatrixStore(&C, cBase + i * 8u * {{N}} + 8u * j, accxx[i+j*{{TM}}], false, {{N}});
682+
for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
683+
for (var idx_j: u32 = 0; idx_j < {{TN}}; idx_j++) {
684+
subgroupMatrixStore(&C, cBase + idx_i * 8u * {{N}} + 8u * idx_j, accxx[idx_i+idx_j*{{TM}}], false, {{N}});
685685
}
686686
}
687687
}

0 commit comments

Comments
 (0)