Skip to content

Commit 1d821b0

Browse files
authored
[AArch64] use isTRNMask to calculate shuffle costs (#171524)
This builds on #169858 to fix the divergence in codegen (https://godbolt.org/z/a9az3h6oq) between two very similar functions initially observed in #137447 (represented in the diff by test cases `@transpose_splat_constants` and `@transpose_constants_splat`: ``` int8x16_t f(int8_t x) { return (int8x16_t) { x, 0, x, 1, x, 2, x, 3, x, 4, x, 5, x, 6, x, 7 }; } int8x16_t g(int8_t x) { return (int8x16_t) { 0, x, 1, x, 2, x, 3, x, 4, x, 5, x, 6, x, 7, x }; } ``` The PR uses an additional `isTRNMask` call in `AArch64TTIImpl::getShuffleCost` to ensure that we treat shuffle masks as transpose masks even if `isTransposeMask` fails to recognise them (meaning that `Kind == TTI::SK_Transpose` cannot be relied upon). Follow-up work could consider modifying `isTransposeMask`, but that would also impact other backends than AArch64.
1 parent 8f51da3 commit 1d821b0

File tree

4 files changed

+310
-7
lines changed

4 files changed

+310
-7
lines changed

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6134,8 +6134,13 @@ AArch64TTIImpl::getShuffleCost(TTI::ShuffleKind Kind, VectorType *DstTy,
61346134
unsigned Unused;
61356135
if (LT.second.isFixedLengthVector() &&
61366136
LT.second.getVectorNumElements() == Mask.size() &&
6137-
(Kind == TTI::SK_PermuteTwoSrc || Kind == TTI::SK_PermuteSingleSrc) &&
6137+
(Kind == TTI::SK_PermuteTwoSrc || Kind == TTI::SK_PermuteSingleSrc ||
6138+
// Discrepancies between isTRNMask and ShuffleVectorInst::isTransposeMask
6139+
// mean that we can end up with shuffles that satisfy isTRNMask, but end
6140+
// up labelled as TTI::SK_InsertSubvector. (e.g. {2, 0}).
6141+
Kind == TTI::SK_InsertSubvector) &&
61386142
(isZIPMask(Mask, LT.second.getVectorNumElements(), Unused, Unused) ||
6143+
isTRNMask(Mask, LT.second.getVectorNumElements(), Unused, Unused) ||
61396144
isUZPMask(Mask, LT.second.getVectorNumElements(), Unused) ||
61406145
isREVMask(Mask, LT.second.getScalarSizeInBits(),
61416146
LT.second.getVectorNumElements(), 16) ||

0 commit comments

Comments
 (0)