Skip to content

Commit f0c9acc

Browse files
authored
[SPIRV][HLSL] Add FixedVector GEP legalization (#171682)
fixes #170241 PR #169090 updated vector swizzle elements individually for HLSL This is because rawbuffer writes via rawBufferStore can cause datarace in concurrent writing the way vectors are written all at once. This means we needed individual writes per element. So that means we need to be able to GEP into an element of a vector. The SPIRV backend did not support this pattern. SPIRV assumes Composite types (ie Vectors, Structs, and Arrays) only for Ptr legalization in it's store transformations via transformStore. Fixing things at the point of ptr legalziation for the store would be too late because we would have lost critical ptr type information still available in LLVM IR. Instead what we needed to do is teach the walkLogicalAccessChain used by buildLogicalAccessChainFromGEP to converts a byte-offset GEP on an i8-pointer into a logical SPIR-V composite access chain. In short: The fundamental issue is that walkLogicalAccessChain currently refuses to handle vector-based logical GEPs, but Clang’s new HLSL path produces exactly that pattern: - you have a logical i8* GEP that represents indexing into the elements of a <4 x i32>. We need to teach walkLogicalAccessChain to treat a FixedVectorType similarly to an array of its element type. So in walkLogicalAccessChain replace the “give up on vector” part with logic that: - Interprets the byte Offset as indexing into the vector elements. - Uses the vector’s element type and number of elements.
1 parent 84d1de2 commit f0c9acc

File tree

2 files changed

+57
-2
lines changed

2 files changed

+57
-2
lines changed

llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -665,9 +665,20 @@ bool SPIRVEmitIntrinsics::walkLogicalAccessChain(
665665
Offset -= STL->getElementOffset(Element);
666666
CurType = ST->getElementType(Element);
667667
OnLiteralIndexing(CurType, Element);
668+
} else if (auto *VT = dyn_cast<FixedVectorType>(CurType)) {
669+
Type *EltTy = VT->getElementType();
670+
TypeSize EltSizeBits = DL.getTypeSizeInBits(EltTy);
671+
assert(EltSizeBits % 8 == 0 &&
672+
"Element type size in bits must be a multiple of 8.");
673+
uint32_t EltTypeSize = EltSizeBits / 8;
674+
assert(Offset < VT->getNumElements() * EltTypeSize);
675+
uint64_t Index = Offset / EltTypeSize;
676+
Offset -= Index * EltTypeSize;
677+
CurType = EltTy;
678+
OnLiteralIndexing(CurType, Index);
679+
668680
} else {
669-
// Vector type indexing should not use GEP.
670-
// So if we have an index left, something is wrong. Giving up.
681+
// Unknown composite kind; give up.
671682
return true;
672683
}
673684
} while (Offset > 0);
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
2+
; RUN: llc -verify-machineinstrs -O3 -mtriple=spirv1.6-unknown-vulkan1.3-compute %s -o - | FileCheck %s
3+
; RUN: %if spirv-tools %{ llc -O3 -mtriple=spirv1.6-unknown-vulkan1.3-compute %s -o - -filetype=obj | spirv-val %}
4+
5+
; StructuredBuffer<int4> In : register(t0);
6+
; RWStructuredBuffer<int4> Out : register(u1);
7+
;
8+
; [numthreads(1,1,1)]
9+
; void main() {
10+
; Out[0].y = In[0].y;
11+
; }
12+
13+
@.str = private unnamed_addr constant [3 x i8] c"In\00", align 1
14+
@.str.2 = private unnamed_addr constant [4 x i8] c"Out\00", align 1
15+
16+
define void @main() local_unnamed_addr #0 {
17+
; CHECK-LABEL: main
18+
; CHECK: %33 = OpFunction %2 None %3 ; -- Begin function main
19+
; CHECK-NEXT: %1 = OpLabel
20+
; CHECK-NEXT: %34 = OpVariable %20 Function %29
21+
; CHECK-NEXT: %35 = OpVariable %19 Function %30
22+
; CHECK-NEXT: %36 = OpCopyObject %12 %31
23+
; CHECK-NEXT: %37 = OpCopyObject %10 %32
24+
; CHECK-NEXT: %38 = OpAccessChain %7 %36 %21 %21
25+
; CHECK-NEXT: %39 = OpLoad %6 %38 Aligned 16
26+
; CHECK-NEXT: %40 = OpCompositeExtract %4 %39 1
27+
; CHECK-NEXT: %41 = OpAccessChain %7 %37 %21 %21
28+
; CHECK-NEXT: %42 = OpInBoundsAccessChain %5 %41 %22
29+
; CHECK-NEXT: OpStore %42 %40 Aligned 4
30+
; CHECK-NEXT: OpReturn
31+
; CHECK-NEXT: OpFunctionEnd
32+
entry:
33+
%0 = tail call target("spirv.VulkanBuffer", [0 x <4 x i32>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4i32_12_0t(i32 0, i32 0, i32 1, i32 0, ptr nonnull @.str)
34+
%1 = tail call target("spirv.VulkanBuffer", [0 x <4 x i32>], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4i32_12_1t(i32 0, i32 1, i32 1, i32 0, ptr nonnull @.str.2)
35+
%2 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v4i32_12_0t(target("spirv.VulkanBuffer", [0 x <4 x i32>], 12, 0) %0, i32 0)
36+
%3 = load <4 x i32>, ptr addrspace(11) %2, align 16
37+
%4 = extractelement <4 x i32> %3, i64 1
38+
%5 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v4i32_12_1t(target("spirv.VulkanBuffer", [0 x <4 x i32>], 12, 1) %1, i32 0)
39+
%6 = getelementptr inbounds nuw i8, ptr addrspace(11) %5, i64 4
40+
store i32 %4, ptr addrspace(11) %6, align 4
41+
ret void
42+
}
43+
44+
!0 = !{i32 1, !"wchar_size", i32 4}

0 commit comments

Comments
 (0)