@@ -253,6 +253,21 @@ static FixedVectorType *getWidenedType(Type *ScalarTy, unsigned VF) {
253253 VF * getNumElements(ScalarTy));
254254}
255255
256+ static void transformScalarShuffleIndiciesToVector(unsigned VecTyNumElements,
257+ SmallVectorImpl<int> &Mask) {
258+ // The ShuffleBuilder implementation use shufflevector to splat an "element".
259+ // But the element have different meaning for SLP (scalar) and REVEC
260+ // (vector). We need to expand Mask into masks which shufflevector can use
261+ // directly.
262+ SmallVector<int> NewMask(Mask.size() * VecTyNumElements);
263+ for (unsigned I : seq<unsigned>(Mask.size()))
264+ for (auto [J, MaskV] : enumerate(MutableArrayRef(NewMask).slice(
265+ I * VecTyNumElements, VecTyNumElements)))
266+ MaskV = Mask[I] == PoisonMaskElem ? PoisonMaskElem
267+ : Mask[I] * VecTyNumElements + J;
268+ Mask.swap(NewMask);
269+ }
270+
256271/// \returns True if the value is a constant (but not globals/constant
257272/// expressions).
258273static bool isConstant(Value *V) {
@@ -7772,6 +7787,31 @@ namespace {
77727787/// The base class for shuffle instruction emission and shuffle cost estimation.
77737788class BaseShuffleAnalysis {
77747789protected:
7790+ Type *ScalarTy = nullptr;
7791+
7792+ BaseShuffleAnalysis(Type *ScalarTy) : ScalarTy(ScalarTy) {}
7793+
7794+ /// V is expected to be a vectorized value.
7795+ /// When REVEC is disabled, there is no difference between VF and
7796+ /// VNumElements.
7797+ /// When REVEC is enabled, VF is VNumElements / ScalarTyNumElements.
7798+ /// e.g., if ScalarTy is <4 x Ty> and V1 is <8 x Ty>, 2 is returned instead
7799+ /// of 8.
7800+ unsigned getVF(Value *V) const {
7801+ assert(V && "V cannot be nullptr");
7802+ assert(isa<FixedVectorType>(V->getType()) &&
7803+ "V does not have FixedVectorType");
7804+ assert(ScalarTy && "ScalarTy cannot be nullptr");
7805+ unsigned ScalarTyNumElements = getNumElements(ScalarTy);
7806+ unsigned VNumElements =
7807+ cast<FixedVectorType>(V->getType())->getNumElements();
7808+ assert(VNumElements > ScalarTyNumElements &&
7809+ "the number of elements of V is not large enough");
7810+ assert(VNumElements % ScalarTyNumElements == 0 &&
7811+ "the number of elements of V is not a vectorized value");
7812+ return VNumElements / ScalarTyNumElements;
7813+ }
7814+
77757815 /// Checks if the mask is an identity mask.
77767816 /// \param IsStrict if is true the function returns false if mask size does
77777817 /// not match vector size.
@@ -8265,7 +8305,6 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
82658305 bool IsFinalized = false;
82668306 SmallVector<int> CommonMask;
82678307 SmallVector<PointerUnion<Value *, const TreeEntry *>, 2> InVectors;
8268- Type *ScalarTy = nullptr;
82698308 const TargetTransformInfo &TTI;
82708309 InstructionCost Cost = 0;
82718310 SmallDenseSet<Value *> VectorizedVals;
@@ -8847,14 +8886,14 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
88478886 } else if (V1 && P2.isNull()) {
88488887 // Shuffle single vector.
88498888 ExtraCost += GetValueMinBWAffectedCost(V1);
8850- CommonVF = cast<FixedVectorType> (V1->getType())->getNumElements( );
8889+ CommonVF = getVF (V1);
88518890 assert(
88528891 all_of(Mask,
88538892 [=](int Idx) { return Idx < static_cast<int>(CommonVF); }) &&
88548893 "All elements in mask must be less than CommonVF.");
88558894 } else if (V1 && !V2) {
88568895 // Shuffle vector and tree node.
8857- unsigned VF = cast<FixedVectorType> (V1->getType())->getNumElements( );
8896+ unsigned VF = getVF (V1);
88588897 const TreeEntry *E2 = P2.get<const TreeEntry *>();
88598898 CommonVF = std::max(VF, E2->getVectorFactor());
88608899 assert(all_of(Mask,
@@ -8880,7 +8919,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
88808919 V2 = getAllOnesValue(*R.DL, getWidenedType(ScalarTy, CommonVF));
88818920 } else if (!V1 && V2) {
88828921 // Shuffle vector and tree node.
8883- unsigned VF = cast<FixedVectorType> (V2->getType())->getNumElements( );
8922+ unsigned VF = getVF (V2);
88848923 const TreeEntry *E1 = P1.get<const TreeEntry *>();
88858924 CommonVF = std::max(VF, E1->getVectorFactor());
88868925 assert(all_of(Mask,
@@ -8908,9 +8947,8 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
89088947 V2 = getAllOnesValue(*R.DL, getWidenedType(ScalarTy, CommonVF));
89098948 } else {
89108949 assert(V1 && V2 && "Expected both vectors.");
8911- unsigned VF = cast<FixedVectorType>(V1->getType())->getNumElements();
8912- CommonVF =
8913- std::max(VF, cast<FixedVectorType>(V2->getType())->getNumElements());
8950+ unsigned VF = getVF(V1);
8951+ CommonVF = std::max(VF, getVF(V2));
89148952 assert(all_of(Mask,
89158953 [=](int Idx) {
89168954 return Idx < 2 * static_cast<int>(CommonVF);
@@ -8928,6 +8966,11 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
89288966 V2 = getAllOnesValue(*R.DL, getWidenedType(ScalarTy, CommonVF));
89298967 }
89308968 }
8969+ if (auto *VecTy = dyn_cast<FixedVectorType>(ScalarTy)) {
8970+ assert(SLPReVec && "FixedVectorType is not expected.");
8971+ transformScalarShuffleIndiciesToVector(VecTy->getNumElements(),
8972+ CommonMask);
8973+ }
89318974 InVectors.front() =
89328975 Constant::getNullValue(getWidenedType(ScalarTy, CommonMask.size()));
89338976 if (InVectors.size() == 2)
@@ -8940,7 +8983,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
89408983 ShuffleCostEstimator(Type *ScalarTy, TargetTransformInfo &TTI,
89418984 ArrayRef<Value *> VectorizedVals, BoUpSLP &R,
89428985 SmallPtrSetImpl<Value *> &CheckedExtracts)
8943- : ScalarTy (ScalarTy), TTI(TTI),
8986+ : BaseShuffleAnalysis (ScalarTy), TTI(TTI),
89448987 VectorizedVals(VectorizedVals.begin(), VectorizedVals.end()), R(R),
89458988 CheckedExtracts(CheckedExtracts) {}
89468989 Value *adjustExtracts(const TreeEntry *E, MutableArrayRef<int> Mask,
@@ -9145,7 +9188,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
91459188 }
91469189 assert(!InVectors.empty() && !CommonMask.empty() &&
91479190 "Expected only tree entries from extracts/reused buildvectors.");
9148- unsigned VF = cast<FixedVectorType> (V1->getType())->getNumElements( );
9191+ unsigned VF = getVF (V1);
91499192 if (InVectors.size() == 2) {
91509193 Cost += createShuffle(InVectors.front(), InVectors.back(), CommonMask);
91519194 transformMaskAfterShuffle(CommonMask, CommonMask);
@@ -9179,12 +9222,32 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
91799222 }
91809223 Vals.push_back(Constant::getNullValue(V->getType()));
91819224 }
9225+ if (auto *VecTy = dyn_cast<FixedVectorType>(Vals.front()->getType())) {
9226+ assert(SLPReVec && "FixedVectorType is not expected.");
9227+ // When REVEC is enabled, we need to expand vector types into scalar
9228+ // types.
9229+ unsigned VecTyNumElements = VecTy->getNumElements();
9230+ SmallVector<Constant *> NewVals(VF * VecTyNumElements, nullptr);
9231+ for (auto [I, V] : enumerate(Vals)) {
9232+ Type *ScalarTy = V->getType()->getScalarType();
9233+ Constant *NewVal;
9234+ if (isa<PoisonValue>(V))
9235+ NewVal = PoisonValue::get(ScalarTy);
9236+ else if (isa<UndefValue>(V))
9237+ NewVal = UndefValue::get(ScalarTy);
9238+ else
9239+ NewVal = Constant::getNullValue(ScalarTy);
9240+ std::fill_n(NewVals.begin() + I * VecTyNumElements, VecTyNumElements,
9241+ NewVal);
9242+ }
9243+ Vals.swap(NewVals);
9244+ }
91829245 return ConstantVector::get(Vals);
91839246 }
91849247 return ConstantVector::getSplat(
91859248 ElementCount::getFixed(
91869249 cast<FixedVectorType>(Root->getType())->getNumElements()),
9187- getAllOnesValue(*R.DL, ScalarTy));
9250+ getAllOnesValue(*R.DL, ScalarTy->getScalarType() ));
91889251 }
91899252 InstructionCost createFreeze(InstructionCost Cost) { return Cost; }
91909253 /// Finalize emission of the shuffles.
@@ -11685,8 +11748,8 @@ Value *BoUpSLP::gather(ArrayRef<Value *> VL, Value *Root, Type *ScalarTy) {
1168511748 Type *Ty) {
1168611749 Value *Scalar = V;
1168711750 if (Scalar->getType() != Ty) {
11688- assert(Scalar->getType()->isIntegerTy() && Ty->isIntegerTy () &&
11689- "Expected integer types only.");
11751+ assert(Scalar->getType()->isIntOrIntVectorTy () &&
11752+ Ty->isIntOrIntVectorTy() && "Expected integer types only.");
1169011753 Value *V = Scalar;
1169111754 if (auto *CI = dyn_cast<CastInst>(Scalar);
1169211755 isa_and_nonnull<SExtInst, ZExtInst>(CI)) {
@@ -11699,10 +11762,21 @@ Value *BoUpSLP::gather(ArrayRef<Value *> VL, Value *Root, Type *ScalarTy) {
1169911762 V, Ty, !isKnownNonNegative(Scalar, SimplifyQuery(*DL)));
1170011763 }
1170111764
11702- Vec = Builder.CreateInsertElement(Vec, Scalar, Builder.getInt32(Pos));
11703- auto *InsElt = dyn_cast<InsertElementInst>(Vec);
11704- if (!InsElt)
11705- return Vec;
11765+ Instruction *InsElt;
11766+ if (auto *VecTy = dyn_cast<FixedVectorType>(Scalar->getType())) {
11767+ assert(SLPReVec && "FixedVectorType is not expected.");
11768+ Vec = InsElt = Builder.CreateInsertVector(
11769+ Vec->getType(), Vec, V,
11770+ Builder.getInt64(Pos * VecTy->getNumElements()));
11771+ auto *II = dyn_cast<IntrinsicInst>(InsElt);
11772+ if (!II || II->getIntrinsicID() != Intrinsic::vector_insert)
11773+ return Vec;
11774+ } else {
11775+ Vec = Builder.CreateInsertElement(Vec, Scalar, Builder.getInt32(Pos));
11776+ InsElt = dyn_cast<InsertElementInst>(Vec);
11777+ if (!InsElt)
11778+ return Vec;
11779+ }
1170611780 GatherShuffleExtractSeq.insert(InsElt);
1170711781 CSEBlocks.insert(InsElt->getParent());
1170811782 // Add to our 'need-to-extract' list.
@@ -11803,7 +11877,6 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
1180311877 /// resulting shuffle and the second operand sets to be the newly added
1180411878 /// operand. The \p CommonMask is transformed in the proper way after that.
1180511879 SmallVector<Value *, 2> InVectors;
11806- Type *ScalarTy = nullptr;
1180711880 IRBuilderBase &Builder;
1180811881 BoUpSLP &R;
1180911882
@@ -11929,7 +12002,7 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
1192912002
1193012003public:
1193112004 ShuffleInstructionBuilder(Type *ScalarTy, IRBuilderBase &Builder, BoUpSLP &R)
11932- : ScalarTy (ScalarTy), Builder(Builder), R(R) {}
12005+ : BaseShuffleAnalysis (ScalarTy), Builder(Builder), R(R) {}
1193312006
1193412007 /// Adjusts extractelements after reusing them.
1193512008 Value *adjustExtracts(const TreeEntry *E, MutableArrayRef<int> Mask,
@@ -12186,7 +12259,7 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
1218612259 break;
1218712260 }
1218812261 }
12189- int VF = cast<FixedVectorType> (V1->getType())->getNumElements( );
12262+ int VF = getVF (V1);
1219012263 for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx)
1219112264 if (Mask[Idx] != PoisonMaskElem && CommonMask[Idx] == PoisonMaskElem)
1219212265 CommonMask[Idx] = Mask[Idx] + (It == InVectors.begin() ? 0 : VF);
@@ -12209,6 +12282,15 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
1220912282 finalize(ArrayRef<int> ExtMask, unsigned VF = 0,
1221012283 function_ref<void(Value *&, SmallVectorImpl<int> &)> Action = {}) {
1221112284 IsFinalized = true;
12285+ SmallVector<int> NewExtMask(ExtMask);
12286+ if (auto *VecTy = dyn_cast<FixedVectorType>(ScalarTy)) {
12287+ assert(SLPReVec && "FixedVectorType is not expected.");
12288+ transformScalarShuffleIndiciesToVector(VecTy->getNumElements(),
12289+ CommonMask);
12290+ transformScalarShuffleIndiciesToVector(VecTy->getNumElements(),
12291+ NewExtMask);
12292+ ExtMask = NewExtMask;
12293+ }
1221212294 if (Action) {
1221312295 Value *Vec = InVectors.front();
1221412296 if (InVectors.size() == 2) {
@@ -13992,6 +14074,17 @@ Value *BoUpSLP::vectorizeTree(
1399214074 if (GEP->hasName())
1399314075 CloneGEP->takeName(GEP);
1399414076 Ex = CloneGEP;
14077+ } else if (auto *VecTy =
14078+ dyn_cast<FixedVectorType>(Scalar->getType())) {
14079+ assert(SLPReVec && "FixedVectorType is not expected.");
14080+ unsigned VecTyNumElements = VecTy->getNumElements();
14081+ // When REVEC is enabled, we need to extract a vector.
14082+ // Note: The element size of Scalar may be different from the
14083+ // element size of Vec.
14084+ Ex = Builder.CreateExtractVector(
14085+ FixedVectorType::get(Vec->getType()->getScalarType(),
14086+ VecTyNumElements),
14087+ Vec, Builder.getInt64(ExternalUse.Lane * VecTyNumElements));
1399514088 } else {
1399614089 Ex = Builder.CreateExtractElement(Vec, Lane);
1399714090 }
0 commit comments