-
Notifications
You must be signed in to change notification settings - Fork 14.8k
Enable matrices in HLSL #111415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Enable matrices in HLSL #111415
Conversation
HLSL needs matrix support. This allows matrices when the target language is HLSL and defines a matrix template alias that allows the short forms of matrix types to be defined in typedefs in a default header. Makes some changes to how matrices are printed in HLSL policy. This required a tweak to how existing matrices are printed in diagnostics that is more consistent with other types. These matrix types will function exactly as the clang matrix extension dictates. Alterations to that behavior both specific to HLSL and also potentially expanding on the matrix extension will follow. fixes llvm#109839
@llvm/pr-subscribers-hlsl @llvm/pr-subscribers-clang Author: Greg Roth (pow2clk) ChangesHLSL needs matrix support. This allows matrices when the target language is HLSL and defines a matrix template alias that allows the short forms of matrix types to be defined in typedefs in a default header. Makes some changes to how matrices are printed in HLSL policy. This required a tweak to how existing matrices are printed in diagnostics that is more consistent with other types. These matrix types will function exactly as the clang matrix extension dictates. Alterations to that behavior both specific to HLSL and also potentially expanding on the matrix extension will follow. fixes #109839 Patch is 222.61 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/111415.diff 25 Files Affected:
diff --git a/clang/include/clang/Sema/HLSLExternalSemaSource.h b/clang/include/clang/Sema/HLSLExternalSemaSource.h
index 3c7495e66055dc..6f4b72045a9464 100644
--- a/clang/include/clang/Sema/HLSLExternalSemaSource.h
+++ b/clang/include/clang/Sema/HLSLExternalSemaSource.h
@@ -28,6 +28,7 @@ class HLSLExternalSemaSource : public ExternalSemaSource {
llvm::DenseMap<CXXRecordDecl *, CompletionFunction> Completions;
void defineHLSLVectorAlias();
+ void defineHLSLMatrixAlias();
void defineTrivialHLSLTypes();
void defineHLSLTypesWithForwardDeclarations();
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index a81429ad6a2380..ed10c210ed170f 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -1381,7 +1381,7 @@ void ASTContext::InitBuiltinTypes(const TargetInfo &Target,
if (LangOpts.OpenACC && !LangOpts.OpenMP) {
InitBuiltinType(ArraySectionTy, BuiltinType::ArraySection);
}
- if (LangOpts.MatrixTypes)
+ if (LangOpts.MatrixTypes || LangOpts.HLSL)
InitBuiltinType(IncompleteMatrixIdxTy, BuiltinType::IncompleteMatrixIdx);
// Builtin types for 'id', 'Class', and 'SEL'.
diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp
index ca75bb97c158e1..142717201557f3 100644
--- a/clang/lib/AST/TypePrinter.cpp
+++ b/clang/lib/AST/TypePrinter.cpp
@@ -852,34 +852,50 @@ void TypePrinter::printExtVectorAfter(const ExtVectorType *T, raw_ostream &OS) {
void TypePrinter::printConstantMatrixBefore(const ConstantMatrixType *T,
raw_ostream &OS) {
+ if (Policy.UseHLSLTypes)
+ OS << "matrix<";
printBefore(T->getElementType(), OS);
- OS << " __attribute__((matrix_type(";
- OS << T->getNumRows() << ", " << T->getNumColumns();
- OS << ")))";
}
void TypePrinter::printConstantMatrixAfter(const ConstantMatrixType *T,
raw_ostream &OS) {
printAfter(T->getElementType(), OS);
+ if (Policy.UseHLSLTypes) {
+ OS << ", ";
+ OS << T->getNumRows() << ", " << T->getNumColumns();
+ OS << ">";
+ } else {
+ OS << " __attribute__((matrix_type(";
+ OS << T->getNumRows() << ", " << T->getNumColumns();
+ OS << ")))";
+ }
}
void TypePrinter::printDependentSizedMatrixBefore(
const DependentSizedMatrixType *T, raw_ostream &OS) {
+ if (Policy.UseHLSLTypes)
+ OS << "matrix<";
printBefore(T->getElementType(), OS);
- OS << " __attribute__((matrix_type(";
- if (T->getRowExpr()) {
- T->getRowExpr()->printPretty(OS, nullptr, Policy);
- }
- OS << ", ";
- if (T->getColumnExpr()) {
- T->getColumnExpr()->printPretty(OS, nullptr, Policy);
- }
- OS << ")))";
}
void TypePrinter::printDependentSizedMatrixAfter(
const DependentSizedMatrixType *T, raw_ostream &OS) {
printAfter(T->getElementType(), OS);
+ if (Policy.UseHLSLTypes)
+ OS << ", ";
+ else
+ OS << " __attribute__((matrix_type(";
+
+ if (Expr *E = T->getRowExpr())
+ E->printPretty(OS, nullptr, Policy);
+ OS << ", ";
+ if (Expr *E = T->getColumnExpr())
+ E->printPretty(OS, nullptr, Policy);
+
+ if (Policy.UseHLSLTypes)
+ OS << ">";
+ else
+ OS << ")))";
}
void
diff --git a/clang/lib/Headers/hlsl/hlsl_basic_types.h b/clang/lib/Headers/hlsl/hlsl_basic_types.h
index eff94e0d7f9500..b6eeffa2f5e362 100644
--- a/clang/lib/Headers/hlsl/hlsl_basic_types.h
+++ b/clang/lib/Headers/hlsl/hlsl_basic_types.h
@@ -115,6 +115,238 @@ typedef vector<float64_t, 2> float64_t2;
typedef vector<float64_t, 3> float64_t3;
typedef vector<float64_t, 4> float64_t4;
+#ifdef __HLSL_ENABLE_16_BIT
+typedef matrix<int16_t, 1, 1> int16_t1x1;
+typedef matrix<int16_t, 1, 2> int16_t1x2;
+typedef matrix<int16_t, 1, 3> int16_t1x3;
+typedef matrix<int16_t, 1, 4> int16_t1x4;
+typedef matrix<int16_t, 2, 1> int16_t2x1;
+typedef matrix<int16_t, 2, 2> int16_t2x2;
+typedef matrix<int16_t, 2, 3> int16_t2x3;
+typedef matrix<int16_t, 2, 4> int16_t2x4;
+typedef matrix<int16_t, 3, 1> int16_t3x1;
+typedef matrix<int16_t, 3, 2> int16_t3x2;
+typedef matrix<int16_t, 3, 3> int16_t3x3;
+typedef matrix<int16_t, 3, 4> int16_t3x4;
+typedef matrix<int16_t, 4, 1> int16_t4x1;
+typedef matrix<int16_t, 4, 2> int16_t4x2;
+typedef matrix<int16_t, 4, 3> int16_t4x3;
+typedef matrix<int16_t, 4, 4> int16_t4x4;
+typedef matrix<uint16_t, 1, 1> uint16_t1x1;
+typedef matrix<uint16_t, 1, 2> uint16_t1x2;
+typedef matrix<uint16_t, 1, 3> uint16_t1x3;
+typedef matrix<uint16_t, 1, 4> uint16_t1x4;
+typedef matrix<uint16_t, 2, 1> uint16_t2x1;
+typedef matrix<uint16_t, 2, 2> uint16_t2x2;
+typedef matrix<uint16_t, 2, 3> uint16_t2x3;
+typedef matrix<uint16_t, 2, 4> uint16_t2x4;
+typedef matrix<uint16_t, 3, 1> uint16_t3x1;
+typedef matrix<uint16_t, 3, 2> uint16_t3x2;
+typedef matrix<uint16_t, 3, 3> uint16_t3x3;
+typedef matrix<uint16_t, 3, 4> uint16_t3x4;
+typedef matrix<uint16_t, 4, 1> uint16_t4x1;
+typedef matrix<uint16_t, 4, 2> uint16_t4x2;
+typedef matrix<uint16_t, 4, 3> uint16_t4x3;
+typedef matrix<uint16_t, 4, 4> uint16_t4x4;
+#endif
+typedef matrix<int, 1, 1> int1x1;
+typedef matrix<int, 1, 2> int1x2;
+typedef matrix<int, 1, 3> int1x3;
+typedef matrix<int, 1, 4> int1x4;
+typedef matrix<int, 2, 1> int2x1;
+typedef matrix<int, 2, 2> int2x2;
+typedef matrix<int, 2, 3> int2x3;
+typedef matrix<int, 2, 4> int2x4;
+typedef matrix<int, 3, 1> int3x1;
+typedef matrix<int, 3, 2> int3x2;
+typedef matrix<int, 3, 3> int3x3;
+typedef matrix<int, 3, 4> int3x4;
+typedef matrix<int, 4, 1> int4x1;
+typedef matrix<int, 4, 2> int4x2;
+typedef matrix<int, 4, 3> int4x3;
+typedef matrix<int, 4, 4> int4x4;
+typedef matrix<uint, 1, 1> uint1x1;
+typedef matrix<uint, 1, 2> uint1x2;
+typedef matrix<uint, 1, 3> uint1x3;
+typedef matrix<uint, 1, 4> uint1x4;
+typedef matrix<uint, 2, 1> uint2x1;
+typedef matrix<uint, 2, 2> uint2x2;
+typedef matrix<uint, 2, 3> uint2x3;
+typedef matrix<uint, 2, 4> uint2x4;
+typedef matrix<uint, 3, 1> uint3x1;
+typedef matrix<uint, 3, 2> uint3x2;
+typedef matrix<uint, 3, 3> uint3x3;
+typedef matrix<uint, 3, 4> uint3x4;
+typedef matrix<uint, 4, 1> uint4x1;
+typedef matrix<uint, 4, 2> uint4x2;
+typedef matrix<uint, 4, 3> uint4x3;
+typedef matrix<uint, 4, 4> uint4x4;
+typedef matrix<int32_t, 1, 1> int32_t1x1;
+typedef matrix<int32_t, 1, 2> int32_t1x2;
+typedef matrix<int32_t, 1, 3> int32_t1x3;
+typedef matrix<int32_t, 1, 4> int32_t1x4;
+typedef matrix<int32_t, 2, 1> int32_t2x1;
+typedef matrix<int32_t, 2, 2> int32_t2x2;
+typedef matrix<int32_t, 2, 3> int32_t2x3;
+typedef matrix<int32_t, 2, 4> int32_t2x4;
+typedef matrix<int32_t, 3, 1> int32_t3x1;
+typedef matrix<int32_t, 3, 2> int32_t3x2;
+typedef matrix<int32_t, 3, 3> int32_t3x3;
+typedef matrix<int32_t, 3, 4> int32_t3x4;
+typedef matrix<int32_t, 4, 1> int32_t4x1;
+typedef matrix<int32_t, 4, 2> int32_t4x2;
+typedef matrix<int32_t, 4, 3> int32_t4x3;
+typedef matrix<int32_t, 4, 4> int32_t4x4;
+typedef matrix<uint32_t, 1, 1> uint32_t1x1;
+typedef matrix<uint32_t, 1, 2> uint32_t1x2;
+typedef matrix<uint32_t, 1, 3> uint32_t1x3;
+typedef matrix<uint32_t, 1, 4> uint32_t1x4;
+typedef matrix<uint32_t, 2, 1> uint32_t2x1;
+typedef matrix<uint32_t, 2, 2> uint32_t2x2;
+typedef matrix<uint32_t, 2, 3> uint32_t2x3;
+typedef matrix<uint32_t, 2, 4> uint32_t2x4;
+typedef matrix<uint32_t, 3, 1> uint32_t3x1;
+typedef matrix<uint32_t, 3, 2> uint32_t3x2;
+typedef matrix<uint32_t, 3, 3> uint32_t3x3;
+typedef matrix<uint32_t, 3, 4> uint32_t3x4;
+typedef matrix<uint32_t, 4, 1> uint32_t4x1;
+typedef matrix<uint32_t, 4, 2> uint32_t4x2;
+typedef matrix<uint32_t, 4, 3> uint32_t4x3;
+typedef matrix<uint32_t, 4, 4> uint32_t4x4;
+typedef matrix<int64_t, 1, 1> int64_t1x1;
+typedef matrix<int64_t, 1, 2> int64_t1x2;
+typedef matrix<int64_t, 1, 3> int64_t1x3;
+typedef matrix<int64_t, 1, 4> int64_t1x4;
+typedef matrix<int64_t, 2, 1> int64_t2x1;
+typedef matrix<int64_t, 2, 2> int64_t2x2;
+typedef matrix<int64_t, 2, 3> int64_t2x3;
+typedef matrix<int64_t, 2, 4> int64_t2x4;
+typedef matrix<int64_t, 3, 1> int64_t3x1;
+typedef matrix<int64_t, 3, 2> int64_t3x2;
+typedef matrix<int64_t, 3, 3> int64_t3x3;
+typedef matrix<int64_t, 3, 4> int64_t3x4;
+typedef matrix<int64_t, 4, 1> int64_t4x1;
+typedef matrix<int64_t, 4, 2> int64_t4x2;
+typedef matrix<int64_t, 4, 3> int64_t4x3;
+typedef matrix<int64_t, 4, 4> int64_t4x4;
+typedef matrix<uint64_t, 1, 1> uint64_t1x1;
+typedef matrix<uint64_t, 1, 2> uint64_t1x2;
+typedef matrix<uint64_t, 1, 3> uint64_t1x3;
+typedef matrix<uint64_t, 1, 4> uint64_t1x4;
+typedef matrix<uint64_t, 2, 1> uint64_t2x1;
+typedef matrix<uint64_t, 2, 2> uint64_t2x2;
+typedef matrix<uint64_t, 2, 3> uint64_t2x3;
+typedef matrix<uint64_t, 2, 4> uint64_t2x4;
+typedef matrix<uint64_t, 3, 1> uint64_t3x1;
+typedef matrix<uint64_t, 3, 2> uint64_t3x2;
+typedef matrix<uint64_t, 3, 3> uint64_t3x3;
+typedef matrix<uint64_t, 3, 4> uint64_t3x4;
+typedef matrix<uint64_t, 4, 1> uint64_t4x1;
+typedef matrix<uint64_t, 4, 2> uint64_t4x2;
+typedef matrix<uint64_t, 4, 3> uint64_t4x3;
+typedef matrix<uint64_t, 4, 4> uint64_t4x4;
+
+typedef matrix<half, 1, 1> half1x1;
+typedef matrix<half, 1, 2> half1x2;
+typedef matrix<half, 1, 3> half1x3;
+typedef matrix<half, 1, 4> half1x4;
+typedef matrix<half, 2, 1> half2x1;
+typedef matrix<half, 2, 2> half2x2;
+typedef matrix<half, 2, 3> half2x3;
+typedef matrix<half, 2, 4> half2x4;
+typedef matrix<half, 3, 1> half3x1;
+typedef matrix<half, 3, 2> half3x2;
+typedef matrix<half, 3, 3> half3x3;
+typedef matrix<half, 3, 4> half3x4;
+typedef matrix<half, 4, 1> half4x1;
+typedef matrix<half, 4, 2> half4x2;
+typedef matrix<half, 4, 3> half4x3;
+typedef matrix<half, 4, 4> half4x4;
+typedef matrix<float, 1, 1> float1x1;
+typedef matrix<float, 1, 2> float1x2;
+typedef matrix<float, 1, 3> float1x3;
+typedef matrix<float, 1, 4> float1x4;
+typedef matrix<float, 2, 1> float2x1;
+typedef matrix<float, 2, 2> float2x2;
+typedef matrix<float, 2, 3> float2x3;
+typedef matrix<float, 2, 4> float2x4;
+typedef matrix<float, 3, 1> float3x1;
+typedef matrix<float, 3, 2> float3x2;
+typedef matrix<float, 3, 3> float3x3;
+typedef matrix<float, 3, 4> float3x4;
+typedef matrix<float, 4, 1> float4x1;
+typedef matrix<float, 4, 2> float4x2;
+typedef matrix<float, 4, 3> float4x3;
+typedef matrix<float, 4, 4> float4x4;
+typedef matrix<double, 1, 1> double1x1;
+typedef matrix<double, 1, 2> double1x2;
+typedef matrix<double, 1, 3> double1x3;
+typedef matrix<double, 1, 4> double1x4;
+typedef matrix<double, 2, 1> double2x1;
+typedef matrix<double, 2, 2> double2x2;
+typedef matrix<double, 2, 3> double2x3;
+typedef matrix<double, 2, 4> double2x4;
+typedef matrix<double, 3, 1> double3x1;
+typedef matrix<double, 3, 2> double3x2;
+typedef matrix<double, 3, 3> double3x3;
+typedef matrix<double, 3, 4> double3x4;
+typedef matrix<double, 4, 1> double4x1;
+typedef matrix<double, 4, 2> double4x2;
+typedef matrix<double, 4, 3> double4x3;
+typedef matrix<double, 4, 4> double4x4;
+
+#ifdef __HLSL_ENABLE_16_BIT
+typedef matrix<float16_t, 1, 1> float16_t1x1;
+typedef matrix<float16_t, 1, 2> float16_t1x2;
+typedef matrix<float16_t, 1, 3> float16_t1x3;
+typedef matrix<float16_t, 1, 4> float16_t1x4;
+typedef matrix<float16_t, 2, 1> float16_t2x1;
+typedef matrix<float16_t, 2, 2> float16_t2x2;
+typedef matrix<float16_t, 2, 3> float16_t2x3;
+typedef matrix<float16_t, 2, 4> float16_t2x4;
+typedef matrix<float16_t, 3, 1> float16_t3x1;
+typedef matrix<float16_t, 3, 2> float16_t3x2;
+typedef matrix<float16_t, 3, 3> float16_t3x3;
+typedef matrix<float16_t, 3, 4> float16_t3x4;
+typedef matrix<float16_t, 4, 1> float16_t4x1;
+typedef matrix<float16_t, 4, 2> float16_t4x2;
+typedef matrix<float16_t, 4, 3> float16_t4x3;
+typedef matrix<float16_t, 4, 4> float16_t4x4;
+#endif
+
+typedef matrix<float32_t, 1, 1> float32_t1x1;
+typedef matrix<float32_t, 1, 2> float32_t1x2;
+typedef matrix<float32_t, 1, 3> float32_t1x3;
+typedef matrix<float32_t, 1, 4> float32_t1x4;
+typedef matrix<float32_t, 2, 1> float32_t2x1;
+typedef matrix<float32_t, 2, 2> float32_t2x2;
+typedef matrix<float32_t, 2, 3> float32_t2x3;
+typedef matrix<float32_t, 2, 4> float32_t2x4;
+typedef matrix<float32_t, 3, 1> float32_t3x1;
+typedef matrix<float32_t, 3, 2> float32_t3x2;
+typedef matrix<float32_t, 3, 3> float32_t3x3;
+typedef matrix<float32_t, 3, 4> float32_t3x4;
+typedef matrix<float32_t, 4, 1> float32_t4x1;
+typedef matrix<float32_t, 4, 2> float32_t4x2;
+typedef matrix<float32_t, 4, 3> float32_t4x3;
+typedef matrix<float32_t, 4, 4> float32_t4x4;
+typedef matrix<float64_t, 1, 1> float64_t1x1;
+typedef matrix<float64_t, 1, 2> float64_t1x2;
+typedef matrix<float64_t, 1, 3> float64_t1x3;
+typedef matrix<float64_t, 1, 4> float64_t1x4;
+typedef matrix<float64_t, 2, 1> float64_t2x1;
+typedef matrix<float64_t, 2, 2> float64_t2x2;
+typedef matrix<float64_t, 2, 3> float64_t2x3;
+typedef matrix<float64_t, 2, 4> float64_t2x4;
+typedef matrix<float64_t, 3, 1> float64_t3x1;
+typedef matrix<float64_t, 3, 2> float64_t3x2;
+typedef matrix<float64_t, 3, 3> float64_t3x3;
+typedef matrix<float64_t, 3, 4> float64_t3x4;
+typedef matrix<float64_t, 4, 1> float64_t4x1;
+typedef matrix<float64_t, 4, 2> float64_t4x2;
+typedef matrix<float64_t, 4, 3> float64_t4x3;
+typedef matrix<float64_t, 4, 4> float64_t4x4;
+
} // namespace hlsl
#endif //_HLSL_HLSL_BASIC_TYPES_H_
diff --git a/clang/lib/Sema/HLSLExternalSemaSource.cpp b/clang/lib/Sema/HLSLExternalSemaSource.cpp
index 2913d16fca4823..d1a53d2ad88864 100644
--- a/clang/lib/Sema/HLSLExternalSemaSource.cpp
+++ b/clang/lib/Sema/HLSLExternalSemaSource.cpp
@@ -472,8 +472,81 @@ void HLSLExternalSemaSource::defineHLSLVectorAlias() {
HLSLNamespace->addDecl(Template);
}
+void HLSLExternalSemaSource::defineHLSLMatrixAlias() {
+ ASTContext &AST = SemaPtr->getASTContext();
+
+ llvm::SmallVector<NamedDecl *> TemplateParams;
+
+ auto *TypeParam = TemplateTypeParmDecl::Create(
+ AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 0,
+ &AST.Idents.get("element", tok::TokenKind::identifier), false, false);
+ TypeParam->setDefaultArgument(
+ AST, SemaPtr->getTrivialTemplateArgumentLoc(
+ TemplateArgument(AST.FloatTy), QualType(), SourceLocation()));
+
+ TemplateParams.emplace_back(TypeParam);
+
+ // these should be 64 bit to be consistent with other clang matrices.
+ auto *RowsParam = NonTypeTemplateParmDecl::Create(
+ AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 1,
+ &AST.Idents.get("rows_count", tok::TokenKind::identifier), AST.IntTy,
+ false, AST.getTrivialTypeSourceInfo(AST.IntTy));
+ llvm::APInt RVal(AST.getIntWidth(AST.IntTy), 4);
+ TemplateArgument RDefault(AST, llvm::APSInt(std::move(RVal)), AST.IntTy,
+ /*IsDefaulted=*/true);
+ RowsParam->setDefaultArgument(
+ AST, SemaPtr->getTrivialTemplateArgumentLoc(RDefault, AST.IntTy,
+ SourceLocation(), RowsParam));
+ TemplateParams.emplace_back(RowsParam);
+
+ auto *ColsParam = NonTypeTemplateParmDecl::Create(
+ AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 2,
+ &AST.Idents.get("cols_count", tok::TokenKind::identifier), AST.IntTy,
+ false, AST.getTrivialTypeSourceInfo(AST.IntTy));
+ llvm::APInt CVal(AST.getIntWidth(AST.IntTy), 4);
+ TemplateArgument CDefault(AST, llvm::APSInt(std::move(CVal)), AST.IntTy,
+ /*IsDefaulted=*/true);
+ ColsParam->setDefaultArgument(
+ AST, SemaPtr->getTrivialTemplateArgumentLoc(CDefault, AST.IntTy,
+ SourceLocation(), ColsParam));
+ TemplateParams.emplace_back(RowsParam);
+
+ auto *ParamList =
+ TemplateParameterList::Create(AST, SourceLocation(), SourceLocation(),
+ TemplateParams, SourceLocation(), nullptr);
+
+ IdentifierInfo &II = AST.Idents.get("matrix", tok::TokenKind::identifier);
+
+ QualType AliasType = AST.getDependentSizedMatrixType(
+ AST.getTemplateTypeParmType(0, 0, false, TypeParam),
+ DeclRefExpr::Create(
+ AST, NestedNameSpecifierLoc(), SourceLocation(), RowsParam, false,
+ DeclarationNameInfo(RowsParam->getDeclName(), SourceLocation()),
+ AST.IntTy, VK_LValue),
+ DeclRefExpr::Create(
+ AST, NestedNameSpecifierLoc(), SourceLocation(), ColsParam, false,
+ DeclarationNameInfo(ColsParam->getDeclName(), SourceLocation()),
+ AST.IntTy, VK_LValue),
+ SourceLocation());
+
+ auto *Record = TypeAliasDecl::Create(AST, HLSLNamespace, SourceLocation(),
+ SourceLocation(), &II,
+ AST.getTrivialTypeSourceInfo(AliasType));
+ Record->setImplicit(true);
+
+ auto *Template =
+ TypeAliasTemplateDecl::Create(AST, HLSLNamespace, SourceLocation(),
+ Record->getIdentifier(), ParamList, Record);
+
+ Record->setDescribedAliasTemplate(Template);
+ Template->setImplicit(true);
+ Template->setLexicalDeclContext(Record->getDeclContext());
+ HLSLNamespace->addDecl(Template);
+}
+
void HLSLExternalSemaSource::defineTrivialHLSLTypes() {
defineHLSLVectorAlias();
+ defineHLSLMatrixAlias();
}
/// Set up common members and attributes for buffer types
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index c44fc9c4194ca4..9213b4d95a70d9 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -2447,7 +2447,7 @@ QualType Sema::BuildExtVectorType(QualType T, Expr *ArraySize,
QualType Sema::BuildMatrixType(QualType ElementTy, Expr *NumRows, Expr *NumCols,
SourceLocation AttrLoc) {
- assert(Context.getLangOpts().MatrixTypes &&
+ assert((getLangOpts().MatrixTypes || getLangOpts().HLSL) &&
"Should never build a matrix type when it is disabled");
// Check element type, if it is not dependent.
diff --git a/clang/test/AST/HLSL/matrix-alias.hlsl b/clang/test/AST/HLSL/matrix-alias.hlsl
new file mode 100644
index 00000000000000..afac2cfed7604b
--- /dev/null
+++ b/clang/test/AST/HLSL/matrix-alias.hlsl
@@ -0,0 +1,49 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -ast-dump -o - %s | FileCheck %s
+
+// Test that matrix aliases are set up properly for HLSL
+
+// CHECK: NamespaceDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit hlsl
+// CHECK-NEXT: TypeAliasTemplateDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit vector
+// CHECK-NEXT: TemplateTypeParmDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> class depth 0 index 0 element
+// CHECK-NEXT: TemplateArgument type 'float'
+// CHECK-NEXT: BuiltinType 0x{{[0-9a-fA-F]+}} 'float'
+// CHECK-NEXT: NonTypeTemplateParmDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> 'int' depth 0 index 1 element_count
+// CHECK-NEXT: TemplateArgument expr
+// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <<invalid sloc>> 'int' 4
+// CHECK-NEXT: TypeAliasDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit vector 'vector<element, element_count>'
+// CHECK-NEXT: DependentSizedExtVectorType 0x{{[0-9a-fA-F]+}} 'vector<element, element_count>' dependent <invalid sloc>
+// CHECK-NEXT: TemplateTypeParmType 0x{{[0-9a-fA-F]+}} 'element' dependent depth 0 index 0
+// CHECK-NEXT: TemplateTypeParm 0x{{[0-9a-fA-F]+}} 'element'
+// CHECK-NEXT: DeclRefExpr 0x{{[0-9a-fA-F]+}} <<invalid sloc>> 'int' lvalue
+// CHECK-SAME: NonTypeTemplateParm 0x{{[0-9a-fA-F]+}} 'element_count' 'int'
+
+// Make sure we got a using directive at the end.
+// CHECK: UsingDirectiveDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> Namespace 0x{{[0-9a-fA-F]+}} 'hlsl'
+
+[numthreads(1,1,1)]
+int entry() {
+ // Verify that the alias is generated inside the hlsl namespace.
+ hlsl::matrix<float, 2, 2> Mat2x2;
+
+ // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:26:3, col:35>
+ // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:29> col:29 Mat2x2 'hlsl::matrix<float, 2, 2>'
+
+ // Verify that you don't need to specify the namespace.
+ matrix<int, 2, 2> Vec2x2a;
+
+ // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:32:3, col:28>
+ // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:21> col:21 Vec2x2a 'matrix<int, 2, 2>'
+
+ // Build a bigger matrix.
+ matrix<double, 4, 4> Mat4x4;
+
+ // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:38:3, col:30>
+ // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:24> col:24 Mat4x4 'matrix<double, 4, 4>'
+
+ // Verify that the implicit arguments generate the correct type.
+ matrix<> ImpM...
[truncated]
|
if (Policy.UseHLSLTypes) | ||
OS << ">"; | ||
else | ||
OS << ")))"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes to this file do alter how pointer and reference matrices are printed. I'd like to get @fhahn's opinion on this even if on nothing else.
printBefore(T->getElementType(), OS); | ||
OS << " __attribute__((matrix_type("; | ||
OS << T->getNumRows() << ", " << T->getNumColumns(); | ||
OS << ")))"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change seems wrong to me. IIUC this changes something like float __attribute__((matrix_type(4, 4)))*
to float * __attribute__((matrix_type(4, 4)))
, which would mean the element type is float*
rather than float
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See my response to Florian here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
may be simpler to read to duplicate some code but have the HSL and C++ matrixes types printed completely separately?
clang/lib/AST/ASTContext.cpp
Outdated
@@ -1381,7 +1381,7 @@ void ASTContext::InitBuiltinTypes(const TargetInfo &Target, | |||
if (LangOpts.OpenACC && !LangOpts.OpenMP) { | |||
InitBuiltinType(ArraySectionTy, BuiltinType::ArraySection); | |||
} | |||
if (LangOpts.MatrixTypes) | |||
if (LangOpts.MatrixTypes || LangOpts.HLSL) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is the only line this will be needed then feel free to ignore this comment.
My question is could there be places where LangOpts.MatrixTypes
could diverge from HLSL?
If so instead is it possible for LangOpts.HLSL
to turn on LangOpts.MatrixTypes
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can add MatrixTypes
to the definition of HLSL's variants in LangStandards.def, which is probably a safer approach to this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll look into that. I didn't want to enable matrices by default because I didn't really want to allow the matrix attribute syntax in HLSL.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've done this two ways now:
ae1a58b just enables matrices when HLSL is enabled
eeb3165 attempts to use LangStandards and make it dependent on HLSL versions
I confess I don't know why the second is preferable to the first. If I've misunderstood @llvm-beanz's suggestion, please let me know.
I'm actually unclear on how any of this prevents the problem @farzonl brought up. If what MatrixTypes represents changes so that it no longer matches what HLSL wants, then those changes will seep in without our notice. Given that the design discussion we had about this rejected the notion that we'd forge our own matrix type that inherited from the clang type, I don't see any way around that problem unless we pay attention to its evolution and weigh in when needed. Not that I'm saying that's a bad idea regardless.
All that the above changes alter is that now matrices can be declared using the clang extension which will sidestep at least the size restrictions we have not yet applied, but intend to.
I'm unconvinced this is an improvement, but it was easy enough to do that I thought I'd give us more concrete options to discuss.
@@ -17,12 +17,12 @@ void add(sx10x10_t a, sx5x10_t b, sx10x5_t c) { | |||
// expected-error@-1 {{assigning to 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') from incompatible type 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))')}} | |||
|
|||
a = b + &c; | |||
// expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'sx10x5_t *' (aka 'float __attribute__((matrix_type(10, 5)))*'))}} | |||
// expected-error@-2 {{casting 'sx10x5_t *' (aka 'float __attribute__((matrix_type(10, 5)))*') to incompatible type 'float'}} | |||
// expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'sx10x5_t *' (aka 'float * __attribute__((matrix_type(10, 5)))'))}} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this change intentional? IIUC the patch should be NFC for the existing matrix type support?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi @fhahn!
Thanks for taking a look at this!
It was intentional, but tentative. This is the consequence of the change I commented on here hoping you'd weigh in on. I made the change this way in TypePrinter because it was more convenient to insert the HLSL printing style that way, but I can maintain the previous behavior. The reason I thought that perhaps the alternate position of *
might be correct or preferable is because it is the way that vectors are printed with this example.
At any rate, we can defer the discussion of whether the vector printing should conform with the matrix printing or vice versa for later. For now, I'll implement the printing so that it maintains previous behavior.
edit: came up with a less contrived example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another oddity in how pointers to matrices were printed before was a double space between the type and the attribute. It's minor, but it suggested something was missing there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new printing is illegal. You've changed it from a matrix of float
to a matrix of float*
. The type printer shouldn't be printing types in a way that changes their meaning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, then perhaps we should look into how vectors are printed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I created an issue for this: #114868 It's not really related to HLSL as we don't have pointers or references (yet). I have a tentative fix, but I'd like to see if anyone has thoughts on the further impact it will have on existing type printing.
There is an inconsistency in how matrix and vector pointers and references are printed at present. Also back out a change that's probably better independent than as a rider
Instead of adding an HLSL component to the checks for matrix types, this just enables them whenever an HLSL version that has that explicitly set is chosen
Future uncertain. Better to test with native HLSL types anyway
This reverts commit eeb3165. Using LangOpts->MatrixTypes for both the default value and the keypath meant that when the lang opts were regenerated, it found them the same and determined that replicating the enable-matrix flag wasn't necessary. Using the HLSL keypath differentiates them so that it will know that they need to be regenerated.
@@ -1,7 +1,5 @@ | |||
// RUN: %clang_cc1 -O0 -fenable-matrix -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - -std=c++11 | FileCheck --check-prefixes=CHECK,NOOPT %s | |||
// RUN: %clang_cc1 -O1 -fenable-matrix -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - -std=c++11 | FileCheck --check-prefixes=CHECK,OPT %s | |||
typedef double dx5x5_t __attribute__((matrix_type(5, 5))); | |||
using fx2x3_t = float __attribute__((matrix_type(2, 3))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are the changes to this file required?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really, but I noted that they weren't being used. I think the other changes to this file are useful as well, but I'm happy to remove them to expedite things.
if (Policy.UseHLSLTypes) | ||
OS << ", "; | ||
else | ||
OS << " __attribute__((matrix_type("; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You still have this in the "after" function. That will print this after pointer annotations which changes its meaning incorrectly.
I'm guessing there isn't an existing test case for this, otherwise this change would have broken something.
cc: @fhahn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah. I guess I missed this one. AFAICT, DependentSizedMatrices are never printed.
To make this perfectly limited to HLSL, revert matrix cpp test Revise matrix printing so that should they have pointers or references in HLSL, they will be printed properly as well
Now that DXIL uses the itanium C++ ABI, it can use some of the tests that were SPIRV only before
@@ -2447,7 +2447,7 @@ QualType Sema::BuildExtVectorType(QualType T, Expr *ArraySize, | |||
|
|||
QualType Sema::BuildMatrixType(QualType ElementTy, Expr *NumRows, Expr *NumCols, | |||
SourceLocation AttrLoc) { | |||
assert(Context.getLangOpts().MatrixTypes && | |||
assert(getLangOpts().MatrixTypes && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unrelated?
@@ -459,8 +459,81 @@ void HLSLExternalSemaSource::defineHLSLVectorAlias() { | |||
HLSLNamespace->addDecl(Template); | |||
} | |||
|
|||
void HLSLExternalSemaSource::defineHLSLMatrixAlias() { | |||
ASTContext &AST = SemaPtr->getASTContext(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a comment here explaining what alias is added?
printBefore(T->getElementType(), OS); | ||
OS << " __attribute__((matrix_type("; | ||
OS << T->getNumRows() << ", " << T->getNumColumns(); | ||
OS << ")))"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
may be simpler to read to duplicate some code but have the HSL and C++ matrixes types printed completely separately?
HLSL needs matrix support. This allows matrices when the target language is HLSL and defines a matrix template alias that allows the short forms of matrix types to be defined in typedefs in a default header. Makes some changes to how matrices are printed in HLSL policy. This required a tweak to how existing matrices are printed in diagnostics that is more consistent with other types.
These matrix types will function exactly as the clang matrix extension dictates. Alterations to that behavior both specific to HLSL and also potentially expanding on the matrix extension will follow.
fixes #109839