Skip to content

Conversation

pow2clk
Copy link

@pow2clk pow2clk commented Oct 7, 2024

HLSL needs matrix support. This allows matrices when the target language is HLSL and defines a matrix template alias that allows the short forms of matrix types to be defined in typedefs in a default header. Makes some changes to how matrices are printed in HLSL policy. This required a tweak to how existing matrices are printed in diagnostics that is more consistent with other types.

These matrix types will function exactly as the clang matrix extension dictates. Alterations to that behavior both specific to HLSL and also potentially expanding on the matrix extension will follow.

fixes #109839

HLSL needs matrix support. This allows matrices when the target
language is HLSL and defines a matrix template alias that allows
the short forms of matrix types to be defined in typedefs in a
default header. Makes some changes to how matrices are printed
in HLSL policy. This required a tweak to how existing matrices are
printed in diagnostics that is more consistent with other types.

These matrix types will function exactly as the clang matrix extension
dictates. Alterations to that behavior both specific to HLSL and
also potentially expanding on the matrix extension will follow.

fixes llvm#109839
@llvmbot llvmbot added clang Clang issues not falling into any other category backend:X86 clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics HLSL HLSL Language Support labels Oct 7, 2024
@llvmbot
Copy link
Member

llvmbot commented Oct 7, 2024

@llvm/pr-subscribers-hlsl
@llvm/pr-subscribers-backend-x86

@llvm/pr-subscribers-clang

Author: Greg Roth (pow2clk)

Changes

HLSL needs matrix support. This allows matrices when the target language is HLSL and defines a matrix template alias that allows the short forms of matrix types to be defined in typedefs in a default header. Makes some changes to how matrices are printed in HLSL policy. This required a tweak to how existing matrices are printed in diagnostics that is more consistent with other types.

These matrix types will function exactly as the clang matrix extension dictates. Alterations to that behavior both specific to HLSL and also potentially expanding on the matrix extension will follow.

fixes #109839


Patch is 222.61 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/111415.diff

25 Files Affected:

  • (modified) clang/include/clang/Sema/HLSLExternalSemaSource.h (+1)
  • (modified) clang/lib/AST/ASTContext.cpp (+1-1)
  • (modified) clang/lib/AST/TypePrinter.cpp (+28-12)
  • (modified) clang/lib/Headers/hlsl/hlsl_basic_types.h (+232)
  • (modified) clang/lib/Sema/HLSLExternalSemaSource.cpp (+73)
  • (modified) clang/lib/Sema/SemaType.cpp (+1-1)
  • (added) clang/test/AST/HLSL/matrix-alias.hlsl (+49)
  • (modified) clang/test/AST/HLSL/vector-alias.hlsl (+1-1)
  • (modified) clang/test/CodeGenCXX/matrix-type-operators.cpp (+3-5)
  • (added) clang/test/CodeGenHLSL/Types/BuiltinMatrix/matrix-cast-template.hlsl (+351)
  • (added) clang/test/CodeGenHLSL/Types/BuiltinMatrix/matrix-cast.hlsl (+162)
  • (added) clang/test/CodeGenHLSL/Types/BuiltinMatrix/matrix-transpose-template.hlsl (+82)
  • (added) clang/test/CodeGenHLSL/Types/BuiltinMatrix/matrix-transpose.hlsl (+95)
  • (added) clang/test/CodeGenHLSL/Types/BuiltinMatrix/matrix-type-operators-template.hlsl (+447)
  • (added) clang/test/CodeGenHLSL/Types/BuiltinMatrix/matrix-type-operators.hlsl (+1515)
  • (added) clang/test/CodeGenHLSL/Types/BuiltinMatrix/matrix-type.hlsl (+217)
  • (modified) clang/test/CodeGenHLSL/basic_types.hlsl (+18)
  • (added) clang/test/CodeGenHLSL/matrix-types.hlsl (+348)
  • (modified) clang/test/Sema/matrix-type-operators.c (+8-8)
  • (added) clang/test/SemaHLSL/Types/BuiltinMatrix/matrix-cast.hlsl (+138)
  • (added) clang/test/SemaHLSL/Types/BuiltinMatrix/matrix-index-operator-type.hlsl (+27)
  • (added) clang/test/SemaHLSL/Types/BuiltinMatrix/matrix-transpose.hlsl (+56)
  • (added) clang/test/SemaHLSL/Types/BuiltinMatrix/matrix-type-operators.hlsl (+307)
  • (added) clang/test/SemaHLSL/Types/BuiltinMatrix/matrix-type.hlsl (+48)
  • (modified) clang/test/SemaTemplate/matrix-type.cpp (+1-1)
