-
Notifications
You must be signed in to change notification settings - Fork 13.3k
[X86][SelectionDAG] Fix the Gather's base and index by modifying the Scale value #134979
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[X86][SelectionDAG] Fix the Gather's base and index by modifying the Scale value #134979
Conversation
… Node. Using the approach to update the Scale if SHL Opcode and followed by truncate.
@llvm/pr-subscribers-backend-x86 @llvm/pr-subscribers-llvm-selectiondag Author: Rohit Aggarwal (rohitaggarwal007) ChangesFix the Gather's base and index for one use or multiple uses of Index Node. Using the approach to update the Scale if SHL Opcode and followed by truncate. Full diff: https://github.com/llvm/llvm-project/pull/134979.diff 3 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 38376de5783ae..7c51ee8222512 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12131,8 +12131,8 @@ bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled,
if (IndexIsScaled)
return false;
- if (!isNullConstant(BasePtr) && !Index.hasOneUse())
- return false;
+ // if (!isNullConstant(BasePtr) && !Index.hasOneUse())
+ // return false;
EVT VT = BasePtr.getValueType();
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 47ac1ee571269..61e6d0734f402 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -56512,6 +56512,120 @@ static SDValue rebuildGatherScatter(MaskedGatherScatterSDNode *GorS,
Scatter->isTruncatingStore());
}
+// Target override this function to decide whether it want to update the base
+// and index value of a non-uniform gep
+static bool updateBaseAndIndex(SDValue &Base, SDValue &Index, SDValue &Scale,
+ const SDLoc &DL, const SDValue &Gep,
+ SelectionDAG &DAG) {
+ SDValue Nbase;
+ SDValue Nindex;
+ SDValue NScale;
+ bool Changed = false;
+ // This function check the opcode of Index and update the index
+ auto checkAndUpdateIndex = [&](SDValue &Idx) {
+ if (Idx.getOpcode() == ISD::SHL) { // shl zext, BV
+ SDValue Op10 = Idx.getOperand(0); // Zext or Sext value
+ SDValue Op11 = Idx.getOperand(1); // Build vector of constant
+ std::optional<uint64_t> ShAmt = DAG.getValidMinimumShiftAmount(Idx);
+
+ unsigned IndexWidth = Op10.getScalarValueSizeInBits();
+ if ((Op10.getOpcode() == ISD::SIGN_EXTEND ||
+ Op10.getOpcode() == ISD::ZERO_EXTEND) &&
+ IndexWidth > 32 &&
+ Op10.getOperand(0).getScalarValueSizeInBits() <= 32 &&
+ DAG.ComputeNumSignBits(Op10) > (IndexWidth - 32) && ShAmt) {
+
+ KnownBits ExtKnown = DAG.computeKnownBits(Op10);
+ bool ExtIsNonNegative = ExtKnown.isNonNegative();
+ KnownBits ExtOpKnown = DAG.computeKnownBits(Op10.getOperand(0));
+ bool ExtOpIsNonNegative = ExtOpKnown.isNonNegative();
+ if (!ExtIsNonNegative || !ExtOpIsNonNegative)
+ return false;
+
+ SDValue NewOp10 =
+ Op10.getOperand(0); // Get the Operand zero from the ext
+ EVT VT = NewOp10.getValueType(); // Use the operand's type to determine
+ // the type of index
+
+ // auto *ConstEltNo = dyn_cast<ConstantSDNode>(Op11.getOperand(0));
+ // if (!ConstEltNo)
+ // return false;
+ uint64_t ScaleAmt = cast<ConstantSDNode>(Scale)->getZExtValue();
+ uint64_t NewScaleAmt = ScaleAmt * (1ULL << *ShAmt);
+ LLVM_DEBUG(dbgs() << NewScaleAmt << " NewScaleAmt"
+ << "\n");
+ if (isPowerOf2_64(NewScaleAmt) && NewScaleAmt <= 8) {
+ // Nindex = NewOp10.getOperand(0);
+ Nindex = Op10;
+ NScale = DAG.getTargetConstant(NewScaleAmt, DL, Scale.getValueType());
+ return true;
+ }
+ // SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(),
+ // DAG.getConstant(ConstEltNo->getZExtValue(),
+ // DL, VT.getScalarType()));
+ // Nindex = DAG.getNode(ISD::SHL, DL, VT, NewOp10,
+ // DAG.getBuildVector(VT, DL, Ops));
+ }
+ }
+ return false;
+ };
+
+ // For the gep instruction, we are trying to properly assign the base and
+ // index value We are go through the lower code and iterate backward.
+ if (isNullConstant(Base) && Gep.getOpcode() == ISD::ADD) {
+ SDValue Op0 = Gep.getOperand(0); // base or add
+ SDValue Op1 = Gep.getOperand(1); // build vector or SHL
+ Nbase = Op0;
+ SDValue Idx = Op1;
+ auto Flags = Gep->getFlags();
+
+ if (Op0->getOpcode() == ISD::ADD) { // add t15(base), t18(Idx)
+ SDValue Op00 = Op0.getOperand(0); // Base
+ Nbase = Op00;
+ Idx = Op0.getOperand(1);
+ } else if (!(Op0->getOpcode() == ISD::BUILD_VECTOR &&
+ Op0.getOperand(0).getOpcode() == ISD::CopyFromReg)) {
+ return false;
+ }
+ if (!checkAndUpdateIndex(Idx)) {
+ return false;
+ }
+ Base = Nbase.getOperand(0);
+
+ if (Op0 != Nbase) {
+ auto *ConstEltNo = dyn_cast<ConstantSDNode>(Op1.getOperand(0));
+ if (!ConstEltNo)
+ return false;
+
+ // SmallVector<SDValue, 8> Ops(
+ // Nindex.getValueType().getVectorNumElements(),
+ // DAG.getConstant(ConstEltNo->getZExtValue(), DL,
+ // Nindex.getValueType().getScalarType()));
+ Base = DAG.getNode(ISD::ADD, DL, Nbase.getOperand(0).getValueType(),
+ Nbase.getOperand(0), Op1.getOperand(0), Flags);
+ }
+ Index = Nindex;
+ Scale = NScale;
+ Changed = true;
+ } else if (Base.getOpcode() == ISD::CopyFromReg ||
+ (Base.getOpcode() == ISD::ADD &&
+ Base.getOperand(0).getOpcode() == ISD::CopyFromReg &&
+ isConstOrConstSplat(Base.getOperand(1)))) {
+ if (checkAndUpdateIndex(Index)) {
+ Index = Nindex;
+ Changed = true;
+ }
+ }
+ if (Changed) {
+ LLVM_DEBUG(dbgs() << "Successful in updating the non uniform gep "
+ "information\n";
+ dbgs() << "updated base "; Base.dump();
+ dbgs() << "updated Index "; Index.dump(););
+ return true;
+ }
+ return false;
+}
+
static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI) {
SDLoc DL(N);
@@ -56523,6 +56637,29 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
if (DCI.isBeforeLegalize()) {
+ // if (updateBaseAndIndex(Base, Index, Scale, DL, Index, DAG))
+ // return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
+ //
+
+ // Attempt to move shifted index into the address scale, allows further
+ // index truncation below.
+ if (Index.getOpcode() == ISD::SHL && isa<ConstantSDNode>(Scale)) {
+ uint64_t ScaleAmt = Scale->getAsZExtVal();
+ if (auto MinShAmt = DAG.getValidMinimumShiftAmount(Index)) {
+ if (*MinShAmt >= 1 && ScaleAmt < 8 &&
+ DAG.ComputeNumSignBits(Index.getOperand(0)) > 1) {
+ SDValue ShAmt = Index.getOperand(1);
+ SDValue NewShAmt =
+ DAG.getNode(ISD::SUB, DL, ShAmt.getValueType(), ShAmt,
+ DAG.getConstant(1, DL, ShAmt.getValueType()));
+ SDValue NewIndex = DAG.getNode(ISD::SHL, DL, Index.getValueType(),
+ Index.getOperand(0), NewShAmt);
+ SDValue NewScale =
+ DAG.getConstant(ScaleAmt * 2, DL, Scale.getValueType());
+ return rebuildGatherScatter(GorS, NewIndex, Base, NewScale, DAG);
+ }
+ }
+ }
unsigned IndexWidth = Index.getScalarValueSizeInBits();
// Shrink indices if they are larger than 32-bits.
@@ -56552,6 +56689,12 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
Index = DAG.getNode(ISD::TRUNCATE, DL, NewVT, Index);
return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
}
+
+ // Shrink if we remove an illegal type.
+ if (!TLI.isTypeLegal(Index.getValueType()) && TLI.isTypeLegal(NewVT)) {
+ Index = DAG.getNode(ISD::TRUNCATE, DL, NewVT, Index);
+ return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
+ }
}
}
diff --git a/llvm/test/CodeGen/X86/gatherBaseIndexFix.ll b/llvm/test/CodeGen/X86/gatherBaseIndexFix.ll
new file mode 100644
index 0000000000000..faa83b0a20290
--- /dev/null
+++ b/llvm/test/CodeGen/X86/gatherBaseIndexFix.ll
@@ -0,0 +1,68 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=x86_64-unknown-unknown -mattr=+avx512f,+avx512bw,+avx512vl,+avx512dq -mcpu=znver5 < %s | FileCheck %s
+
+%struct.pt = type { float, float, float, i32 }
+%struct.res = type {<16 x float>, <16 x float>}
+
+define <16 x float> @test_gather_16f32_1(ptr %x, ptr %arr, <16 x i1> %mask, <16 x float> %src0) {
+; CHECK-LABEL: test_gather_16f32_1:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vpsllw $7, %xmm0, %xmm0
+; CHECK-NEXT: vmovdqu64 (%rsi), %zmm2
+; CHECK-NEXT: vpmovb2m %xmm0, %k1
+; CHECK-NEXT: vpandd {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to16}, %zmm2, %zmm0
+; CHECK-NEXT: vpaddd %zmm0, %zmm0, %zmm0
+; CHECK-NEXT: vgatherdps (%rdi,%zmm0,8), %zmm1 {%k1}
+; CHECK-NEXT: vmovaps %zmm1, %zmm0
+; CHECK-NEXT: retq
+ %wide.load = load <16 x i32>, ptr %arr, align 4
+ %and = and <16 x i32> %wide.load, <i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911>
+ %zext = zext <16 x i32> %and to <16 x i64>
+ %ptrs = getelementptr inbounds %struct.pt, ptr %x, <16 x i64> %zext
+ %res = call <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr> %ptrs, i32 4, <16 x i1> %mask, <16 x float> %src0)
+ ret <16 x float> %res
+ }
+
+define <16 x float> @test_gather_16f32_2(ptr %x, ptr %arr, <16 x i1> %mask, <16 x float> %src0) {
+; CHECK-LABEL: test_gather_16f32_2:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vpsllw $7, %xmm0, %xmm0
+; CHECK-NEXT: vmovdqu64 (%rsi), %zmm2
+; CHECK-NEXT: vpmovb2m %xmm0, %k1
+; CHECK-NEXT: vpandd {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to16}, %zmm2, %zmm0
+; CHECK-NEXT: vpaddd %zmm0, %zmm0, %zmm0
+; CHECK-NEXT: vgatherdps 4(%rdi,%zmm0,8), %zmm1 {%k1}
+; CHECK-NEXT: vmovaps %zmm1, %zmm0
+; CHECK-NEXT: retq
+ %wide.load = load <16 x i32>, ptr %arr, align 4
+ %and = and <16 x i32> %wide.load, <i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911>
+ %zext = zext <16 x i32> %and to <16 x i64>
+ %ptrs = getelementptr inbounds %struct.pt, ptr %x, <16 x i64> %zext, i32 1
+ %res = call <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr> %ptrs, i32 4, <16 x i1> %mask, <16 x float> %src0)
+ ret <16 x float> %res
+ }
+
+define {<16 x float>, <16 x float>} @test_gather_16f32_3(ptr %x, ptr %arr, <16 x i1> %mask, <16 x float> %src0) {
+; CHECK-LABEL: test_gather_16f32_3:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vpsllw $7, %xmm0, %xmm0
+; CHECK-NEXT: vpmovb2m %xmm0, %k1
+; CHECK-NEXT: vmovdqu64 (%rsi), %zmm0
+; CHECK-NEXT: vpandd {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to16}, %zmm0, %zmm0
+; CHECK-NEXT: kmovq %k1, %k2
+; CHECK-NEXT: vpaddd %zmm0, %zmm0, %zmm2
+; CHECK-NEXT: vmovaps %zmm1, %zmm0
+; CHECK-NEXT: vgatherdps (%rdi,%zmm2,8), %zmm0 {%k2}
+; CHECK-NEXT: vgatherdps 4(%rdi,%zmm2,8), %zmm1 {%k1}
+; CHECK-NEXT: retq
+ %wide.load = load <16 x i32>, ptr %arr, align 4
+ %and = and <16 x i32> %wide.load, <i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911>
+ %zext = zext <16 x i32> %and to <16 x i64>
+ %ptrs1 = getelementptr inbounds %struct.pt, ptr %x, <16 x i64> %zext
+ %res1 = call <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr> %ptrs1, i32 4, <16 x i1> %mask, <16 x float> %src0)
+ %ptrs = getelementptr inbounds %struct.pt, ptr %x, <16 x i64> %zext, i32 1
+ %res = call <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr> %ptrs, i32 4, <16 x i1> %mask, <16 x float> %src0)
+ %pair1 = insertvalue {<16 x float>, <16 x float>} undef, <16 x float> %res1, 0
+ %pair2 = insertvalue {<16 x float>, <16 x float>} %pair1, <16 x float> %res, 1
+ ret {<16 x float>, <16 x float>} %pair2
+ }
|
@RKSimon I have created this PR with Scale/truncate changes and updated test case. |
… Node. Using the approach to update the Scale if SHL Opcode and followed by truncate.
…7/llvm-project into gatherMultipleOccurrence
…7/llvm-project into gatherMultipleOccurrence
…7/llvm-project into gatherMultipleOccurrence
…7/llvm-project into gatherMultipleOccurrence
…7/llvm-project into gatherMultipleOccurrence
…007/llvm-project into gatherMultipleOccurrence
cheers @rohitaggarwal007 - I've pushed your test coverage to trunk inside masked_gather_scatter.ll - so you should be able to merge (and remove gatherBaseIndexFix.ll its now superfluous) - my test patch now shows the current diff for your tests |
…dices patterns and add the splat into the (scalar) base address We already did this for constant cases, this patch generalizes the existing fold to attempt to extract the splat from either operand of a ADD node for the gather/scatter index value This cleanup should also make it easier to add support for splitting vXi32 indices on x86_64 64-bit pointer targets in the future as well. Noticed while reviewing llvm#134979
Thanks @RKSimon. Sorry, I misunderstood your point. |
…dices patterns and add the splat into the (scalar) base address (#135201) We already did this for constant cases, this patch generalizes the existing fold to attempt to extract the splat from either operand of a ADD node for the gather/scatter index value This cleanup should also make it easier to add support for splitting vXi32 indices on x86_64 64-bit pointer targets in the future as well. Noticed while reviewing #134979
✅ With the latest revision this PR passed the undef deprecator. |
@RKSimon, I have update the testcase in the masked_gather_scatter.ll |
@@ -12163,9 +12163,6 @@ bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled, | |||
if (IndexIsScaled) | |||
return false; | |||
|
|||
if (!isNullConstant(BasePtr) && !Index.hasOneUse()) | |||
return false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure this is required anymore
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the removal of these lines is not required anymore, right?
Let me check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
; X86-SKX-NEXT: vpslld $4, (%ecx), %zmm2 | ||
; X86-SKX-NEXT: vmovdqu64 (%ecx), %zmm0 | ||
; X86-SKX-NEXT: vpandd {{\.?LCPI[0-9]+_[0-9]+}}{1to16}, %zmm0, %zmm0 | ||
; X86-SKX-NEXT: vpaddd %zmm0, %zmm0, %zmm2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can handle this by adding an additional SimplifyDemandedBits call to combineGatherScatter - for cases where PtrVT == IndexSVT, then we can use the Scale value (assuming its Pow2 which it should be) to demand just the lower bits of the Index param - which should remove the VPANDD.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implemented a logic to do this.
@@ -56508,6 +56508,120 @@ static SDValue rebuildGatherScatter(MaskedGatherScatterSDNode *GorS, | |||
Scatter->isTruncatingStore()); | |||
} | |||
|
|||
// Target override this function to decide whether it want to update the base | |||
// and index value of a non-uniform gep | |||
static bool updateBaseAndIndex(SDValue &Base, SDValue &Index, SDValue &Scale, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not convinced the updateBaseAndIndex method is going to be necessary - its a very monolithic combine, which can be handled already with some of the focussed folds inside combineGatherScatter - it looks like we're just needing to address a few corner cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is not called anywhere in our logic. I will remove it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed the function.
…dices patterns and add the splat into the (scalar) base address (llvm#135201) We already did this for constant cases, this patch generalizes the existing fold to attempt to extract the splat from either operand of a ADD node for the gather/scatter index value This cleanup should also make it easier to add support for splitting vXi32 indices on x86_64 64-bit pointer targets in the future as well. Noticed while reviewing llvm#134979
@@ -4930,49 +4874,22 @@ define <16 x float> @test_gather_structpt_16f32_mask_index_offset(ptr %x, ptr %a | |||
; X86-KNL-NEXT: vptestmd %zmm0, %zmm0, %k1 | |||
; X86-KNL-NEXT: movl {{[0-9]+}}(%esp), %eax | |||
; X86-KNL-NEXT: movl {{[0-9]+}}(%esp), %ecx | |||
; X86-KNL-NEXT: vpslld $4, (%ecx), %zmm0 | |||
; X86-KNL-NEXT: vgatherdps 4(%eax,%zmm0), %zmm1 {%k1} | |||
; X86-KNL-NEXT: vmovdqu64 (%ecx), %zmm0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@RKSimon, Is our codegen good here? vpslld is replace with mov and add statement. We are increasing instruction count!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC that was the point of the "empty TODO" in the shift -> scale fold? My prototype transferred one shift bit at a time until we hit the Scale = 8 max. But really we should only attempt the fold if either (a) it will likely allow further simplification or (b) we can transfer the entire min scale amount to scale (NewScale <= 8).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, TODO was just my placeholder.
Yeah that is sound. Currently, Not sure how we can simplify more. But (b) should be there. As we are increasing the arithmetic complexity of the address calculation by changing the scale value. One mov and add are introduced in replacement of vpslld $1 by the Vector-legalized selection DAG is doing this optimization.
I will check the instructions cycle cost in KHL SWOG and let me think how to handle this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@RKSimon, I have update the patch to handle the use cases.
Please have a look.
Thanks
Fix the Gather's base and index for one use or multiple uses of Index Node. Using the approach to update the Scale if SHL Opcode and followed by truncate.