Skip to content

Commit 542faf0

Browse files
committed
* Refactor to share isDUPQMask
* Support SME2p1 * Remove hardcoded magic number * Return the same result for other cost kinds
1 parent e160c70 commit 542faf0

File tree

3 files changed

+39
-37
lines changed

3 files changed

+39
-37
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13402,30 +13402,6 @@ static bool isUZP_v_undef_Mask(ArrayRef<int> M, EVT VT, unsigned &WhichResult) {
1340213402
return true;
1340313403
}
1340413404

13405-
/// isDUPQMask - matches a splat of equivalent lanes within 128b segments in
13406-
/// the first vector operand.
13407-
static std::optional<unsigned> isDUPQMask(ArrayRef<int> M, EVT VT) {
13408-
assert(VT.getFixedSizeInBits() % 128 == 0 && "Unsupported SVE vector size");
13409-
unsigned Lane = (unsigned)M[0];
13410-
unsigned Segments = VT.getFixedSizeInBits() / 128;
13411-
unsigned SegmentElts = VT.getVectorNumElements() / Segments;
13412-
13413-
// Make sure there's no size changes.
13414-
if (SegmentElts * Segments != M.size())
13415-
return std::nullopt;
13416-
13417-
// Check the first index corresponds to one of the lanes in the first segment.
13418-
if (Lane >= SegmentElts)
13419-
return std::nullopt;
13420-
13421-
// Check that all lanes match the first, adjusted for segment.
13422-
for (unsigned I = 0; I < M.size(); ++I)
13423-
if ((unsigned)M[I] != (Lane + ((I / SegmentElts) * SegmentElts)))
13424-
return std::nullopt;
13425-
13426-
return Lane;
13427-
}
13428-
1342913405
/// isTRN_v_undef_Mask - Special case of isTRNMask for canonical form of
1343013406
/// "vector_shuffle v, v", i.e., "vector_shuffle v, undef".
1343113407
/// Mask is e.g., <0, 0, 2, 2> instead of <0, 4, 2, 6>.
@@ -30026,8 +30002,14 @@ SDValue AArch64TargetLowering::LowerFixedLengthVECTOR_SHUFFLEToSVE(
3002630002
DAG, VT, DAG.getNode(Opc, DL, ContainerVT, Op1, Op1));
3002730003
}
3002830004

30029-
if (Subtarget->hasSVE2p1()) {
30030-
if (std::optional<unsigned> Lane = isDUPQMask(ShuffleMask, VT)) {
30005+
if (Subtarget->hasSVE2p1() || Subtarget->hasSME2p1()) {
30006+
assert(VT.getFixedSizeInBits() % AArch64::SVEBitsPerBlock == 0 &&
30007+
"Unsupported SVE vector size");
30008+
30009+
unsigned Segments = VT.getFixedSizeInBits() / AArch64::SVEBitsPerBlock;
30010+
unsigned SegmentElts = VT.getVectorNumElements() / Segments;
30011+
if (std::optional<unsigned> Lane =
30012+
isDUPQMask(ShuffleMask, Segments, SegmentElts)) {
3003130013
SDValue IID =
3003230014
DAG.getConstant(Intrinsic::aarch64_sve_dup_laneq, DL, MVT::i64);
3003330015
return convertFromScalableVector(

llvm/lib/Target/AArch64/AArch64PerfectShuffle.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#define LLVM_LIB_TARGET_AARCH64_AARCH64PERFECTSHUFFLE_H
1616

1717
#include "llvm/ADT/ArrayRef.h"
18+
#include "llvm/ADT/STLExtras.h"
1819

1920
namespace llvm {
2021

@@ -6723,6 +6724,29 @@ inline bool isREVMask(ArrayRef<int> M, unsigned EltSize, unsigned NumElts,
67236724
return true;
67246725
}
67256726

6727+
/// isDUPQMask - matches a splat of equivalent lanes within segments of a given
6728+
/// number of elements.
6729+
inline std::optional<unsigned> isDUPQMask(ArrayRef<int> M, unsigned Segments,
6730+
unsigned NumElts) {
6731+
unsigned Lane = (unsigned)M[0];
6732+
6733+
// Make sure there's no size changes.
6734+
if (NumElts * Segments != M.size())
6735+
return std::nullopt;
6736+
6737+
// Check the first index corresponds to one of the lanes in the first segment.
6738+
if (Lane >= NumElts)
6739+
return std::nullopt;
6740+
6741+
// Check that all lanes match the first, adjusted for segment.
6742+
if (all_of(enumerate(M), [&](auto P) {
6743+
return (unsigned)P.value() == Lane + (P.index() / NumElts) * NumElts;
6744+
}))
6745+
return Lane;
6746+
6747+
return std::nullopt;
6748+
}
6749+
67266750
} // namespace llvm
67276751

67286752
#endif

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5600,23 +5600,19 @@ AArch64TTIImpl::getShuffleCost(TTI::ShuffleKind Kind, VectorType *DstTy,
56005600
}
56015601

56025602
// Segmented shuffle matching.
5603-
if (ST->hasSVE2p1() && CostKind == TTI::TCK_RecipThroughput &&
5603+
if ((ST->hasSVE2p1() || ST->hasSME2p1()) &&
56045604
Kind == TTI::SK_PermuteSingleSrc && isa<FixedVectorType>(Tp) &&
5605-
Tp->getPrimitiveSizeInBits().isKnownMultipleOf(128)) {
5605+
Tp->getPrimitiveSizeInBits().isKnownMultipleOf(
5606+
AArch64::SVEBitsPerBlock)) {
56065607

56075608
FixedVectorType *VTy = cast<FixedVectorType>(Tp);
5608-
unsigned Segments = VTy->getPrimitiveSizeInBits() / 128;
5609+
unsigned Segments =
5610+
VTy->getPrimitiveSizeInBits() / AArch64::SVEBitsPerBlock;
56095611
unsigned SegmentElts = VTy->getNumElements() / Segments;
56105612

56115613
// dupq zd.t, zn.t[idx]
5612-
unsigned Lane = (unsigned)Mask[0];
5613-
if (SegmentElts * Segments == Mask.size() && Lane < SegmentElts) {
5614-
bool IsDupQ = true;
5615-
for (unsigned I = 1; I < Mask.size(); ++I)
5616-
IsDupQ &= (unsigned)Mask[I] == Lane + ((I / SegmentElts) * SegmentElts);
5617-
if (IsDupQ)
5618-
return LT.first;
5619-
}
5614+
if (isDUPQMask(Mask, Segments, SegmentElts))
5615+
return LT.first;
56205616
}
56215617

56225618
// Check for broadcast loads, which are supported by the LD1R instruction.

0 commit comments

Comments
 (0)