Skip to content

[X86] combineGatherScatter - split non-constant (add v, (splat b)) indices patterns and add the splat into the (scalar) base address #135201

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 11, 2025

Conversation

RKSimon
Copy link
Collaborator

@RKSimon RKSimon commented Apr 10, 2025

We already did this for constant cases, this patch generalizes the existing fold to attempt to extract the splat from either operand of a ADD node for the gather/scatter index value

This cleanup should also make it easier to add support for splitting vXi32 indices on x86_64 64-bit pointer targets in the future as well.

Noticed while reviewing #134979

CC @rohitaggarwal007

…dices patterns and add the splat into the (scalar) base address

We already did this for constant cases, this patch generalizes the existing fold to attempt to extract the splat from either operand of a ADD node for the gather/scatter index value

This cleanup should also make it easier to add support for splitting vXi32 indices on x86_64 64-bit pointer targets in the future as well.

Noticed while reviewing llvm#134979
@llvmbot
Copy link
Member

llvmbot commented Apr 10, 2025

@llvm/pr-subscribers-backend-x86

Author: Simon Pilgrim (RKSimon)

Changes

We already did this for constant cases, this patch generalizes the existing fold to attempt to extract the splat from either operand of a ADD node for the gather/scatter index value

This cleanup should also make it easier to add support for splitting vXi32 indices on x86_64 64-bit pointer targets in the future as well.

Noticed while reviewing #134979

CC @rohitaggarwal007


Full diff: https://github.com/llvm/llvm-project/pull/135201.diff

2 Files Affected:

  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+40-29)
  • (modified) llvm/test/CodeGen/X86/masked_gather_scatter.ll (+2-6)
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index a3c423270f44a..77808608045f9 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -56521,6 +56521,7 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
   SDValue Base = GorS->getBasePtr();
   SDValue Scale = GorS->getScale();
   EVT IndexVT = Index.getValueType();
+  EVT IndexSVT = IndexVT.getVectorElementType();
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
 
   if (DCI.isBeforeLegalize()) {
@@ -56557,41 +56558,51 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
   }
 
   EVT PtrVT = TLI.getPointerTy(DAG.getDataLayout());
-  // Try to move splat constant adders from the index operand to the base
+
+  // Try to move splat adders from the index operand to the base
   // pointer operand. Taking care to multiply by the scale. We can only do
   // this when index element type is the same as the pointer type.
   // Otherwise we need to be sure the math doesn't wrap before the scale.
-  if (Index.getOpcode() == ISD::ADD &&
-      IndexVT.getVectorElementType() == PtrVT && isa<ConstantSDNode>(Scale)) {
+  if (Index.getOpcode() == ISD::ADD && IndexSVT == PtrVT &&
+      isa<ConstantSDNode>(Scale)) {
     uint64_t ScaleAmt = Scale->getAsZExtVal();
-    if (auto *BV = dyn_cast<BuildVectorSDNode>(Index.getOperand(1))) {
-      BitVector UndefElts;
-      if (ConstantSDNode *C = BV->getConstantSplatNode(&UndefElts)) {
-        // FIXME: Allow non-constant?
-        if (UndefElts.none()) {
-          // Apply the scale.
-          APInt Adder = C->getAPIntValue() * ScaleAmt;
-          // Add it to the existing base.
-          Base = DAG.getNode(ISD::ADD, DL, PtrVT, Base,
-                             DAG.getConstant(Adder, DL, PtrVT));
-          Index = Index.getOperand(0);
-          return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
-        }
-      }
 
-      // It's also possible base is just a constant. In that case, just
-      // replace it with 0 and move the displacement into the index.
-      if (BV->isConstant() && isa<ConstantSDNode>(Base) &&
-          isOneConstant(Scale)) {
-        SDValue Splat = DAG.getSplatBuildVector(IndexVT, DL, Base);
-        // Combine the constant build_vector and the constant base.
-        Splat = DAG.getNode(ISD::ADD, DL, IndexVT, Index.getOperand(1), Splat);
-        // Add to the LHS of the original Index add.
-        Index = DAG.getNode(ISD::ADD, DL, IndexVT, Index.getOperand(0), Splat);
-        Base = DAG.getConstant(0, DL, Base.getValueType());
-        return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
+    for (unsigned I = 0; I != 2; ++I)
+      if (auto *BV = dyn_cast<BuildVectorSDNode>(Index.getOperand(I))) {
+        BitVector UndefElts;
+        if (SDValue Splat = BV->getSplatValue(&UndefElts)) {
+          if (UndefElts.none()) {
+            // If the splat value is constant we can add the scaled splat value
+            // to the existing base.
+            if (auto *C = dyn_cast<ConstantSDNode>(Splat)) {
+              APInt Adder = C->getAPIntValue() * ScaleAmt;
+              SDValue NewBase = DAG.getNode(ISD::ADD, DL, PtrVT, Base,
+                                            DAG.getConstant(Adder, DL, PtrVT));
+              SDValue NewIndex = Index.getOperand(1 - I);
+              return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG);
+            }
+            // For non-constant cases, limit this to non-scaled cases.
+            if (ScaleAmt == 1) {
+              SDValue NewBase = DAG.getNode(ISD::ADD, DL, PtrVT, Base, Splat);
+              SDValue NewIndex = Index.getOperand(1 - I);
+              return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG);
+            }
+          }
+        }
+        // It's also possible base is just a constant. In that case, just
+        // replace it with 0 and move the displacement into the index.
+        if (ScaleAmt == 1 && BV->isConstant() && isa<ConstantSDNode>(Base)) {
+          SDValue Splat = DAG.getSplatBuildVector(IndexVT, DL, Base);
+          // Combine the constant build_vector and the constant base.
+          Splat =
+              DAG.getNode(ISD::ADD, DL, IndexVT, Index.getOperand(I), Splat);
+          // Add to the other half of the original Index add.
+          SDValue NewIndex = DAG.getNode(ISD::ADD, DL, IndexVT,
+                                         Index.getOperand(1 - I), Splat);
+          SDValue NewBase = DAG.getConstant(0, DL, PtrVT);
+          return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG);
+        }
       }
-    }
   }
 
   if (DCI.isBeforeLegalizeOps()) {
diff --git a/llvm/test/CodeGen/X86/masked_gather_scatter.ll b/llvm/test/CodeGen/X86/masked_gather_scatter.ll
index 5effb18fb6aa6..46e589b7b1be9 100644
--- a/llvm/test/CodeGen/X86/masked_gather_scatter.ll
+++ b/llvm/test/CodeGen/X86/masked_gather_scatter.ll
@@ -5028,12 +5028,10 @@ define {<16 x float>, <16 x float>} @test_gather_16f32_mask_index_pair(ptr %x, p
 ; X86-KNL-NEXT:    movl {{[0-9]+}}(%esp), %eax
 ; X86-KNL-NEXT:    movl {{[0-9]+}}(%esp), %ecx
 ; X86-KNL-NEXT:    vpslld $4, (%ecx), %zmm2
-; X86-KNL-NEXT:    vpbroadcastd %eax, %zmm0
-; X86-KNL-NEXT:    vpaddd %zmm2, %zmm0, %zmm3
 ; X86-KNL-NEXT:    kmovw %k1, %k2
 ; X86-KNL-NEXT:    vmovaps %zmm1, %zmm0
 ; X86-KNL-NEXT:    vgatherdps (%eax,%zmm2), %zmm0 {%k2}
-; X86-KNL-NEXT:    vgatherdps 4(,%zmm3), %zmm1 {%k1}
+; X86-KNL-NEXT:    vgatherdps 4(%eax,%zmm2), %zmm1 {%k1}
 ; X86-KNL-NEXT:    retl
 ;
 ; X64-SKX-SMALL-LABEL: test_gather_16f32_mask_index_pair:
@@ -5097,12 +5095,10 @@ define {<16 x float>, <16 x float>} @test_gather_16f32_mask_index_pair(ptr %x, p
 ; X86-SKX-NEXT:    movl {{[0-9]+}}(%esp), %eax
 ; X86-SKX-NEXT:    movl {{[0-9]+}}(%esp), %ecx
 ; X86-SKX-NEXT:    vpslld $4, (%ecx), %zmm2
-; X86-SKX-NEXT:    vpbroadcastd %eax, %zmm0
-; X86-SKX-NEXT:    vpaddd %zmm2, %zmm0, %zmm3
 ; X86-SKX-NEXT:    kmovw %k1, %k2
 ; X86-SKX-NEXT:    vmovaps %zmm1, %zmm0
 ; X86-SKX-NEXT:    vgatherdps (%eax,%zmm2), %zmm0 {%k2}
-; X86-SKX-NEXT:    vgatherdps 4(,%zmm3), %zmm1 {%k1}
+; X86-SKX-NEXT:    vgatherdps 4(%eax,%zmm2), %zmm1 {%k1}
 ; X86-SKX-NEXT:    retl
   %wide.load = load <16 x i32>, ptr %arr, align 4
   %and = and <16 x i32> %wide.load, <i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911, i32 536870911>

Copy link
Contributor

@phoebewang phoebewang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@RKSimon RKSimon merged commit 937cfdc into llvm:main Apr 11, 2025
7 of 11 checks passed
@RKSimon RKSimon deleted the x86-gather-base-splat branch April 11, 2025 09:53
@llvm-ci
Copy link
Collaborator

llvm-ci commented Apr 11, 2025

LLVM Buildbot has detected a new failure on builder ml-opt-devrel-x86-64 running on ml-opt-devrel-x86-64-b1 while building llvm at step 6 "test-build-unified-tree-check-all".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/175/builds/16680

Here is the relevant piece of the build log for the reference
Step 6 (test-build-unified-tree-check-all) failure: test (failure)
******************** TEST 'LLVM :: Transforms/SLPVectorizer/X86/revec-SplitVectorize.ll' FAILED ********************
Exit Code: 1

Command Output (stderr):
--
/b/ml-opt-devrel-x86-64-b1/build/bin/opt -mtriple=x86_64-unknown-linux-gnu -mattr=+avx -passes=slp-vectorizer -S -slp-revec < /b/ml-opt-devrel-x86-64-b1/llvm-project/llvm/test/Transforms/SLPVectorizer/X86/revec-SplitVectorize.ll | /b/ml-opt-devrel-x86-64-b1/build/bin/FileCheck /b/ml-opt-devrel-x86-64-b1/llvm-project/llvm/test/Transforms/SLPVectorizer/X86/revec-SplitVectorize.ll # RUN: at line 2
+ /b/ml-opt-devrel-x86-64-b1/build/bin/opt -mtriple=x86_64-unknown-linux-gnu -mattr=+avx -passes=slp-vectorizer -S -slp-revec
+ /b/ml-opt-devrel-x86-64-b1/build/bin/FileCheck /b/ml-opt-devrel-x86-64-b1/llvm-project/llvm/test/Transforms/SLPVectorizer/X86/revec-SplitVectorize.ll
/b/ml-opt-devrel-x86-64-b1/llvm-project/llvm/test/Transforms/SLPVectorizer/X86/revec-SplitVectorize.ll:16:15: error: CHECK-NEXT: expected string not found in input
; CHECK-NEXT: [[TMP9:%.*]] = call <16 x i1> @llvm.vector.insert.v16i1.v4i1(<16 x i1> poison, <4 x i1> zeroinitializer, i64 0)
              ^
<stdin>:16:39: note: scanning from here
 %8 = trunc <32 x i32> %7 to <32 x i1>
                                      ^
<stdin>:29:2: note: possible intended match here
 %14 = call <32 x i1> @llvm.vector.insert.v32i1.v4i1(<32 x i1> %13, <4 x i1> zeroinitializer, i64 16)
 ^

Input file: <stdin>
Check file: /b/ml-opt-devrel-x86-64-b1/llvm-project/llvm/test/Transforms/SLPVectorizer/X86/revec-SplitVectorize.ll

-dump-input=help explains the following input dump.

Input was:
<<<<<<
           .
           .
           .
          11:  %3 = call <32 x i32> @llvm.vector.insert.v32i32.v4i32(<32 x i32> %2, <4 x i32> zeroinitializer, i64 12) 
          12:  %4 = call <32 x i32> @llvm.vector.insert.v32i32.v4i32(<32 x i32> %3, <4 x i32> zeroinitializer, i64 16) 
          13:  %5 = call <32 x i32> @llvm.vector.insert.v32i32.v4i32(<32 x i32> %4, <4 x i32> zeroinitializer, i64 20) 
          14:  %6 = call <32 x i32> @llvm.vector.insert.v32i32.v4i32(<32 x i32> %5, <4 x i32> zeroinitializer, i64 24) 
          15:  %7 = call <32 x i32> @llvm.vector.insert.v32i32.v4i32(<32 x i32> %6, <4 x i32> zeroinitializer, i64 28) 
          16:  %8 = trunc <32 x i32> %7 to <32 x i1> 
next:16'0                                           X error: no match found
          17:  br label %vector.body 
next:16'0     ~~~~~~~~~~~~~~~~~~~~~~~
          18:  
next:16'0     ~
          19: vector.body: ; preds = %vector.body, %entry 
next:16'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
          20:  %9 = phi <32 x i1> [ %8, %entry ], [ %18, %vector.body ] 
next:16'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
          21:  %narrow = select <4 x i1> zeroinitializer, <4 x i1> zeroinitializer, <4 x i1> zeroinitializer 
next:16'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
           .
           .
           .
          24:  %narrow68 = select <4 x i1> zeroinitializer, <4 x i1> zeroinitializer, <4 x i1> zeroinitializer 
next:16'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...

var-const pushed a commit to ldionne/llvm-project that referenced this pull request Apr 17, 2025
…dices patterns and add the splat into the (scalar) base address (llvm#135201)

We already did this for constant cases, this patch generalizes the existing fold to attempt to extract the splat from either operand of a ADD node for the gather/scatter index value

This cleanup should also make it easier to add support for splitting vXi32 indices on x86_64 64-bit pointer targets in the future as well.

Noticed while reviewing llvm#134979
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants