Skip to content

Commit 5ca3a48

Browse files
committed
* Refactor to share isDUPQMask
* Support SME2p1 * Remove hardcoded magic number * Return the same result for other cost kinds
1 parent bc8af08 commit 5ca3a48

File tree

3 files changed

+39
-37
lines changed

3 files changed

+39
-37
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13391,30 +13391,6 @@ static bool isUZP_v_undef_Mask(ArrayRef<int> M, EVT VT, unsigned &WhichResult) {
1339113391
return true;
1339213392
}
1339313393

13394-
/// isDUPQMask - matches a splat of equivalent lanes within 128b segments in
13395-
/// the first vector operand.
13396-
static std::optional<unsigned> isDUPQMask(ArrayRef<int> M, EVT VT) {
13397-
assert(VT.getFixedSizeInBits() % 128 == 0 && "Unsupported SVE vector size");
13398-
unsigned Lane = (unsigned)M[0];
13399-
unsigned Segments = VT.getFixedSizeInBits() / 128;
13400-
unsigned SegmentElts = VT.getVectorNumElements() / Segments;
13401-
13402-
// Make sure there's no size changes.
13403-
if (SegmentElts * Segments != M.size())
13404-
return std::nullopt;
13405-
13406-
// Check the first index corresponds to one of the lanes in the first segment.
13407-
if (Lane >= SegmentElts)
13408-
return std::nullopt;
13409-
13410-
// Check that all lanes match the first, adjusted for segment.
13411-
for (unsigned I = 0; I < M.size(); ++I)
13412-
if ((unsigned)M[I] != (Lane + ((I / SegmentElts) * SegmentElts)))
13413-
return std::nullopt;
13414-
13415-
return Lane;
13416-
}
13417-
1341813394
/// isTRN_v_undef_Mask - Special case of isTRNMask for canonical form of
1341913395
/// "vector_shuffle v, v", i.e., "vector_shuffle v, undef".
1342013396
/// Mask is e.g., <0, 0, 2, 2> instead of <0, 4, 2, 6>.
@@ -30005,8 +29981,14 @@ SDValue AArch64TargetLowering::LowerFixedLengthVECTOR_SHUFFLEToSVE(
3000529981
DAG, VT, DAG.getNode(Opc, DL, ContainerVT, Op1, Op1));
3000629982
}
3000729983

30008-
if (Subtarget->hasSVE2p1()) {
30009-
if (std::optional<unsigned> Lane = isDUPQMask(ShuffleMask, VT)) {
29984+
if (Subtarget->hasSVE2p1() || Subtarget->hasSME2p1()) {
29985+
assert(VT.getFixedSizeInBits() % AArch64::SVEBitsPerBlock == 0 &&
29986+
"Unsupported SVE vector size");
29987+
29988+
unsigned Segments = VT.getFixedSizeInBits() / AArch64::SVEBitsPerBlock;
29989+
unsigned SegmentElts = VT.getVectorNumElements() / Segments;
29990+
if (std::optional<unsigned> Lane =
29991+
isDUPQMask(ShuffleMask, Segments, SegmentElts)) {
3001029992
SDValue IID =
3001129993
DAG.getConstant(Intrinsic::aarch64_sve_dup_laneq, DL, MVT::i64);
3001229994
return convertFromScalableVector(

llvm/lib/Target/AArch64/AArch64PerfectShuffle.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#define LLVM_LIB_TARGET_AARCH64_AARCH64PERFECTSHUFFLE_H
1616

1717
#include "llvm/ADT/ArrayRef.h"
18+
#include "llvm/ADT/STLExtras.h"
1819

1920
namespace llvm {
2021

@@ -6723,6 +6724,29 @@ inline bool isREVMask(ArrayRef<int> M, unsigned EltSize, unsigned NumElts,
67236724
return true;
67246725
}
67256726

6727+
/// isDUPQMask - matches a splat of equivalent lanes within segments of a given
6728+
/// number of elements.
6729+
inline std::optional<unsigned> isDUPQMask(ArrayRef<int> M, unsigned Segments,
6730+
unsigned NumElts) {
6731+
unsigned Lane = (unsigned)M[0];
6732+
6733+
// Make sure there's no size changes.
6734+
if (NumElts * Segments != M.size())
6735+
return std::nullopt;
6736+
6737+
// Check the first index corresponds to one of the lanes in the first segment.
6738+
if (Lane >= NumElts)
6739+
return std::nullopt;
6740+
6741+
// Check that all lanes match the first, adjusted for segment.
6742+
if (all_of(enumerate(M), [&](auto P) {
6743+
return (unsigned)P.value() == Lane + (P.index() / NumElts) * NumElts;
6744+
}))
6745+
return Lane;
6746+
6747+
return std::nullopt;
6748+
}
6749+
67266750
} // namespace llvm
67276751

67286752
#endif

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5584,23 +5584,19 @@ InstructionCost AArch64TTIImpl::getShuffleCost(
55845584
}
55855585

55865586
// Segmented shuffle matching.
5587-
if (ST->hasSVE2p1() && CostKind == TTI::TCK_RecipThroughput &&
5587+
if ((ST->hasSVE2p1() || ST->hasSME2p1()) &&
55885588
Kind == TTI::SK_PermuteSingleSrc && isa<FixedVectorType>(Tp) &&
5589-
Tp->getPrimitiveSizeInBits().isKnownMultipleOf(128)) {
5589+
Tp->getPrimitiveSizeInBits().isKnownMultipleOf(
5590+
AArch64::SVEBitsPerBlock)) {
55905591

55915592
FixedVectorType *VTy = cast<FixedVectorType>(Tp);
5592-
unsigned Segments = VTy->getPrimitiveSizeInBits() / 128;
5593+
unsigned Segments =
5594+
VTy->getPrimitiveSizeInBits() / AArch64::SVEBitsPerBlock;
55935595
unsigned SegmentElts = VTy->getNumElements() / Segments;
55945596

55955597
// dupq zd.t, zn.t[idx]
5596-
unsigned Lane = (unsigned)Mask[0];
5597-
if (SegmentElts * Segments == Mask.size() && Lane < SegmentElts) {
5598-
bool IsDupQ = true;
5599-
for (unsigned I = 1; I < Mask.size(); ++I)
5600-
IsDupQ &= (unsigned)Mask[I] == Lane + ((I / SegmentElts) * SegmentElts);
5601-
if (IsDupQ)
5602-
return LT.first;
5603-
}
5598+
if (isDUPQMask(Mask, Segments, SegmentElts))
5599+
return LT.first;
56045600
}
56055601

56065602
// Check for broadcast loads, which are supported by the LD1R instruction.

0 commit comments

Comments
 (0)