-
Notifications
You must be signed in to change notification settings - Fork 13.3k
[X86] combineGatherScatter - split non-constant (add v, (splat b)) indices patterns and add the splat into the (scalar) base address #135201
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…dices patterns and add the splat into the (scalar) base address We already did this for constant cases, this patch generalizes the existing fold to attempt to extract the splat from either operand of a ADD node for the gather/scatter index value This cleanup should also make it easier to add support for splitting vXi32 indices on x86_64 64-bit pointer targets in the future as well. Noticed while reviewing llvm#134979
@llvm/pr-subscribers-backend-x86 Author: Simon Pilgrim (RKSimon) ChangesWe already did this for constant cases, this patch generalizes the existing fold to attempt to extract the splat from either operand of a ADD node for the gather/scatter index value This cleanup should also make it easier to add support for splitting vXi32 indices on x86_64 64-bit pointer targets in the future as well. Noticed while reviewing #134979 CC @rohitaggarwal007 Full diff: https://github.com/llvm/llvm-project/pull/135201.diff 2 Files Affected:
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index a3c423270f44a..77808608045f9 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -56521,6 +56521,7 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
SDValue Base = GorS->getBasePtr();
SDValue Scale = GorS->getScale();
EVT IndexVT = Index.getValueType();
+ EVT IndexSVT = IndexVT.getVectorElementType();
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
if (DCI.isBeforeLegalize()) {
@@ -56557,41 +56558,51 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
}
EVT PtrVT = TLI.getPointerTy(DAG.getDataLayout());
- // Try to move splat constant adders from the index operand to the base
+
+ // Try to move splat adders from the index operand to the base
// pointer operand. Taking care to multiply by the scale. We can only do
// this when index element type is the same as the pointer type.
// Otherwise we need to be sure the math doesn't wrap before the scale.
- if (Index.getOpcode() == ISD::ADD &&
- IndexVT.getVectorElementType() == PtrVT && isa<ConstantSDNode>(Scale)) {
+ if (Index.getOpcode() == ISD::ADD && IndexSVT == PtrVT &&
+ isa<ConstantSDNode>(Scale)) {
uint64_t ScaleAmt = Scale->getAsZExtVal();
- if (auto *BV = dyn_cast<BuildVectorSDNode>(Index.getOperand(1))) {
- BitVector UndefElts;
- if (ConstantSDNode *C = BV->getConstantSplatNode(&UndefElts)) {
- // FIXME: Allow non-constant?
- if (UndefElts.none()) {
- // Apply the scale.
- APInt Adder = C->getAPIntValue() * ScaleAmt;
- // Add it to the existing base.
- Base = DAG.getNode(ISD::ADD, DL, PtrVT, Base,
- DAG.getConstant(Adder, DL, PtrVT));
- Index = Index.getOperand(0);
- return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
- }
- }
- // It's also possible base is just a constant. In that case, just
- // replace it with 0 and move the displacement into the index.
- if (BV->isConstant() && isa<ConstantSDNode>(Base) &&
- isOneConstant(Scale)) {
- SDValue Splat = DAG.getSplatBuildVector(IndexVT, DL, Base);
- // Combine the constant build_vector and the constant base.
- Splat = DAG.getNode(ISD::ADD, DL, IndexVT, Index.getOperand(1), Splat);
- // Add to the LHS of the original Index add.
- Index = DAG.getNode(ISD::ADD, DL, IndexVT, Index.getOperand(0), Splat);
- Base = DAG.getConstant(0, DL, Base.getValueType());
- return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
+ for (unsigned I = 0; I != 2; ++I)
+ if (auto *BV = dyn_cast<BuildVectorSDNode>(Index.getOperand(I))) {
+ BitVector UndefElts;
+ if (SDValue Splat = BV->getSplatValue(&UndefElts)) {
+ if (UndefElts.none()) {
+ // If the splat value is constant we can add the scaled splat value
+ // to the existing base.
+ if (auto *C = dyn_cast<ConstantSDNode>(Splat)) {
+ APInt Adder = C->getAPIntValue() * ScaleAmt;
+ SDValue NewBase = DAG.getNode(ISD::ADD, DL, PtrVT, Base,
+ DAG.getConstant(Adder, DL, PtrVT));
+ SDValue NewIndex = Index.getOperand(1 - I);
+ return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG);
+ }
+ // For non-constant cases, limit this to non-scaled cases.
+ if (ScaleAmt == 1) {
+ SDValue NewBase = DAG.getNode(ISD::ADD, DL, PtrVT, Base, Splat);
+ SDValue NewIndex = Index.getOperand(1 - I);
+ return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG);
+ }
+ }
+ }
+ // It's also possible base is just a constant. In that case, just
+ // replace it with 0 and move the displacement into the index.
+ if (ScaleAmt == 1 && BV->isConstant() && isa<ConstantSDNode>(Base)) {
+ SDValue Splat = DAG.getSplatBuildVector(IndexVT, DL, Base);
+ // Combine the constant build_vector and the constant base.
+ Splat =
+ DAG.getNode(ISD::ADD, DL, IndexVT, Index.getOperand(I), Splat);
+ // Add to the other half of the original Index add.
+ SDValue NewIndex = DAG.getNode(ISD::ADD, DL, IndexVT,
+ Index.getOperand(1 - I), Splat);
+ SDValue NewBase = DAG.getConstant(0, DL, PtrVT);
+ return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG);
+ }
}
- }
}
if (DCI.isBeforeLegalizeOps()) {
diff --git a/llvm/test/CodeGen/X86/masked_gather_scatter.ll b/llvm/test/CodeGen/X86/masked_gather_scatter.ll
index 5effb18fb6aa6..46e589b7b1be9 100644
--- a/llvm/test/CodeGen/X86/masked_gather_scatter.ll
+++ b/llvm/test/CodeGen/X86/masked_gather_scatter.ll
@@ -5028,12 +5028,10 @@ define {<16 x float>, <16 x float>} @test_gather_16f32_mask_index_pair(ptr %x, p
; X86-KNL-NEXT: movl {{[0-9]+}}(%esp), %eax
; X86-KNL-NEXT: movl {{[0-9]+}}(%esp), %ecx
; X86-KNL-NEXT: vpslld $4, (%ecx), %zmm2
-; X86-KNL-NEXT: vpbroadcastd %eax, %zmm0
-; X86-KNL-NEXT: vpaddd %zmm2, %zmm0, %zmm3
; X86-KNL-NEXT: kmovw %k1, %k2
; X86-KNL-NEXT: vmovaps %zmm1, %zmm0
; X86-KNL-NEXT: vgatherdps (%eax,%zmm2), %zmm0 {%k2}
-; X86-KNL-NEXT: vgatherdps 4(,%zmm3), %zmm1 {%k1}
+; X86-KNL-NEXT: vgatherdps 4(%eax,%zmm2), %zmm1 {%k1}
; X86-KNL-NEXT: retl
;
; X64-SKX-SMALL-LABEL: test_gather_16f32_mask_index_pair:
@@ -5097,12 +5095,10 @@ define {<16 x float>, <16 x float>} @test_gather_16f32_mask_index_pair(ptr %x, p
; X86-SKX-NEXT: movl {{[0-9]+}}(%esp), %eax
; X86-SKX-NEXT: movl {{[0-9]+}}(%esp), %ecx
; X86-SKX-NEXT: vpslld $4, (%ecx), %zmm2
-; X86-SKX-NEXT: vpbroadcastd %eax, %zmm0
-; X86-SKX-NEXT: vpaddd %zmm2, %zmm0, %zmm3
; X86-SKX-NEXT: kmovw %k1, %k2
; X86-SKX-NEXT: vmovaps %zmm1, %zmm0
; X86-SKX-NEXT: vgatherdps (%eax,%zmm2), %zmm0 {%k2}
-; X86-SKX-NEXT: vgatherdps 4(,%zmm3), %zmm1 {%k1}
+; X86-SKX-NEXT: vgatherdps 4(%eax,%zmm2), %zmm1 {%k1}
; X86-SKX-NEXT: retl
%wide.load = load <16 x i32>, ptr %arr, align 4
%and = and <16 x i32> %wide.load, <i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/175/builds/16680 Here is the relevant piece of the build log for the reference
|
…dices patterns and add the splat into the (scalar) base address (llvm#135201) We already did this for constant cases, this patch generalizes the existing fold to attempt to extract the splat from either operand of a ADD node for the gather/scatter index value This cleanup should also make it easier to add support for splitting vXi32 indices on x86_64 64-bit pointer targets in the future as well. Noticed while reviewing llvm#134979
We already did this for constant cases, this patch generalizes the existing fold to attempt to extract the splat from either operand of a ADD node for the gather/scatter index value
This cleanup should also make it easier to add support for splitting vXi32 indices on x86_64 64-bit pointer targets in the future as well.
Noticed while reviewing #134979
CC @rohitaggarwal007