Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 36 additions & 15 deletions llvm/lib/Target/AMDGPU/VOP3PInstructions.td
Original file line number Diff line number Diff line change
Expand Up @@ -1814,21 +1814,42 @@ def F32_FP8BF8X128_SWMMAC_w32 : VOP3PWMMA_Profile<[v8f32, v8i32, v16i32, v8f
def F16_FP8BF8X128_SWMMAC_w32 : VOP3PWMMA_Profile<[v8f16, v8i32, v16i32, v8f16], 1, 32, 0, 1, 1, 0, 0, 0, 1>;
def I32_IU8X128_SWMMAC_w32 : VOP3PWMMA_Profile<[v8i32, v8i32, v16i32, v8i32], 1, 32, 1, 0, 1, 0, 0, 0, 1>;

multiclass WMMA_F8F6F4_Profiles<bit HasMatrixScale, bit Scale16, bit HasMatrixReuse> {
def _f8_f8_w32 : VOP3PWMMA_Profile<[v8f32, v16i32, v16i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
def _f8_f6_w32 : VOP3PWMMA_Profile<[v8f32, v16i32, v12i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
def _f8_f4_w32 : VOP3PWMMA_Profile<[v8f32, v16i32, v8i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
def _f6_f8_w32 : VOP3PWMMA_Profile<[v8f32, v12i32, v16i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
def _f6_f6_w32 : VOP3PWMMA_Profile<[v8f32, v12i32, v12i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
def _f6_f4_w32 : VOP3PWMMA_Profile<[v8f32, v12i32, v8i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
def _f4_f8_w32 : VOP3PWMMA_Profile<[v8f32, v8i32, v16i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
def _f4_f6_w32 : VOP3PWMMA_Profile<[v8f32, v8i32, v12i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
def _f4_f4_w32 : VOP3PWMMA_Profile<[v8f32, v8i32, v8i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
}

defm F32_16X16X128_F8F6F4 : WMMA_F8F6F4_Profiles<0, 0, 0>;
defm F32_16X16X128_F8F6F4_SCALE : WMMA_F8F6F4_Profiles<1, 0, 1>;
defm F32_16X16X128_F8F6F4_SCALE16 : WMMA_F8F6F4_Profiles<1, 1, 1>;
// Helper class to compute the destination vector type of WMMA_F8F6F4 instructions based on element type and dimensions.
class getWMMAF8F6F4DstVTy<ValueType DstEltTy, int M, int N> {
// Size in bits = (M * N / 32) * element_size_in_bits
defvar Size = !mul(!div(!mul(M, N), 32), DstEltTy.Size);
ValueType ret = !cond(!eq(Size, 256) : v8f32,
!eq(Size, 1024) : v64f16);
}

// Helper class to compute the type of matrix A and B of WMMA_F8F6F4 instructions based on format and dimensions.
class getWMMAF8F6F4ABVTy<string Fmt, int D1, int D2> {
defvar FmtBits = !cond(!eq(Fmt, "f8") : 8,
!eq(Fmt, "f6") : 6,
!eq(Fmt, "f4") : 4);
// TypeSize in bits = (D1 * D2 / 32) * format_bits
defvar TypeSize = !mul(!div(!mul(D1, D2), 32), FmtBits);
ValueType ret = !cond(!eq(TypeSize, 256) : v8i32,
!eq(TypeSize, 384) : v12i32,
!eq(TypeSize, 512) : v16i32,
!eq(TypeSize, 1024) : v32i32);
}

multiclass WMMA_F8F6F4_Profiles<ValueType DstEltTy, int M, int N, int K,
bit HasMatrixScale, bit Scale16, bit HasMatrixReuse> {
defvar DstTy = getWMMAF8F6F4DstVTy<DstEltTy, M, N>.ret;
foreach ATy = ["f8", "f6", "f4"] in {
foreach BTy = ["f8", "f6", "f4"] in {
def _#ATy#_#BTy#_w32 : VOP3PWMMA_Profile<
[DstTy, getWMMAF8F6F4ABVTy<ATy, M, K>.ret, getWMMAF8F6F4ABVTy<BTy, K, N>.ret, DstTy],
0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
}
}
}

defm F32_16X16X128_F8F6F4 : WMMA_F8F6F4_Profiles<f32, /*M=*/16, /*N=*/16, /*K=*/128, /*HasMatrixScale=*/0, /*Scale16=*/0, /*HasMatrixReuse=*/0>;
defm F32_16X16X128_F8F6F4_SCALE : WMMA_F8F6F4_Profiles<f32, /*M=*/16, /*N=*/16, /*K=*/128, /*HasMatrixScale=*/1, /*Scale16=*/0, /*HasMatrixReuse=*/1>;
defm F32_16X16X128_F8F6F4_SCALE16 : WMMA_F8F6F4_Profiles<f32, /*M=*/16, /*N=*/16, /*K=*/128, /*HasMatrixScale=*/1, /*Scale16=*/1, /*HasMatrixReuse=*/1>;

class VOP_WMMA_LD_SCALE<ValueType vt, RegisterOperand RC> : VOP3P_Profile<VOPProfile<[untyped, vt, vt, untyped]>> {
let HasMatrixScale = 1;
Expand Down