Skip to content

Conversation

@shiltian
Copy link
Contributor

No description provided.

Copy link
Contributor Author

This stack of pull requests is managed by Graphite. Learn more about stacking.

@llvmbot
Copy link
Member

llvmbot commented Dec 15, 2025

@llvm/pr-subscribers-backend-amdgpu

Author: Shilei Tian (shiltian)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/172245.diff

1 Files Affected:

  • (modified) llvm/lib/Target/AMDGPU/VOP3PInstructions.td (+36-15)
diff --git a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
index 2dfa905848a34..410e56d83331b 100644
--- a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
+++ b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
@@ -1814,21 +1814,42 @@ def F32_FP8BF8X128_SWMMAC_w32    : VOP3PWMMA_Profile<[v8f32, v8i32,  v16i32, v8f
 def F16_FP8BF8X128_SWMMAC_w32    : VOP3PWMMA_Profile<[v8f16, v8i32,  v16i32, v8f16], 1, 32, 0, 1, 1, 0, 0, 0, 1>;
 def I32_IU8X128_SWMMAC_w32       : VOP3PWMMA_Profile<[v8i32, v8i32,  v16i32, v8i32], 1, 32, 1, 0, 1, 0, 0, 0, 1>;
 
-multiclass WMMA_F8F6F4_Profiles<bit HasMatrixScale, bit Scale16, bit HasMatrixReuse> {
-  def _f8_f8_w32 : VOP3PWMMA_Profile<[v8f32, v16i32, v16i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
-  def _f8_f6_w32 : VOP3PWMMA_Profile<[v8f32, v16i32, v12i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
-  def _f8_f4_w32 : VOP3PWMMA_Profile<[v8f32, v16i32, v8i32,  v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
-  def _f6_f8_w32 : VOP3PWMMA_Profile<[v8f32, v12i32, v16i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
-  def _f6_f6_w32 : VOP3PWMMA_Profile<[v8f32, v12i32, v12i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
-  def _f6_f4_w32 : VOP3PWMMA_Profile<[v8f32, v12i32, v8i32,  v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
-  def _f4_f8_w32 : VOP3PWMMA_Profile<[v8f32, v8i32,  v16i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
-  def _f4_f6_w32 : VOP3PWMMA_Profile<[v8f32, v8i32,  v12i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
-  def _f4_f4_w32 : VOP3PWMMA_Profile<[v8f32, v8i32,  v8i32,  v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
-}
-
-defm F32_16X16X128_F8F6F4         : WMMA_F8F6F4_Profiles<0, 0, 0>;
-defm F32_16X16X128_F8F6F4_SCALE   : WMMA_F8F6F4_Profiles<1, 0, 1>;
-defm F32_16X16X128_F8F6F4_SCALE16 : WMMA_F8F6F4_Profiles<1, 1, 1>;
+// Helper class to compute the destination vector type of WMMA_F8F6F4 instructions based on element type and dimensions.
+class getWMMAF8F6F4DstVTy<ValueType DstEltTy, int M, int N> {
+  // Size in bits = (M * N / 32) * element_size_in_bits
+  defvar Size = !mul(!div(!mul(M, N), 32), DstEltTy.Size);
+  ValueType ret = !cond(!eq(Size, 256)  : v8f32,
+                        !eq(Size, 1024) : v64f16);
+}
+
+// Helper class to compute the type of matrix A and B of WMMA_F8F6F4 instructions based on format and dimensions.
+class getWMMAF8F6F4ABVTy<string Fmt, int D1, int D2> {
+  defvar FmtBits = !cond(!eq(Fmt, "f8") : 8,
+                         !eq(Fmt, "f6") : 6,
+                         !eq(Fmt, "f4") : 4);
+  // TypeSize in bits = (D1 * D2 / 32) * format_bits
+  defvar TypeSize = !mul(!div(!mul(D1, D2), 32), FmtBits);
+  ValueType ret = !cond(!eq(TypeSize, 256)  : v8i32,
+                        !eq(TypeSize, 384)  : v12i32,
+                        !eq(TypeSize, 512)  : v16i32,
+                        !eq(TypeSize, 1024) : v32i32);
+}
+
+multiclass WMMA_F8F6F4_Profiles<ValueType DstEltTy, int M, int N, int K,
+                                bit HasMatrixScale, bit Scale16, bit HasMatrixReuse> {
+  defvar DstTy = getWMMAF8F6F4DstVTy<DstEltTy, M, N>.ret;
+  foreach ATy = ["f8", "f6", "f4"] in {
+    foreach BTy = ["f8", "f6", "f4"] in {
+      def _#ATy#_#BTy#_w32 : VOP3PWMMA_Profile<
+        [DstTy, getWMMAF8F6F4ABVTy<ATy, M, K>.ret, getWMMAF8F6F4ABVTy<BTy, K, N>.ret, DstTy],
+        0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
+    }
+  }
+}
+
+defm F32_16X16X128_F8F6F4         : WMMA_F8F6F4_Profiles<f32, /*M=*/16, /*N=*/16, /*K=*/128, /*HasMatrixScale=*/0, /*Scale16=*/0, /*HasMatrixReuse=*/0>;
+defm F32_16X16X128_F8F6F4_SCALE   : WMMA_F8F6F4_Profiles<f32, /*M=*/16, /*N=*/16, /*K=*/128, /*HasMatrixScale=*/1, /*Scale16=*/0, /*HasMatrixReuse=*/1>;
+defm F32_16X16X128_F8F6F4_SCALE16 : WMMA_F8F6F4_Profiles<f32, /*M=*/16, /*N=*/16, /*K=*/128, /*HasMatrixScale=*/1, /*Scale16=*/1, /*HasMatrixReuse=*/1>;
 
 class VOP_WMMA_LD_SCALE<ValueType vt, RegisterOperand RC> : VOP3P_Profile<VOPProfile<[untyped, vt, vt, untyped]>> {
   let HasMatrixScale = 1;

@shiltian shiltian requested a review from changpeng December 15, 2025 03:57
@shiltian shiltian merged commit df14096 into main Dec 15, 2025
12 checks passed
@shiltian shiltian deleted the users/shiltian/refactor-wmma-f8f6f4-multiclass branch December 15, 2025 14:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants