@@ -129,6 +129,7 @@ class VectorCombine {
129
129
bool foldShuffleOfIntrinsics (Instruction &I);
130
130
bool foldShuffleToIdentity (Instruction &I);
131
131
bool foldShuffleFromReductions (Instruction &I);
132
+ bool foldShuffleChainsToReduce (Instruction &I);
132
133
bool foldCastFromReductions (Instruction &I);
133
134
bool foldSelectShuffle (Instruction &I, bool FromReduction = false );
134
135
bool foldInterleaveIntrinsics (Instruction &I);
@@ -2910,6 +2911,130 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
2910
2911
return foldSelectShuffle (*Shuffle, true );
2911
2912
}
2912
2913
2914
+ bool VectorCombine::foldShuffleChainsToReduce (Instruction &I) {
2915
+ auto *SVI = dyn_cast<ShuffleVectorInst>(&I);
2916
+ if (!SVI)
2917
+ return false ;
2918
+
2919
+ std::queue<Value *> Worklist;
2920
+ SmallVector<Instruction *> ToEraseFromParent;
2921
+
2922
+ SmallVector<int > ShuffleMask;
2923
+ bool IsShuffleOp = true ;
2924
+
2925
+ Worklist.push (SVI);
2926
+ SVI->getShuffleMask (ShuffleMask);
2927
+
2928
+ if (ShuffleMask.size () < 2 )
2929
+ return false ;
2930
+
2931
+ Instruction *Prev0 = nullptr , *Prev1 = nullptr ;
2932
+ Instruction *LastOp = nullptr ;
2933
+
2934
+ int MaskHalfPos = ShuffleMask.size () / 2 ;
2935
+ bool IsFirst = true ;
2936
+
2937
+ while (!Worklist.empty ()) {
2938
+ Value *V = Worklist.front ();
2939
+ Worklist.pop ();
2940
+
2941
+ auto *CI = dyn_cast<Instruction>(V);
2942
+ if (!CI)
2943
+ return false ;
2944
+
2945
+ if (auto *SV = dyn_cast<ShuffleVectorInst>(V)) {
2946
+ if (!IsShuffleOp || MaskHalfPos < 1 || (!Prev1 && !IsFirst))
2947
+ return false ;
2948
+
2949
+ auto *Op0 = SV->getOperand (0 );
2950
+ auto *Op1 = SV->getOperand (1 );
2951
+ if (!Op0 || !Op1)
2952
+ return false ;
2953
+
2954
+ auto *FVT = dyn_cast<FixedVectorType>(Op1->getType ());
2955
+ if (!FVT || !isa<PoisonValue>(Op1))
2956
+ return false ;
2957
+
2958
+ SmallVector<int > CurrentMask;
2959
+ SV->getShuffleMask (CurrentMask);
2960
+
2961
+ int64_t MaskSize = CurrentMask.size ();
2962
+ for (int MaskPos = 0 ; MaskPos != MaskSize; ++MaskPos) {
2963
+ if (MaskPos < MaskHalfPos && CurrentMask[MaskPos] != MaskHalfPos + MaskPos)
2964
+ return false ;
2965
+ if (MaskPos >= MaskHalfPos && CurrentMask[MaskPos] != -1 )
2966
+ return false ;
2967
+ }
2968
+ MaskHalfPos /= 2 ;
2969
+ Prev0 = SV;
2970
+ } else if (auto *Call = dyn_cast<CallInst>(V)) {
2971
+ if (IsShuffleOp || !Prev0)
2972
+ return false ;
2973
+
2974
+ auto *II = dyn_cast<IntrinsicInst>(Call);
2975
+ if (!II)
2976
+ return false ;
2977
+
2978
+ switch (II->getIntrinsicID ()) {
2979
+ case Intrinsic::umin: {
2980
+ auto *Op0 = Call->getOperand (0 );
2981
+ auto *Op1 = Call->getOperand (1 );
2982
+ if (!(Op0 == Prev0 && Op1 == Prev1) && !(Op0 == Prev1 && Op1 == Prev0) && !IsFirst)
2983
+ return false ;
2984
+
2985
+ if (!IsFirst)
2986
+ Prev0 = Prev1;
2987
+ else
2988
+ IsFirst = false ;
2989
+ Prev1 = Call;
2990
+ break ;
2991
+ }
2992
+ default :
2993
+ return false ;
2994
+ }
2995
+ } else if (auto *ExtractElement = dyn_cast<ExtractElementInst>(CI)) {
2996
+ if (!IsShuffleOp || !Prev0 || !Prev1 || MaskHalfPos != 0 )
2997
+ return false ;
2998
+
2999
+ auto *Op0 = ExtractElement->getOperand (0 );
3000
+ auto *Op1 = ExtractElement->getOperand (1 );
3001
+ if (Op0 != Prev1)
3002
+ return false ;
3003
+
3004
+ if (auto *Op1Idx = dyn_cast<ConstantInt>(Op1)) {
3005
+ if (Op1Idx->getValue () != 0 )
3006
+ return false ;
3007
+ } else {
3008
+ return false ;
3009
+ }
3010
+ LastOp = ExtractElement;
3011
+ break ;
3012
+ }
3013
+ IsShuffleOp ^= 1 ;
3014
+ ToEraseFromParent.push_back (CI);
3015
+
3016
+ auto *NextI = CI->getNextNode ();
3017
+ if (!NextI)
3018
+ return false ;
3019
+ Worklist.push (NextI);
3020
+ }
3021
+
3022
+ if (!LastOp)
3023
+ return false ;
3024
+
3025
+ auto *ReducedResult = Builder.CreateIntrinsic (Intrinsic::vector_reduce_umin, {SVI->getType ()}, {SVI->getOperand (0 )});
3026
+ replaceValue (*LastOp, *ReducedResult);
3027
+
3028
+ ToEraseFromParent.push_back (LastOp);
3029
+
3030
+ std::reverse (ToEraseFromParent.begin (), ToEraseFromParent.end ());
3031
+ // for (auto &Instr : ToEraseFromParent)
3032
+ // eraseInstruction(*Instr);
3033
+ // Instr->eraseFromParent();
3034
+
3035
+ return true ;
3036
+ }
3037
+
2913
3038
// / Determine if its more efficient to fold:
2914
3039
// / reduce(trunc(x)) -> trunc(reduce(x)).
2915
3040
// / reduce(sext(x)) -> sext(reduce(x)).
@@ -3607,6 +3732,7 @@ bool VectorCombine::run() {
3607
3732
MadeChange |= foldShuffleOfIntrinsics (I);
3608
3733
MadeChange |= foldSelectShuffle (I);
3609
3734
MadeChange |= foldShuffleToIdentity (I);
3735
+ MadeChange |= foldShuffleChainsToReduce (I);
3610
3736
break ;
3611
3737
case Instruction::BitCast:
3612
3738
MadeChange |= foldBitcastShuffle (I);
0 commit comments