Skip to content

Commit ffbb983

Browse files
Disable chromium.subgroup_matrix_uniformity
1 parent e87d791 commit ffbb983

File tree

1 file changed

+38
-17
lines changed

1 file changed

+38
-17
lines changed

examples/matmul/run.cpp

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,7 @@ inline KernelCode createMatmulWithTranspose(const char *shaderTemplate, const si
618618
inline KernelCode createMatmul12(const char *shaderTemplate, const size_t M,
619619
const size_t K, const size_t N,
620620
const size_t TM, const size_t TN,
621+
const size_t LID,
621622
const Shape &workgroupSize = {256, 1, 1},
622623
NumType precision = kf32) {
623624
std::string codeString(shaderTemplate);
@@ -626,7 +627,8 @@ inline KernelCode createMatmul12(const char *shaderTemplate, const size_t M,
626627
{"{{K}}", toString(K)},
627628
{"{{N}}", toString(N)},
628629
{"{{TM}}", toString(TM)},
629-
{"{{TN}}", toString(TN)}
630+
{"{{TN}}", toString(TN)},
631+
{"{{LID}}", toString(LID)}
630632
});
631633
return {loopUnrolling(codeString), workgroupSize, precision};
632634
}
@@ -638,18 +640,18 @@ inline KernelCode createMatmul12(const char *shaderTemplate, const size_t M,
638640
const char* kShaderSubgroupMatrixMultiply = R"(
639641
enable subgroups;
640642
enable chromium_experimental_subgroup_matrix;
643+
diagnostic (off, chromium.subgroup_matrix_uniformity);
641644
642645
@group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>;
643646
@group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>;
644647
@group(0) @binding(2) var<storage, read_write> C: array<{{precision}}>;
645648
646649
@compute @workgroup_size({{workgroupSize}})
647-
fn main(@builtin(workgroup_id) wg: vec3<u32>) {
650+
fn main(@builtin(workgroup_id) wg: vec3<u32>,
651+
@builtin(local_invocation_id) localID : vec3<u32>) {
648652
649653
let rowStart: u32 = wg.x * 8u * {{TM}};
650-
let colStart: u32 = wg.y * 8u * {{TN}};
651-
652-
if (rowStart >= u32({{M}}) || colStart >= u32({{N}})) { return; }
654+
let colStart: u32 = (wg.y * {{LID}} + localID.y) * 8u * {{TN}};
653655
654656
let baseA: u32 = rowStart * {{K}};
655657
let baseB: u32 = colStart;
@@ -661,27 +663,41 @@ fn main(@builtin(workgroup_id) wg: vec3<u32>) {
661663
// 4x4 accumulators (8x8 each)
662664
var accxx: array<subgroup_matrix_result<{{precision}}, 8, 8>, {{TM}} * {{TN}}>;
663665
666+
for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
667+
Ax[idx_i] = subgroup_matrix_left<{{precision}}, 8, 8>(0);
668+
}
669+
670+
for (var idx_i: u32 = 0; idx_i < {{TN}}; idx_i++) {
671+
Bx[idx_i] = subgroup_matrix_right<{{precision}}, 8, 8>(0);
672+
}
673+
674+
for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
675+
for (var idx_j: u32 = 0; idx_j < {{TN}}; idx_j++) {
676+
accxx[idx_i+idx_j*{{TM}}] = subgroup_matrix_result<{{precision}}, 8, 8>(0);
677+
}
678+
}
679+
664680
for (var k: u32 = 0u; k < {{K}}; k = k + 8u) {
665681
workgroupBarrier();
666682
for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
667-
Ax[idx_i] = subgroupMatrixLoad<subgroup_matrix_left<{{precision}},8,8>>(&A, baseA + idx_i * 8u*{{K}} + k, false, {{K}});
683+
Ax[idx_i] = subgroupMatrixLoad<subgroup_matrix_left<{{precision}},8,8>>(&A, baseA + k + 8u * {{K}} * idx_i, false, {{K}});
668684
}
669685
670686
for (var idx_i: u32 = 0; idx_i < {{TN}}; idx_i++) {
671-
Bx[idx_i] = subgroupMatrixLoad<subgroup_matrix_right<{{precision}},8,8>>(&B, baseB + k*{{N}} + 8u * idx_i, false, {{N}});
687+
Bx[idx_i] = subgroupMatrixLoad<subgroup_matrix_right<{{precision}},8,8>>(&B, baseB + k * {{N}} + 8u * idx_i, false, {{N}});
672688
}
673689
674-
for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
675-
for (var idx_j: u32 = 0; idx_j < {{TN}}; idx_j++) {
676-
accxx[idx_i+idx_j*{{TM}}] = subgroupMatrixMultiplyAccumulate(Ax[idx_i], Bx[idx_j], accxx[idx_i+idx_j*{{TM}}]);
690+
for (var idx_j: u32 = 0; idx_j < {{TN}}; idx_j++) {
691+
for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
692+
accxx[idx_j*{{TM}} + idx_i] = subgroupMatrixMultiplyAccumulate(Ax[idx_i], Bx[idx_j], accxx[idx_j*{{TM}} + idx_i]);
677693
}
678694
}
679695
}
680696
681697
workgroupBarrier();
682698
for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
683699
for (var idx_j: u32 = 0; idx_j < {{TN}}; idx_j++) {
684-
subgroupMatrixStore(&C, cBase + idx_i * 8u * {{N}} + 8u * idx_j, accxx[idx_i+idx_j*{{TM}}], false, {{N}});
700+
subgroupMatrixStore(&C, cBase + idx_i * 8u * {{N}} + 8u * idx_j, accxx[idx_j*{{TM}} + idx_i], false, {{N}});
685701
}
686702
}
687703
}
@@ -858,15 +874,16 @@ Kernel selectMatmul(Context &ctx, int version,
858874
/*nWorkgroups*/ nWorkgroups,
859875
NoParam{}, &info);
860876
} else if (version == 12) {
861-
// f32: Subgroup matrix multiply
862-
static constexpr size_t TM = 2;
863-
static constexpr size_t TN = 4;
864-
Shape wgSize = {64, 1, 1}; // One subgroup per workgroup
865-
Shape nWorkgroups = {cdiv(M, 8 * TM), cdiv(N, 8 * TN), 1};
877+
// f16: Subgroup matrix multiply
878+
static constexpr size_t TM = 4;
879+
static constexpr size_t TN = 8;
880+
static constexpr size_t LID = 2;
881+
Shape wgSize = {64, LID, 1}; // One subgroup per workgroup
882+
Shape nWorkgroups = {cdiv(M, 8 * TM), cdiv(N, 8 * TN * LID), 1};
866883
LOG(kDefLog, kInfo, "M: %zu, K: %zu, N: %zu", M, K, N);
867884
LOG(kDefLog, kInfo, "wgSize: ( %s )", toString(wgSize).c_str());
868885
LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str());
869-
KernelCode matmul = createMatmul12(kShaderSubgroupMatrixMultiply, M, K, N, TM, TN, wgSize, numtype);
886+
KernelCode matmul = createMatmul12(kShaderSubgroupMatrixMultiply, M, K, N, TM, TN, LID, wgSize, numtype);
870887
kernel = createKernel(ctx, matmul, bindings, nWorkgroups,
871888
NoParam{}, &info);
872889
}
@@ -931,6 +948,10 @@ void runTest(int version, size_t M, size_t K, size_t N,
931948
LOG(kDefLog, kError, "[DeviceLost %d] %.*s\n", (int)reason, (int)msg.length, msg.data);
932949
}
933950
};
951+
952+
static WGPULimits requiredLimits = WGPU_LIMITS_INIT;
953+
954+
devDesc.requiredLimits = &requiredLimits;
934955

935956
Context ctx = createContext({}, {}, devDesc);
936957

0 commit comments

Comments
 (0)