Skip to content

Commit ca004da

Browse files
authored
Add knowledge of cooperative matrices (KhronosGroup#5720)
* Add knowledge of cooperative matrices Some optimizations are not aware of cooperative matrices, and either do nothing or assert. This commits fixes that up. * Add int tests, and a handle a couple more cases. * Add float tests, and a handle a couple more cases. * Add NV coop matrix as well.
1 parent 64d37e2 commit ca004da

7 files changed

+201
-5
lines changed

source/opt/aggressive_dead_code_elim_pass.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1004,7 +1004,9 @@ void AggressiveDCEPass::InitExtensions() {
10041004
"SPV_NV_bindless_texture",
10051005
"SPV_EXT_shader_atomic_float_add",
10061006
"SPV_EXT_fragment_shader_interlock",
1007-
"SPV_NV_compute_shader_derivatives"
1007+
"SPV_NV_compute_shader_derivatives",
1008+
"SPV_NV_cooperative_matrix",
1009+
"SPV_KHR_cooperative_matrix"
10081010
});
10091011
// clang-format on
10101012
}

source/opt/folding_rules.cpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,12 @@ bool IsValidResult(T val) {
112112
}
113113
}
114114

115+
// Returns true if `type` is a cooperative matrix.
116+
bool IsCooperativeMatrix(const analysis::Type* type) {
117+
return type->kind() == analysis::Type::kCooperativeMatrixKHR ||
118+
type->kind() == analysis::Type::kCooperativeMatrixNV;
119+
}
120+
115121
const analysis::Constant* ConstInput(
116122
const std::vector<const analysis::Constant*>& constants) {
117123
return constants[0] ? constants[0] : constants[1];
@@ -313,6 +319,11 @@ FoldingRule ReciprocalFDiv() {
313319
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
314320
const analysis::Type* type =
315321
context->get_type_mgr()->GetType(inst->type_id());
322+
323+
if (IsCooperativeMatrix(type)) {
324+
return false;
325+
}
326+
316327
if (!inst->IsFloatingPointFoldingAllowed()) return false;
317328

318329
uint32_t width = ElementWidth(type);
@@ -394,6 +405,11 @@ FoldingRule MergeNegateMulDivArithmetic() {
394405
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
395406
const analysis::Type* type =
396407
context->get_type_mgr()->GetType(inst->type_id());
408+
409+
if (IsCooperativeMatrix(type)) {
410+
return false;
411+
}
412+
397413
if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
398414
return false;
399415

@@ -455,6 +471,11 @@ FoldingRule MergeNegateAddSubArithmetic() {
455471
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
456472
const analysis::Type* type =
457473
context->get_type_mgr()->GetType(inst->type_id());
474+
475+
if (IsCooperativeMatrix(type)) {
476+
return false;
477+
}
478+
458479
if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
459480
return false;
460481

@@ -686,6 +707,11 @@ FoldingRule MergeMulMulArithmetic() {
686707
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
687708
const analysis::Type* type =
688709
context->get_type_mgr()->GetType(inst->type_id());
710+
711+
if (IsCooperativeMatrix(type)) {
712+
return false;
713+
}
714+
689715
if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
690716
return false;
691717

@@ -740,6 +766,11 @@ FoldingRule MergeMulDivArithmetic() {
740766

741767
const analysis::Type* type =
742768
context->get_type_mgr()->GetType(inst->type_id());
769+
770+
if (IsCooperativeMatrix(type)) {
771+
return false;
772+
}
773+
743774
if (!inst->IsFloatingPointFoldingAllowed()) return false;
744775

745776
uint32_t width = ElementWidth(type);
@@ -813,6 +844,11 @@ FoldingRule MergeMulNegateArithmetic() {
813844
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
814845
const analysis::Type* type =
815846
context->get_type_mgr()->GetType(inst->type_id());
847+
848+
if (IsCooperativeMatrix(type)) {
849+
return false;
850+
}
851+
816852
bool uses_float = HasFloatingPoint(type);
817853
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
818854

@@ -853,6 +889,11 @@ FoldingRule MergeDivDivArithmetic() {
853889
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
854890
const analysis::Type* type =
855891
context->get_type_mgr()->GetType(inst->type_id());
892+
893+
if (IsCooperativeMatrix(type)) {
894+
return false;
895+
}
896+
856897
if (!inst->IsFloatingPointFoldingAllowed()) return false;
857898

858899
uint32_t width = ElementWidth(type);
@@ -926,6 +967,11 @@ FoldingRule MergeDivMulArithmetic() {
926967

927968
const analysis::Type* type =
928969
context->get_type_mgr()->GetType(inst->type_id());
970+
971+
if (IsCooperativeMatrix(type)) {
972+
return false;
973+
}
974+
929975
if (!inst->IsFloatingPointFoldingAllowed()) return false;
930976

931977
uint32_t width = ElementWidth(type);
@@ -1068,6 +1114,11 @@ FoldingRule MergeSubNegateArithmetic() {
10681114
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
10691115
const analysis::Type* type =
10701116
context->get_type_mgr()->GetType(inst->type_id());
1117+
1118+
if (IsCooperativeMatrix(type)) {
1119+
return false;
1120+
}
1121+
10711122
bool uses_float = HasFloatingPoint(type);
10721123
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
10731124

@@ -1116,6 +1167,11 @@ FoldingRule MergeAddAddArithmetic() {
11161167
inst->opcode() == spv::Op::OpIAdd);
11171168
const analysis::Type* type =
11181169
context->get_type_mgr()->GetType(inst->type_id());
1170+
1171+
if (IsCooperativeMatrix(type)) {
1172+
return false;
1173+
}
1174+
11191175
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
11201176
bool uses_float = HasFloatingPoint(type);
11211177
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
@@ -1164,6 +1220,11 @@ FoldingRule MergeAddSubArithmetic() {
11641220
inst->opcode() == spv::Op::OpIAdd);
11651221
const analysis::Type* type =
11661222
context->get_type_mgr()->GetType(inst->type_id());
1223+
1224+
if (IsCooperativeMatrix(type)) {
1225+
return false;
1226+
}
1227+
11671228
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
11681229
bool uses_float = HasFloatingPoint(type);
11691230
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
@@ -1224,6 +1285,11 @@ FoldingRule MergeSubAddArithmetic() {
12241285
inst->opcode() == spv::Op::OpISub);
12251286
const analysis::Type* type =
12261287
context->get_type_mgr()->GetType(inst->type_id());
1288+
1289+
if (IsCooperativeMatrix(type)) {
1290+
return false;
1291+
}
1292+
12271293
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
12281294
bool uses_float = HasFloatingPoint(type);
12291295
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
@@ -1290,6 +1356,11 @@ FoldingRule MergeSubSubArithmetic() {
12901356
inst->opcode() == spv::Op::OpISub);
12911357
const analysis::Type* type =
12921358
context->get_type_mgr()->GetType(inst->type_id());
1359+
1360+
if (IsCooperativeMatrix(type)) {
1361+
return false;
1362+
}
1363+
12931364
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
12941365
bool uses_float = HasFloatingPoint(type);
12951366
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
@@ -1383,6 +1454,11 @@ FoldingRule MergeGenericAddSubArithmetic() {
13831454
inst->opcode() == spv::Op::OpIAdd);
13841455
const analysis::Type* type =
13851456
context->get_type_mgr()->GetType(inst->type_id());
1457+
1458+
if (IsCooperativeMatrix(type)) {
1459+
return false;
1460+
}
1461+
13861462
bool uses_float = HasFloatingPoint(type);
13871463
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
13881464

source/opt/local_access_chain_convert_pass.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -428,8 +428,8 @@ void LocalAccessChainConvertPass::InitExtensions() {
428428
"SPV_KHR_uniform_group_instructions",
429429
"SPV_KHR_fragment_shader_barycentric", "SPV_KHR_vulkan_memory_model",
430430
"SPV_NV_bindless_texture", "SPV_EXT_shader_atomic_float_add",
431-
"SPV_EXT_fragment_shader_interlock",
432-
"SPV_NV_compute_shader_derivatives"});
431+
"SPV_EXT_fragment_shader_interlock", "SPV_NV_compute_shader_derivatives",
432+
"SPV_NV_cooperative_matrix", "SPV_KHR_cooperative_matrix"});
433433
}
434434

435435
bool LocalAccessChainConvertPass::AnyIndexIsOutOfBounds(

source/opt/local_single_block_elim_pass.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,9 @@ void LocalSingleBlockLoadStoreElimPass::InitExtensions() {
291291
"SPV_NV_bindless_texture",
292292
"SPV_EXT_shader_atomic_float_add",
293293
"SPV_EXT_fragment_shader_interlock",
294-
"SPV_NV_compute_shader_derivatives"});
294+
"SPV_NV_compute_shader_derivatives",
295+
"SPV_NV_cooperative_matrix",
296+
"SPV_KHR_cooperative_matrix"});
295297
}
296298

297299
} // namespace opt

source/opt/local_single_store_elim_pass.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,9 @@ void LocalSingleStoreElimPass::InitExtensionAllowList() {
141141
"SPV_NV_bindless_texture",
142142
"SPV_EXT_shader_atomic_float_add",
143143
"SPV_EXT_fragment_shader_interlock",
144-
"SPV_NV_compute_shader_derivatives"});
144+
"SPV_NV_compute_shader_derivatives",
145+
"SPV_NV_cooperative_matrix",
146+
"SPV_KHR_cooperative_matrix"});
145147
}
146148
bool LocalSingleStoreElimPass::ProcessVariable(Instruction* var_inst) {
147149
std::vector<Instruction*> users;

source/opt/mem_pass.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ bool MemPass::IsBaseTargetType(const Instruction* typeInst) const {
4343
case spv::Op::OpTypeSampler:
4444
case spv::Op::OpTypeSampledImage:
4545
case spv::Op::OpTypePointer:
46+
case spv::Op::OpTypeCooperativeMatrixNV:
47+
case spv::Op::OpTypeCooperativeMatrixKHR:
4648
return true;
4749
default:
4850
break;

test/opt/fold_test.cpp

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@ OpCapability Float64
215215
OpCapability Int8
216216
OpCapability Int16
217217
OpCapability Int64
218+
OpCapability CooperativeMatrixKHR
219+
OpExtension "SPV_KHR_cooperative_matrix"
218220
%1 = OpExtInstImport "GLSL.std.450"
219221
OpMemoryModel Logical GLSL450
220222
OpEntryPoint Fragment %main "main"
@@ -434,6 +436,12 @@ OpName %main "main"
434436
%ushort_0xBC00 = OpConstant %ushort 0xBC00
435437
%short_0xBC00 = OpConstant %short 0xBC00
436438
%int_arr_2_undef = OpUndef %int_arr_2
439+
%int_coop_matrix = OpTypeCooperativeMatrixKHR %int %uint_3 %uint_3 %uint_32 %uint_0
440+
%undef_int_coop_matrix = OpUndef %int_coop_matrix
441+
%uint_coop_matrix = OpTypeCooperativeMatrixKHR %uint %uint_3 %uint_3 %uint_32 %uint_0
442+
%undef_uint_coop_matrix = OpUndef %uint_coop_matrix
443+
%float_coop_matrix = OpTypeCooperativeMatrixKHR %float %uint_3 %uint_3 %uint_32 %uint_0
444+
%undef_float_coop_matrix = OpUndef %float_coop_matrix
437445
)";
438446

439447
return header;
@@ -4148,6 +4156,62 @@ INSTANTIATE_TEST_SUITE_P(IntegerArithmeticTestCases, GeneralInstructionFoldingTe
41484156
"%2 = OpSLessThan %bool %long_0 %long_2\n" +
41494157
"OpReturn\n" +
41504158
"OpFunctionEnd",
4159+
2, 0),
4160+
// Test case 41: Don't fold OpSNegate for cooperative matrices.
4161+
InstructionFoldingCase<uint32_t>(
4162+
Header() + "%main = OpFunction %void None %void_func\n" +
4163+
"%main_lab = OpLabel\n" +
4164+
"%2 = OpSNegate %int_coop_matrix %undef_int_coop_matrix\n" +
4165+
"OpReturn\n" +
4166+
"OpFunctionEnd",
4167+
2, 0),
4168+
// Test case 42: Don't fold OpIAdd for cooperative matrices.
4169+
InstructionFoldingCase<uint32_t>(
4170+
Header() + "%main = OpFunction %void None %void_func\n" +
4171+
"%main_lab = OpLabel\n" +
4172+
"%2 = OpIAdd %int_coop_matrix %undef_int_coop_matrix %undef_int_coop_matrix\n" +
4173+
"OpReturn\n" +
4174+
"OpFunctionEnd",
4175+
2, 0),
4176+
// Test case 43: Don't fold OpISub for cooperative matrices.
4177+
InstructionFoldingCase<uint32_t>(
4178+
Header() + "%main = OpFunction %void None %void_func\n" +
4179+
"%main_lab = OpLabel\n" +
4180+
"%2 = OpISub %int_coop_matrix %undef_int_coop_matrix %undef_int_coop_matrix\n" +
4181+
"OpReturn\n" +
4182+
"OpFunctionEnd",
4183+
2, 0),
4184+
// Test case 44: Don't fold OpIMul for cooperative matrices.
4185+
InstructionFoldingCase<uint32_t>(
4186+
Header() + "%main = OpFunction %void None %void_func\n" +
4187+
"%main_lab = OpLabel\n" +
4188+
"%2 = OpIMul %int_coop_matrix %undef_int_coop_matrix %undef_int_coop_matrix\n" +
4189+
"OpReturn\n" +
4190+
"OpFunctionEnd",
4191+
2, 0),
4192+
// Test case 45: Don't fold OpSDiv for cooperative matrices.
4193+
InstructionFoldingCase<uint32_t>(
4194+
Header() + "%main = OpFunction %void None %void_func\n" +
4195+
"%main_lab = OpLabel\n" +
4196+
"%2 = OpSDiv %int_coop_matrix %undef_int_coop_matrix %undef_int_coop_matrix\n" +
4197+
"OpReturn\n" +
4198+
"OpFunctionEnd",
4199+
2, 0),
4200+
// Test case 46: Don't fold OpUDiv for cooperative matrices.
4201+
InstructionFoldingCase<uint32_t>(
4202+
Header() + "%main = OpFunction %void None %void_func\n" +
4203+
"%main_lab = OpLabel\n" +
4204+
"%2 = OpUDiv %uint_coop_matrix %undef_uint_coop_matrix %undef_uint_coop_matrix\n" +
4205+
"OpReturn\n" +
4206+
"OpFunctionEnd",
4207+
2, 0),
4208+
// Test case 47: Don't fold OpMatrixTimesScalar for cooperative matrices.
4209+
InstructionFoldingCase<uint32_t>(
4210+
Header() + "%main = OpFunction %void None %void_func\n" +
4211+
"%main_lab = OpLabel\n" +
4212+
"%2 = OpMatrixTimesScalar %uint_coop_matrix %undef_uint_coop_matrix %uint_3\n" +
4213+
"OpReturn\n" +
4214+
"OpFunctionEnd",
41514215
2, 0)
41524216
));
41534217