diff --git a/clang/include/clang/Sema/HLSLExternalSemaSource.h b/clang/include/clang/Sema/HLSLExternalSemaSource.h
index 3c7495e66055dc..6f4b72045a9464 100644
--- a/clang/include/clang/Sema/HLSLExternalSemaSource.h
+++ b/clang/include/clang/Sema/HLSLExternalSemaSource.h
@@ -28,6 +28,7 @@ class HLSLExternalSemaSource : public ExternalSemaSource {
   llvm::DenseMap<CXXRecordDecl *, CompletionFunction> Completions;
 
   void defineHLSLVectorAlias();
+  void defineHLSLMatrixAlias();
   void defineTrivialHLSLTypes();
   void defineHLSLTypesWithForwardDeclarations();
 
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index a81429ad6a2380..ed10c210ed170f 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -1381,7 +1381,7 @@ void ASTContext::InitBuiltinTypes(const TargetInfo &Target,
   if (LangOpts.OpenACC && !LangOpts.OpenMP) {
     InitBuiltinType(ArraySectionTy, BuiltinType::ArraySection);
   }
-  if (LangOpts.MatrixTypes)
+  if (LangOpts.MatrixTypes || LangOpts.HLSL)
     InitBuiltinType(IncompleteMatrixIdxTy, BuiltinType::IncompleteMatrixIdx);
 
   // Builtin types for 'id', 'Class', and 'SEL'.
diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp
index ca75bb97c158e1..142717201557f3 100644
--- a/clang/lib/AST/TypePrinter.cpp
+++ b/clang/lib/AST/TypePrinter.cpp
@@ -852,34 +852,50 @@ void TypePrinter::printExtVectorAfter(const ExtVectorType *T, raw_ostream &OS) {
 
 void TypePrinter::printConstantMatrixBefore(const ConstantMatrixType *T,
                                             raw_ostream &OS) {
+  if (Policy.UseHLSLTypes)
+    OS << "matrix<";
   printBefore(T->getElementType(), OS);
-  OS << " __attribute__((matrix_type(";
-  OS << T->getNumRows() << ", " << T->getNumColumns();
-  OS << ")))";
 }
 
 void TypePrinter::printConstantMatrixAfter(const ConstantMatrixType *T,
                                            raw_ostream &OS) {
   printAfter(T->getElementType(), OS);
+  if (Policy.UseHLSLTypes) {
+    OS << ", ";
+    OS << T->getNumRows() << ", " << T->getNumColumns();
+    OS << ">";
+  } else {
+    OS << " __attribute__((matrix_type(";
+    OS << T->getNumRows() << ", " << T->getNumColumns();
+    OS << ")))";
+  }
 }
 
 void TypePrinter::printDependentSizedMatrixBefore(
     const DependentSizedMatrixType *T, raw_ostream &OS) {
+  if (Policy.UseHLSLTypes)
+    OS << "matrix<";
   printBefore(T->getElementType(), OS);
-  OS << " __attribute__((matrix_type(";
-  if (T->getRowExpr()) {
-    T->getRowExpr()->printPretty(OS, nullptr, Policy);
-  }
-  OS << ", ";
-  if (T->getColumnExpr()) {
-    T->getColumnExpr()->printPretty(OS, nullptr, Policy);
-  }
-  OS << ")))";
 }
 
 void TypePrinter::printDependentSizedMatrixAfter(
     const DependentSizedMatrixType *T, raw_ostream &OS) {
   printAfter(T->getElementType(), OS);
+  if (Policy.UseHLSLTypes)
+    OS << ", ";
+  else
+    OS << " __attribute__((matrix_type(";
+
+  if (Expr *E = T->getRowExpr())
+    E->printPretty(OS, nullptr, Policy);
+  OS << ", ";
+  if (Expr *E = T->getColumnExpr())
+    E->printPretty(OS, nullptr, Policy);
+
+  if (Policy.UseHLSLTypes)
+    OS << ">";
+  else
+    OS << ")))";
 }
 
 void
diff --git a/clang/lib/Headers/hlsl/hlsl_basic_types.h b/clang/lib/Headers/hlsl/hlsl_basic_types.h
index eff94e0d7f9500..b6eeffa2f5e362 100644
--- a/clang/lib/Headers/hlsl/hlsl_basic_types.h
+++ b/clang/lib/Headers/hlsl/hlsl_basic_types.h
@@ -115,6 +115,238 @@ typedef vector<float64_t, 2> float64_t2;
 typedef vector<float64_t, 3> float64_t3;
 typedef vector<float64_t, 4> float64_t4;
 
+#ifdef __HLSL_ENABLE_16_BIT
+typedef matrix<int16_t, 1, 1> int16_t1x1;
+typedef matrix<int16_t, 1, 2> int16_t1x2;
+typedef matrix<int16_t, 1, 3> int16_t1x3;
+typedef matrix<int16_t, 1, 4> int16_t1x4;
+typedef matrix<int16_t, 2, 1> int16_t2x1;
+typedef matrix<int16_t, 2, 2> int16_t2x2;
+typedef matrix<int16_t, 2, 3> int16_t2x3;
+typedef matrix<int16_t, 2, 4> int16_t2x4;
+typedef matrix<int16_t, 3, 1> int16_t3x1;
+typedef matrix<int16_t, 3, 2> int16_t3x2;
+typedef matrix<int16_t, 3, 3> int16_t3x3;
+typedef matrix<int16_t, 3, 4> int16_t3x4;
+typedef matrix<int16_t, 4, 1> int16_t4x1;
+typedef matrix<int16_t, 4, 2> int16_t4x2;
+typedef matrix<int16_t, 4, 3> int16_t4x3;
+typedef matrix<int16_t, 4, 4> int16_t4x4;
+typedef matrix<uint16_t, 1, 1> uint16_t1x1;
+typedef matrix<uint16_t, 1, 2> uint16_t1x2;
+typedef matrix<uint16_t, 1, 3> uint16_t1x3;
+typedef matrix<uint16_t, 1, 4> uint16_t1x4;
+typedef matrix<uint16_t, 2, 1> uint16_t2x1;
+typedef matrix<uint16_t, 2, 2> uint16_t2x2;
+typedef matrix<uint16_t, 2, 3> uint16_t2x3;
+typedef matrix<uint16_t, 2, 4> uint16_t2x4;
+typedef matrix<uint16_t, 3, 1> uint16_t3x1;
+typedef matrix<uint16_t, 3, 2> uint16_t3x2;
+typedef matrix<uint16_t, 3, 3> uint16_t3x3;
+typedef matrix<uint16_t, 3, 4> uint16_t3x4;
+typedef matrix<uint16_t, 4, 1> uint16_t4x1;
+typedef matrix<uint16_t, 4, 2> uint16_t4x2;
+typedef matrix<uint16_t, 4, 3> uint16_t4x3;
+typedef matrix<uint16_t, 4, 4> uint16_t4x4;
+#endif
+typedef matrix<int, 1, 1> int1x1;
+typedef matrix<int, 1, 2> int1x2;
+typedef matrix<int, 1, 3> int1x3;
+typedef matrix<int, 1, 4> int1x4;
+typedef matrix<int, 2, 1> int2x1;
+typedef matrix<int, 2, 2> int2x2;
+typedef matrix<int, 2, 3> int2x3;
+typedef matrix<int, 2, 4> int2x4;
+typedef matrix<int, 3, 1> int3x1;
+typedef matrix<int, 3, 2> int3x2;
+typedef matrix<int, 3, 3> int3x3;
+typedef matrix<int, 3, 4> int3x4;
+typedef matrix<int, 4, 1> int4x1;
+typedef matrix<int, 4, 2> int4x2;
+typedef matrix<int, 4, 3> int4x3;
+typedef matrix<int, 4, 4> int4x4;
+typedef matrix<uint, 1, 1> uint1x1;
+typedef matrix<uint, 1, 2> uint1x2;
+typedef matrix<uint, 1, 3> uint1x3;
+typedef matrix<uint, 1, 4> uint1x4;
+typedef matrix<uint, 2, 1> uint2x1;
+typedef matrix<uint, 2, 2> uint2x2;
+typedef matrix<uint, 2, 3> uint2x3;
+typedef matrix<uint, 2, 4> uint2x4;
+typedef matrix<uint, 3, 1> uint3x1;
+typedef matrix<uint, 3, 2> uint3x2;
+typedef matrix<uint, 3, 3> uint3x3;
+typedef matrix<uint, 3, 4> uint3x4;
+typedef matrix<uint, 4, 1> uint4x1;
+typedef matrix<uint, 4, 2> uint4x2;
+typedef matrix<uint, 4, 3> uint4x3;
+typedef matrix<uint, 4, 4> uint4x4;
+typedef matrix<int32_t, 1, 1> int32_t1x1;
+typedef matrix<int32_t, 1, 2> int32_t1x2;
+typedef matrix<int32_t, 1, 3> int32_t1x3;
+typedef matrix<int32_t, 1, 4> int32_t1x4;
+typedef matrix<int32_t, 2, 1> int32_t2x1;
+typedef matrix<int32_t, 2, 2> int32_t2x2;
+typedef matrix<int32_t, 2, 3> int32_t2x3;
+typedef matrix<int32_t, 2, 4> int32_t2x4;
+typedef matrix<int32_t, 3, 1> int32_t3x1;
+typedef matrix<int32_t, 3, 2> int32_t3x2;
+typedef matrix<int32_t, 3, 3> int32_t3x3;
+typedef matrix<int32_t, 3, 4> int32_t3x4;
+typedef matrix<int32_t, 4, 1> int32_t4x1;
+typedef matrix<int32_t, 4, 2> int32_t4x2;
+typedef matrix<int32_t, 4, 3> int32_t4x3;
+typedef matrix<int32_t, 4, 4> int32_t4x4;
+typedef matrix<uint32_t, 1, 1> uint32_t1x1;
+typedef matrix<uint32_t, 1, 2> uint32_t1x2;
+typedef matrix<uint32_t, 1, 3> uint32_t1x3;
+typedef matrix<uint32_t, 1, 4> uint32_t1x4;
+typedef matrix<uint32_t, 2, 1> uint32_t2x1;
+typedef matrix<uint32_t, 2, 2> uint32_t2x2;
+typedef matrix<uint32_t, 2, 3> uint32_t2x3;
+typedef matrix<uint32_t, 2, 4> uint32_t2x4;
+typedef matrix<uint32_t, 3, 1> uint32_t3x1;
+typedef matrix<uint32_t, 3, 2> uint32_t3x2;
+typedef matrix<uint32_t, 3, 3> uint32_t3x3;
+typedef matrix<uint32_t, 3, 4> uint32_t3x4;
+typedef matrix<uint32_t, 4, 1> uint32_t4x1;
+typedef matrix<uint32_t, 4, 2> uint32_t4x2;
+typedef matrix<uint32_t, 4, 3> uint32_t4x3;
+typedef matrix<uint32_t, 4, 4> uint32_t4x4;
+typedef matrix<int64_t, 1, 1> int64_t1x1;
+typedef matrix<int64_t, 1, 2> int64_t1x2;
+typedef matrix<int64_t, 1, 3> int64_t1x3;
+typedef matrix<int64_t, 1, 4> int64_t1x4;
+typedef matrix<int64_t, 2, 1> int64_t2x1;
+typedef matrix<int64_t, 2, 2> int64_t2x2;
+typedef matrix<int64_t, 2, 3> int64_t2x3;
+typedef matrix<int64_t, 2, 4> int64_t2x4;
+typedef matrix<int64_t, 3, 1> int64_t3x1;
+typedef matrix<int64_t, 3, 2> int64_t3x2;
+typedef matrix<int64_t, 3, 3> int64_t3x3;
+typedef matrix<int64_t, 3, 4> int64_t3x4;
+typedef matrix<int64_t, 4, 1> int64_t4x1;
+typedef matrix<int64_t, 4, 2> int64_t4x2;
+typedef matrix<int64_t, 4, 3> int64_t4x3;
+typedef matrix<int64_t, 4, 4> int64_t4x4;
+typedef matrix<uint64_t, 1, 1> uint64_t1x1;
+typedef matrix<uint64_t, 1, 2> uint64_t1x2;
+typedef matrix<uint64_t, 1, 3> uint64_t1x3;
+typedef matrix<uint64_t, 1, 4> uint64_t1x4;
+typedef matrix<uint64_t, 2, 1> uint64_t2x1;
+typedef matrix<uint64_t, 2, 2> uint64_t2x2;
+typedef matrix<uint64_t, 2, 3> uint64_t2x3;
+typedef matrix<uint64_t, 2, 4> uint64_t2x4;
+typedef matrix<uint64_t, 3, 1> uint64_t3x1;
+typedef matrix<uint64_t, 3, 2> uint64_t3x2;
+typedef matrix<uint64_t, 3, 3> uint64_t3x3;
+typedef matrix<uint64_t, 3, 4> uint64_t3x4;
+typedef matrix<uint64_t, 4, 1> uint64_t4x1;
+typedef matrix<uint64_t, 4, 2> uint64_t4x2;
+typedef matrix<uint64_t, 4, 3> uint64_t4x3;
+typedef matrix<uint64_t, 4, 4> uint64_t4x4;
+
+typedef matrix<half, 1, 1> half1x1;
+typedef matrix<half, 1, 2> half1x2;
+typedef matrix<half, 1, 3> half1x3;
+typedef matrix<half, 1, 4> half1x4;
+typedef matrix<half, 2, 1> half2x1;
+typedef matrix<half, 2, 2> half2x2;
+typedef matrix<half, 2, 3> half2x3;
+typedef matrix<half, 2, 4> half2x4;
+typedef matrix<half, 3, 1> half3x1;
+typedef matrix<half, 3, 2> half3x2;
+typedef matrix<half, 3, 3> half3x3;
+typedef matrix<half, 3, 4> half3x4;
+typedef matrix<half, 4, 1> half4x1;
+typedef matrix<half, 4, 2> half4x2;
+typedef matrix<half, 4, 3> half4x3;
+typedef matrix<half, 4, 4> half4x4;
+typedef matrix<float, 1, 1> float1x1;
+typedef matrix<float, 1, 2> float1x2;
+typedef matrix<float, 1, 3> float1x3;
+typedef matrix<float, 1, 4> float1x4;
+typedef matrix<float, 2, 1> float2x1;
+typedef matrix<float, 2, 2> float2x2;
+typedef matrix<float, 2, 3> float2x3;
+typedef matrix<float, 2, 4> float2x4;
+typedef matrix<float, 3, 1> float3x1;
+typedef matrix<float, 3, 2> float3x2;
+typedef matrix<float, 3, 3> float3x3;
+typedef matrix<float, 3, 4> float3x4;
+typedef matrix<float, 4, 1> float4x1;
+typedef matrix<float, 4, 2> float4x2;
+typedef matrix<float, 4, 3> float4x3;
+typedef matrix<float, 4, 4> float4x4;
+typedef matrix<double, 1, 1> double1x1;
+typedef matrix<double, 1, 2> double1x2;
+typedef matrix<double, 1, 3> double1x3;
+typedef matrix<double, 1, 4> double1x4;
+typedef matrix<double, 2, 1> double2x1;
+typedef matrix<double, 2, 2> double2x2;
+typedef matrix<double, 2, 3> double2x3;
+typedef matrix<double, 2, 4> double2x4;
+typedef matrix<double, 3, 1> double3x1;
+typedef matrix<double, 3, 2> double3x2;
+typedef matrix<double, 3, 3> double3x3;
+typedef matrix<double, 3, 4> double3x4;
+typedef matrix<double, 4, 1> double4x1;
+typedef matrix<double, 4, 2> double4x2;
+typedef matrix<double, 4, 3> double4x3;
+typedef matrix<double, 4, 4> double4x4;
+
+#ifdef __HLSL_ENABLE_16_BIT
+typedef matrix<float16_t, 1, 1> float16_t1x1;
+typedef matrix<float16_t, 1, 2> float16_t1x2;
+typedef matrix<float16_t, 1, 3> float16_t1x3;
+typedef matrix<float16_t, 1, 4> float16_t1x4;
+typedef matrix<float16_t, 2, 1> float16_t2x1;
+typedef matrix<float16_t, 2, 2> float16_t2x2;
+typedef matrix<float16_t, 2, 3> float16_t2x3;
+typedef matrix<float16_t, 2, 4> float16_t2x4;
+typedef matrix<float16_t, 3, 1> float16_t3x1;
+typedef matrix<float16_t, 3, 2> float16_t3x2;
+typedef matrix<float16_t, 3, 3> float16_t3x3;
+typedef matrix<float16_t, 3, 4> float16_t3x4;
+typedef matrix<float16_t, 4, 1> float16_t4x1;
+typedef matrix<float16_t, 4, 2> float16_t4x2;
+typedef matrix<float16_t, 4, 3> float16_t4x3;
+typedef matrix<float16_t, 4, 4> float16_t4x4;
+#endif
+
+typedef matrix<float32_t, 1, 1> float32_t1x1;
+typedef matrix<float32_t, 1, 2> float32_t1x2;
+typedef matrix<float32_t, 1, 3> float32_t1x3;
+typedef matrix<float32_t, 1, 4> float32_t1x4;
+typedef matrix<float32_t, 2, 1> float32_t2x1;
+typedef matrix<float32_t, 2, 2> float32_t2x2;
+typedef matrix<float32_t, 2, 3> float32_t2x3;
+typedef matrix<float32_t, 2, 4> float32_t2x4;
+typedef matrix<float32_t, 3, 1> float32_t3x1;
+typedef matrix<float32_t, 3, 2> float32_t3x2;
+typedef matrix<float32_t, 3, 3> float32_t3x3;
+typedef matrix<float32_t, 3, 4> float32_t3x4;
+typedef matrix<float32_t, 4, 1> float32_t4x1;
+typedef matrix<float32_t, 4, 2> float32_t4x2;
+typedef matrix<float32_t, 4, 3> float32_t4x3;
+typedef matrix<float32_t, 4, 4> float32_t4x4;
+typedef matrix<float64_t, 1, 1> float64_t1x1;
+typedef matrix<float64_t, 1, 2> float64_t1x2;
+typedef matrix<float64_t, 1, 3> float64_t1x3;
+typedef matrix<float64_t, 1, 4> float64_t1x4;
+typedef matrix<float64_t, 2, 1> float64_t2x1;
+typedef matrix<float64_t, 2, 2> float64_t2x2;
+typedef matrix<float64_t, 2, 3> float64_t2x3;
+typedef matrix<float64_t, 2, 4> float64_t2x4;
+typedef matrix<float64_t, 3, 1> float64_t3x1;
+typedef matrix<float64_t, 3, 2> float64_t3x2;
+typedef matrix<float64_t, 3, 3> float64_t3x3;
+typedef matrix<float64_t, 3, 4> float64_t3x4;
+typedef matrix<float64_t, 4, 1> float64_t4x1;
+typedef matrix<float64_t, 4, 2> float64_t4x2;
+typedef matrix<float64_t, 4, 3> float64_t4x3;
+typedef matrix<float64_t, 4, 4> float64_t4x4;
+
 } // namespace hlsl
 
 #endif //_HLSL_HLSL_BASIC_TYPES_H_
diff --git a/clang/lib/Sema/HLSLExternalSemaSource.cpp b/clang/lib/Sema/HLSLExternalSemaSource.cpp
index 2913d16fca4823..d1a53d2ad88864 100644
--- a/clang/lib/Sema/HLSLExternalSemaSource.cpp
+++ b/clang/lib/Sema/HLSLExternalSemaSource.cpp
@@ -472,8 +472,81 @@ void HLSLExternalSemaSource::defineHLSLVectorAlias() {
   HLSLNamespace->addDecl(Template);
 }
 
+void HLSLExternalSemaSource::defineHLSLMatrixAlias() {
+  ASTContext &AST = SemaPtr->getASTContext();
+
+  llvm::SmallVector<NamedDecl *> TemplateParams;
+
+  auto *TypeParam = TemplateTypeParmDecl::Create(
+      AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 0,
+      &AST.Idents.get("element", tok::TokenKind::identifier), false, false);
+  TypeParam->setDefaultArgument(
+      AST, SemaPtr->getTrivialTemplateArgumentLoc(
+               TemplateArgument(AST.FloatTy), QualType(), SourceLocation()));
+
+  TemplateParams.emplace_back(TypeParam);
+
+  // these should be 64 bit to be consistent with other clang matrices.
+  auto *RowsParam = NonTypeTemplateParmDecl::Create(
+      AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 1,
+      &AST.Idents.get("rows_count", tok::TokenKind::identifier), AST.IntTy,
+      false, AST.getTrivialTypeSourceInfo(AST.IntTy));
+  llvm::APInt RVal(AST.getIntWidth(AST.IntTy), 4);
+  TemplateArgument RDefault(AST, llvm::APSInt(std::move(RVal)), AST.IntTy,
+                           /*IsDefaulted=*/true);
+  RowsParam->setDefaultArgument(
+      AST, SemaPtr->getTrivialTemplateArgumentLoc(RDefault, AST.IntTy,
+                                                  SourceLocation(), RowsParam));
+  TemplateParams.emplace_back(RowsParam);
+
+  auto *ColsParam = NonTypeTemplateParmDecl::Create(
+      AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 2,
+      &AST.Idents.get("cols_count", tok::TokenKind::identifier), AST.IntTy,
+      false, AST.getTrivialTypeSourceInfo(AST.IntTy));
+  llvm::APInt CVal(AST.getIntWidth(AST.IntTy), 4);
+  TemplateArgument CDefault(AST, llvm::APSInt(std::move(CVal)), AST.IntTy,
+                           /*IsDefaulted=*/true);
+  ColsParam->setDefaultArgument(
+      AST, SemaPtr->getTrivialTemplateArgumentLoc(CDefault, AST.IntTy,
+                                                  SourceLocation(), ColsParam));
+  TemplateParams.emplace_back(RowsParam);
+
+  auto *ParamList =
+      TemplateParameterList::Create(AST, SourceLocation(), SourceLocation(),
+                                    TemplateParams, SourceLocation(), nullptr);
+
+  IdentifierInfo &II = AST.Idents.get("matrix", tok::TokenKind::identifier);
+
+  QualType AliasType = AST.getDependentSizedMatrixType(
+      AST.getTemplateTypeParmType(0, 0, false, TypeParam),
+      DeclRefExpr::Create(
+          AST, NestedNameSpecifierLoc(), SourceLocation(), RowsParam, false,
+          DeclarationNameInfo(RowsParam->getDeclName(), SourceLocation()),
+          AST.IntTy, VK_LValue),
+      DeclRefExpr::Create(
+          AST, NestedNameSpecifierLoc(), SourceLocation(), ColsParam, false,
+          DeclarationNameInfo(ColsParam->getDeclName(), SourceLocation()),
+          AST.IntTy, VK_LValue),
+      SourceLocation());
+
+  auto *Record = TypeAliasDecl::Create(AST, HLSLNamespace, SourceLocation(),
+                                       SourceLocation(), &II,
+                                       AST.getTrivialTypeSourceInfo(AliasType));
+  Record->setImplicit(true);
+
+  auto *Template =
+      TypeAliasTemplateDecl::Create(AST, HLSLNamespace, SourceLocation(),
+                                    Record->getIdentifier(), ParamList, Record);
+
+  Record->setDescribedAliasTemplate(Template);
+  Template->setImplicit(true);
+  Template->setLexicalDeclContext(Record->getDeclContext());
+  HLSLNamespace->addDecl(Template);
+}
+
 void HLSLExternalSemaSource::defineTrivialHLSLTypes() {
   defineHLSLVectorAlias();
+  defineHLSLMatrixAlias();
 }
 
 /// Set up common members and attributes for buffer types
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index c44fc9c4194ca4..9213b4d95a70d9 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -2447,7 +2447,7 @@ QualType Sema::BuildExtVectorType(QualType T, Expr *ArraySize,
 
 QualType Sema::BuildMatrixType(QualType ElementTy, Expr *NumRows, Expr *NumCols,
                                SourceLocation AttrLoc) {
-  assert(Context.getLangOpts().MatrixTypes &&
+  assert((getLangOpts().MatrixTypes || getLangOpts().HLSL) &&
          "Should never build a matrix type when it is disabled");
 
   // Check element type, if it is not dependent.
diff --git a/clang/test/AST/HLSL/matrix-alias.hlsl b/clang/test/AST/HLSL/matrix-alias.hlsl
new file mode 100644
index 00000000000000..afac2cfed7604b
--- /dev/null
+++ b/clang/test/AST/HLSL/matrix-alias.hlsl
@@ -0,0 +1,49 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -ast-dump -o - %s | FileCheck %s
+
+// Test that matrix aliases are set up properly for HLSL
+
+// CHECK: NamespaceDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit hlsl
+// CHECK-NEXT: TypeAliasTemplateDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit vector
+// CHECK-NEXT: TemplateTypeParmDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> class depth 0 index 0 element
+// CHECK-NEXT: TemplateArgument type 'float'
+// CHECK-NEXT: BuiltinType 0x{{[0-9a-fA-F]+}} 'float'
+// CHECK-NEXT: NonTypeTemplateParmDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> 'int' depth 0 index 1 element_count
+// CHECK-NEXT: TemplateArgument expr
+// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <<invalid sloc>> 'int' 4
+// CHECK-NEXT: TypeAliasDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit vector 'vector<element, element_count>'
+// CHECK-NEXT: DependentSizedExtVectorType 0x{{[0-9a-fA-F]+}} 'vector<element, element_count>' dependent <invalid sloc>
+// CHECK-NEXT: TemplateTypeParmType 0x{{[0-9a-fA-F]+}} 'element' dependent depth 0 index 0
+// CHECK-NEXT: TemplateTypeParm 0x{{[0-9a-fA-F]+}} 'element'
+// CHECK-NEXT: DeclRefExpr 0x{{[0-9a-fA-F]+}} <<invalid sloc>> 'int' lvalue
+// CHECK-SAME: NonTypeTemplateParm 0x{{[0-9a-fA-F]+}} 'element_count' 'int'
+
+// Make sure we got a using directive at the end.
+// CHECK: UsingDirectiveDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> Namespace 0x{{[0-9a-fA-F]+}} 'hlsl'
+
+[numthreads(1,1,1)]
+int entry() {
+  // Verify that the alias is generated inside the hlsl namespace.
+  hlsl::matrix<float, 2, 2> Mat2x2;
+
+  // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:26:3, col:35>
+  // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:29> col:29 Mat2x2 'hlsl::matrix<float, 2, 2>'
+
+  // Verify that you don't need to specify the namespace.
+  matrix<int, 2, 2> Vec2x2a;
+
+  // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:32:3, col:28>
+  // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:21> col:21 Vec2x2a 'matrix<int, 2, 2>'
+
+  // Build a bigger matrix.
+  matrix<double, 4, 4> Mat4x4;
+
+  // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:38:3, col:30>
+  // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:24> col:24 Mat4x4 'matrix<double, 4, 4>'
+
+  // Verify that the implicit arguments generate the correct type.
+  matrix<> ImpM...
[truncated]

if (Policy.UseHLSLTypes)
OS << ">";
else
OS << ")))";
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes to this file do alter how pointer and reference matrices are printed. I'd like to get @fhahn's opinion on this even if on nothing else.

printBefore(T->getElementType(), OS);
OS << " __attribute__((matrix_type(";
OS << T->getNumRows() << ", " << T->getNumColumns();
OS << ")))";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change seems wrong to me. IIUC this changes something like float __attribute__((matrix_type(4, 4)))* to float * __attribute__((matrix_type(4, 4))), which would mean the element type is float* rather than float.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my response to Florian here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may be simpler to read to duplicate some code but have the HSL and C++ matrixes types printed completely separately?

@@ -1381,7 +1381,7 @@ void ASTContext::InitBuiltinTypes(const TargetInfo &Target,
if (LangOpts.OpenACC && !LangOpts.OpenMP) {
InitBuiltinType(ArraySectionTy, BuiltinType::ArraySection);
}
if (LangOpts.MatrixTypes)
if (LangOpts.MatrixTypes || LangOpts.HLSL)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is the only line this will be needed then feel free to ignore this comment.

My question is could there be places where LangOpts.MatrixTypes could diverge from HLSL?

If so instead is it possible for LangOpts.HLSL to turn on LangOpts.MatrixTypes?

Copy link
Collaborator

@llvm-beanz llvm-beanz Oct 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can add MatrixTypes to the definition of HLSL's variants in LangStandards.def, which is probably a safer approach to this.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll look into that. I didn't want to enable matrices by default because I didn't really want to allow the matrix attribute syntax in HLSL.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've done this two ways now:
ae1a58b just enables matrices when HLSL is enabled
eeb3165 attempts to use LangStandards and make it dependent on HLSL versions

I confess I don't know why the second is preferable to the first. If I've misunderstood @llvm-beanz's suggestion, please let me know.

I'm actually unclear on how any of this prevents the problem @farzonl brought up. If what MatrixTypes represents changes so that it no longer matches what HLSL wants, then those changes will seep in without our notice. Given that the design discussion we had about this rejected the notion that we'd forge our own matrix type that inherited from the clang type, I don't see any way around that problem unless we pay attention to its evolution and weigh in when needed. Not that I'm saying that's a bad idea regardless.

All that the above changes alter is that now matrices can be declared using the clang extension which will sidestep at least the size restrictions we have not yet applied, but intend to.

I'm unconvinced this is an improvement, but it was easy enough to do that I thought I'd give us more concrete options to discuss.

@@ -17,12 +17,12 @@ void add(sx10x10_t a, sx5x10_t b, sx10x5_t c) {
// expected-error@-1 {{assigning to 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') from incompatible type 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))')}}

a = b + &c;
// expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'sx10x5_t *' (aka 'float __attribute__((matrix_type(10, 5)))*'))}}
// expected-error@-2 {{casting 'sx10x5_t *' (aka 'float __attribute__((matrix_type(10, 5)))*') to incompatible type 'float'}}
// expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'sx10x5_t *' (aka 'float * __attribute__((matrix_type(10, 5)))'))}}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change intentional? IIUC the patch should be NFC for the existing matrix type support?

Copy link
Author

@pow2clk pow2clk Oct 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @fhahn!

Thanks for taking a look at this!

It was intentional, but tentative. This is the consequence of the change I commented on here hoping you'd weigh in on. I made the change this way in TypePrinter because it was more convenient to insert the HLSL printing style that way, but I can maintain the previous behavior. The reason I thought that perhaps the alternate position of * might be correct or preferable is because it is the way that vectors are printed with this example.

At any rate, we can defer the discussion of whether the vector printing should conform with the matrix printing or vice versa for later. For now, I'll implement the printing so that it maintains previous behavior.

edit: came up with a less contrived example.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another oddity in how pointers to matrices were printed before was a double space between the type and the attribute. It's minor, but it suggested something was missing there.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new printing is illegal. You've changed it from a matrix of float to a matrix of float*. The type printer shouldn't be printing types in a way that changes their meaning.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, then perhaps we should look into how vectors are printed.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created an issue for this: #114868 It's not really related to HLSL as we don't have pointers or references (yet). I have a tentative fix, but I'd like to see if anyone has thoughts on the further impact it will have on existing type printing.

Greg Roth added 9 commits October 10, 2024 15:50
There is an inconsistency in how matrix and vector pointers and references are printed at present.

Also back out a change that's probably better independent than as a rider
Instead of adding an HLSL component to the checks for matrix types,
this just enables them whenever an HLSL version that has that explicitly set is chosen
Future uncertain. Better to test with native HLSL types anyway
This reverts commit eeb3165.

Using LangOpts->MatrixTypes for both the default value and the keypath meant that when the lang opts were regenerated, it found them the same and determined that replicating the enable-matrix flag wasn't necessary. Using the HLSL keypath differentiates them so that it will know that they need to be regenerated.
@@ -1,7 +1,5 @@
// RUN: %clang_cc1 -O0 -fenable-matrix -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - -std=c++11 | FileCheck --check-prefixes=CHECK,NOOPT %s
// RUN: %clang_cc1 -O1 -fenable-matrix -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - -std=c++11 | FileCheck --check-prefixes=CHECK,OPT %s
typedef double dx5x5_t __attribute__((matrix_type(5, 5)));
using fx2x3_t = float __attribute__((matrix_type(2, 3)));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the changes to this file required?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really, but I noted that they weren't being used. I think the other changes to this file are useful as well, but I'm happy to remove them to expedite things.

if (Policy.UseHLSLTypes)
OS << ", ";
else
OS << " __attribute__((matrix_type(";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You still have this in the "after" function. That will print this after pointer annotations which changes its meaning incorrectly.

I'm guessing there isn't an existing test case for this, otherwise this change would have broken something.

cc: @fhahn

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. I guess I missed this one. AFAICT, DependentSizedMatrices are never printed.

Greg Roth added 5 commits November 5, 2024 13:20
To make this perfectly limited to HLSL, revert matrix cpp test

Revise matrix printing so that should they have pointers or references in HLSL, they will be printed properly as well
Now that DXIL uses the itanium C++ ABI, it can use some of the tests that were SPIRV only before
@damyanp damyanp requested a review from llvm-beanz November 25, 2024 23:30
@damyanp damyanp requested a review from farzonl November 25, 2024 23:30
@@ -2447,7 +2447,7 @@ QualType Sema::BuildExtVectorType(QualType T, Expr *ArraySize,

QualType Sema::BuildMatrixType(QualType ElementTy, Expr *NumRows, Expr *NumCols,
SourceLocation AttrLoc) {
assert(Context.getLangOpts().MatrixTypes &&
assert(getLangOpts().MatrixTypes &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated?

@@ -459,8 +459,81 @@ void HLSLExternalSemaSource::defineHLSLVectorAlias() {
HLSLNamespace->addDecl(Template);
}

void HLSLExternalSemaSource::defineHLSLMatrixAlias() {
ASTContext &AST = SemaPtr->getASTContext();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment here explaining what alias is added?

printBefore(T->getElementType(), OS);
OS << " __attribute__((matrix_type(";
OS << T->getNumRows() << ", " << T->getNumColumns();
OS << ")))";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may be simpler to read to duplicate some code but have the HSL and C++ matrixes types printed completely separately?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:X86 clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics clang Clang issues not falling into any other category HLSL HLSL Language Support
Projects
Status: No status
Development

Successfully merging this pull request may close these issues.

[HLSL] Enable clang extension matrices in HLSL
5 participants