-
Notifications
You must be signed in to change notification settings - Fork 14.2k
[HLSL][SPIRV] Allow large z value in numthreads #144934
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
The current validation checks for numthreads assume that the target is DXIL so the version checks inadvertently issue error when targeting SPIR-V.
@llvm/pr-subscribers-hlsl @llvm/pr-subscribers-clang Author: Steven Perron (s-perron) ChangesThe current validation checks for numthreads assume that the target is Full diff: https://github.com/llvm/llvm-project/pull/144934.diff 2 Files Affected:
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index b55f4fd786b58..9f39c077cea7a 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1033,12 +1033,15 @@ void SemaHLSL::handleRootSignatureAttr(Decl *D, const ParsedAttr &AL) {
void SemaHLSL::handleNumThreadsAttr(Decl *D, const ParsedAttr &AL) {
llvm::VersionTuple SMVersion =
getASTContext().getTargetInfo().getTriple().getOSVersion();
+ bool IsDXIL = getASTContext().getTargetInfo().getTriple().getArch() ==
+ llvm::Triple::dxil;
+
uint32_t ZMax = 1024;
uint32_t ThreadMax = 1024;
- if (SMVersion.getMajor() <= 4) {
+ if (IsDXIL && SMVersion.getMajor() <= 4) {
ZMax = 1;
ThreadMax = 768;
- } else if (SMVersion.getMajor() == 5) {
+ } else if (IsDXIL && SMVersion.getMajor() == 5) {
ZMax = 64;
ThreadMax = 1024;
}
diff --git a/clang/test/SemaHLSL/num_threads.hlsl b/clang/test/SemaHLSL/num_threads.hlsl
index b5f9ad6c33cd6..96200312bbf69 100644
--- a/clang/test/SemaHLSL/num_threads.hlsl
+++ b/clang/test/SemaHLSL/num_threads.hlsl
@@ -10,6 +10,8 @@
// RUN: %clang_cc1 -triple dxil-pc-shadermodel5.0-compute -x hlsl -ast-dump -o - %s -DFAIL -verify
// RUN: %clang_cc1 -triple dxil-pc-shadermodel4.0-compute -x hlsl -ast-dump -o - %s -DFAIL -verify
+// RUN: %clang_cc1 -triple spirv-pc-vulkan1.3-compute -x hlsl -ast-dump -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
+
#if __SHADER_TARGET_STAGE == __SHADER_STAGE_COMPUTE || __SHADER_TARGET_STAGE == __SHADER_STAGE_MESH || __SHADER_TARGET_STAGE == __SHADER_STAGE_AMPLIFICATION || __SHADER_TARGET_STAGE == __SHADER_STAGE_LIBRARY
#ifdef FAIL
@@ -88,24 +90,30 @@ int entry() {
// Because these two attributes match, they should both appear in the AST
[numthreads(2,2,1)]
-// CHECK: HLSLNumThreadsAttr 0x{{[0-9a-fA-F]+}} <line:90:2, col:18> 2 2 1
+// CHECK: HLSLNumThreadsAttr 0x{{[0-9a-fA-F]+}} <line:{{[0-9]+}}:2, col:18> 2 2 1
int secondFn();
[numthreads(2,2,1)]
-// CHECK: HLSLNumThreadsAttr 0x{{[0-9a-fA-F]+}} <line:94:2, col:18> 2 2 1
+// CHECK: HLSLNumThreadsAttr 0x{{[0-9a-fA-F]+}} <line:{{[0-9]+}}:2, col:18> 2 2 1
int secondFn() {
return 1;
}
[numthreads(4,2,1)]
-// CHECK: HLSLNumThreadsAttr 0x{{[0-9a-fA-F]+}} <line:100:2, col:18> 4 2 1
+// CHECK: HLSLNumThreadsAttr 0x{{[0-9a-fA-F]+}} <line:{{[0-9]+}}:2, col:18> 4 2 1
int onlyOnForwardDecl();
-// CHECK: HLSLNumThreadsAttr 0x{{[0-9a-fA-F]+}} <line:100:2, col:18> Inherited 4 2 1
+// CHECK: HLSLNumThreadsAttr 0x{{[0-9a-fA-F]+}} <line:{{[0-9]+}}:2, col:18> Inherited 4 2 1
int onlyOnForwardDecl() {
return 1;
}
+#ifdef __spirv__
+[numthreads(4,2,128)]
+// CHECK-SPIRV: HLSLNumThreadsAttr 0x{{[0-9a-fA-F]+}} <line:{{[0-9]+}}:2, col:20> 4 2 128
+int largeZ();
+#endif
+
#else // Vertex and Pixel only beyond here
// expected-error-re@+1 {{attribute 'numthreads' is unsupported in '{{[A-Za-z]+}}' shaders, requires one of the following: compute, amplification, mesh}}
[numthreads(1,1,1)]
|
@@ -1033,12 +1033,15 @@ void SemaHLSL::handleRootSignatureAttr(Decl *D, const ParsedAttr &AL) { | |||
void SemaHLSL::handleNumThreadsAttr(Decl *D, const ParsedAttr &AL) { | |||
llvm::VersionTuple SMVersion = | |||
getASTContext().getTargetInfo().getTriple().getOSVersion(); | |||
bool IsDXIL = getASTContext().getTargetInfo().getTriple().getArch() == | |||
llvm::Triple::dxil; | |||
|
|||
uint32_t ZMax = 1024; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we even disable any Xmax/Ymax/Zmax check when targeting SPIR-V/Vk? Seems like those are set by the device limits, so not limited by the vk version
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so. These are close to the limits on actual machines. For Android, the existing checks are way beyond the limit.
See maxComputeWorkGroupInvocations and maxComputeWorkGroupSize in https://vulkan.lunarg.com/doc/sdk/1.4.313.2/windows/profiles_definitions.html.
The current validation checks for numthreads assume that the target is
DXIL so the version checks inadvertently issue error when targeting
SPIR-V.