@@ -4689,6 +4753,54 @@ INSTANTIATE_TEST_SUITE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTes
46894753
"%2 = OpFDiv %half %half_1 %half_2\n" +
46904754
"OpReturn\n" +
46914755
"OpFunctionEnd",
4756+
2, 0),
4757+
// Test case 24: Don't fold OpFNegate for cooperative matrices.
4758+
InstructionFoldingCase<uint32_t>(
4759+
Header() + "%main = OpFunction %void None %void_func\n" +
4760+
"%main_lab = OpLabel\n" +
4761+
"%2 = OpFNegate %float_coop_matrix %undef_float_coop_matrix\n" +
4762+
"OpReturn\n" +
4763+
"OpFunctionEnd",
4764+
2, 0),
4765+
// Test case 25: Don't fold OpIAdd for cooperative matrices.
4766+
InstructionFoldingCase<uint32_t>(
4767+
Header() + "%main = OpFunction %void None %void_func\n" +
4768+
"%main_lab = OpLabel\n" +
4769+
"%2 = OpFAdd %float_coop_matrix %undef_float_coop_matrix %undef_float_coop_matrix\n" +
4770+
"OpReturn\n" +
4771+
"OpFunctionEnd",
4772+
2, 0),
4773+
// Test case 26: Don't fold OpISub for cooperative matrices.
4774+
InstructionFoldingCase<uint32_t>(
4775+
Header() + "%main = OpFunction %void None %void_func\n" +
4776+
"%main_lab = OpLabel\n" +
4777+
"%2 = OpFSub %float_coop_matrix %undef_float_coop_matrix %undef_float_coop_matrix\n" +
4778+
"OpReturn\n" +
4779+
"OpFunctionEnd",
4780+
2, 0),
4781+
// Test case 27: Don't fold OpIMul for cooperative matrices.
4782+
InstructionFoldingCase<uint32_t>(
4783+
Header() + "%main = OpFunction %void None %void_func\n" +
4784+
"%main_lab = OpLabel\n" +
4785+
"%2 = OpFMul %float_coop_matrix %undef_float_coop_matrix %undef_float_coop_matrix\n" +
4786+
"OpReturn\n" +
4787+
"OpFunctionEnd",
4788+
2, 0),
4789+
// Test case 28: Don't fold OpSDiv for cooperative matrices.
4790+
InstructionFoldingCase<uint32_t>(
4791+
Header() + "%main = OpFunction %void None %void_func\n" +
4792+
"%main_lab = OpLabel\n" +
4793+
"%2 = OpFDiv %float_coop_matrix %undef_float_coop_matrix %undef_float_coop_matrix\n" +
4794+
"OpReturn\n" +
4795+
"OpFunctionEnd",
4796+
2, 0),
4797+
// Test case 29: Don't fold OpMatrixTimesScalar for cooperative matrices.
4798+
InstructionFoldingCase<uint32_t>(
4799+
Header() + "%main = OpFunction %void None %void_func\n" +
4800+
"%main_lab = OpLabel\n" +
4801+
"%2 = OpMatrixTimesScalar %float_coop_matrix %undef_float_coop_matrix %float_3\n" +
4802+
"OpReturn\n" +
4803+
"OpFunctionEnd",
46924804
2, 0)
46934805
));
46944806

0 commit comments

Comments
 (0)