Skip to content

Commit c8d0e24

Browse files
authored
[VPlan] Preserve trunc nuw/nsw in VPRecipeWithIRFlags (#144700)
This preserves the nuw/nsw flags on widened truncs by checking for TruncInst in the VPIRFlags constructor The motivation for this is to be able to fold away some redundant truncs feeding into uitofps (or potentially narrow the inductions feeding them)
1 parent b0769aa commit c8d0e24

File tree

5 files changed

+49
-11
lines changed

5 files changed

+49
-11
lines changed

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,7 @@ class VPIRFlags {
597597
enum class OperationType : unsigned char {
598598
Cmp,
599599
OverflowingBinOp,
600+
Trunc,
600601
DisjointOp,
601602
PossiblyExactOp,
602603
GEPOp,
@@ -613,6 +614,13 @@ class VPIRFlags {
613614
WrapFlagsTy(bool HasNUW, bool HasNSW) : HasNUW(HasNUW), HasNSW(HasNSW) {}
614615
};
615616

617+
struct TruncFlagsTy {
618+
char HasNUW : 1;
619+
char HasNSW : 1;
620+
621+
TruncFlagsTy(bool HasNUW, bool HasNSW) : HasNUW(HasNUW), HasNSW(HasNSW) {}
622+
};
623+
616624
struct DisjointFlagsTy {
617625
char IsDisjoint : 1;
618626
DisjointFlagsTy(bool IsDisjoint) : IsDisjoint(IsDisjoint) {}
@@ -644,6 +652,7 @@ class VPIRFlags {
644652
union {
645653
CmpInst::Predicate CmpPredicate;
646654
WrapFlagsTy WrapFlags;
655+
TruncFlagsTy TruncFlags;
647656
DisjointFlagsTy DisjointFlags;
648657
ExactFlagsTy ExactFlags;
649658
GEPNoWrapFlags GEPFlags;
@@ -665,6 +674,9 @@ class VPIRFlags {
665674
} else if (auto *Op = dyn_cast<OverflowingBinaryOperator>(&I)) {
666675
OpType = OperationType::OverflowingBinOp;
667676
WrapFlags = {Op->hasNoUnsignedWrap(), Op->hasNoSignedWrap()};
677+
} else if (auto *Op = dyn_cast<TruncInst>(&I)) {
678+
OpType = OperationType::Trunc;
679+
TruncFlags = {Op->hasNoUnsignedWrap(), Op->hasNoSignedWrap()};
668680
} else if (auto *Op = dyn_cast<PossiblyExactOperator>(&I)) {
669681
OpType = OperationType::PossiblyExactOp;
670682
ExactFlags.IsExact = Op->isExact();
@@ -715,6 +727,10 @@ class VPIRFlags {
715727
WrapFlags.HasNUW = false;
716728
WrapFlags.HasNSW = false;
717729
break;
730+
case OperationType::Trunc:
731+
TruncFlags.HasNUW = false;
732+
TruncFlags.HasNSW = false;
733+
break;
718734
case OperationType::DisjointOp:
719735
DisjointFlags.IsDisjoint = false;
720736
break;
@@ -744,6 +760,10 @@ class VPIRFlags {
744760
I.setHasNoUnsignedWrap(WrapFlags.HasNUW);
745761
I.setHasNoSignedWrap(WrapFlags.HasNSW);
746762
break;
763+
case OperationType::Trunc:
764+
I.setHasNoUnsignedWrap(TruncFlags.HasNUW);
765+
I.setHasNoSignedWrap(TruncFlags.HasNSW);
766+
break;
747767
case OperationType::DisjointOp:
748768
cast<PossiblyDisjointInst>(&I)->setIsDisjoint(DisjointFlags.IsDisjoint);
749769
break;
@@ -800,15 +820,25 @@ class VPIRFlags {
800820
}
801821

802822
bool hasNoUnsignedWrap() const {
803-
assert(OpType == OperationType::OverflowingBinOp &&
804-
"recipe doesn't have a NUW flag");
805-
return WrapFlags.HasNUW;
823+
switch (OpType) {
824+
case OperationType::OverflowingBinOp:
825+
return WrapFlags.HasNUW;
826+
case OperationType::Trunc:
827+
return TruncFlags.HasNUW;
828+
default:
829+
llvm_unreachable("recipe doesn't have a NUW flag");
830+
}
806831
}
807832

808833
bool hasNoSignedWrap() const {
809-
assert(OpType == OperationType::OverflowingBinOp &&
810-
"recipe doesn't have a NSW flag");
811-
return WrapFlags.HasNSW;
834+
switch (OpType) {
835+
case OperationType::OverflowingBinOp:
836+
return WrapFlags.HasNSW;
837+
case OperationType::Trunc:
838+
return TruncFlags.HasNSW;
839+
default:
840+
llvm_unreachable("recipe doesn't have a NSW flag");
841+
}
812842
}
813843

814844
bool isDisjoint() const {

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1763,6 +1763,8 @@ bool VPIRFlags::flagsValidForOpcode(unsigned Opcode) const {
17631763
return Opcode == Instruction::Add || Opcode == Instruction::Sub ||
17641764
Opcode == Instruction::Mul ||
17651765
Opcode == VPInstruction::VPInstruction::CanonicalIVIncrementForPart;
1766+
case OperationType::Trunc:
1767+
return Opcode == Instruction::Trunc;
17661768
case OperationType::DisjointOp:
17671769
return Opcode == Instruction::Or;
17681770
case OperationType::PossiblyExactOp:
@@ -1810,6 +1812,12 @@ void VPIRFlags::printFlags(raw_ostream &O) const {
18101812
if (WrapFlags.HasNSW)
18111813
O << " nsw";
18121814
break;
1815+
case OperationType::Trunc:
1816+
if (TruncFlags.HasNUW)
1817+
O << " nuw";
1818+
if (TruncFlags.HasNSW)
1819+
O << " nsw";
1820+
break;
18131821
case OperationType::FPMathOp:
18141822
getFastMathFlags().print(O);
18151823
break;

llvm/test/Transforms/LoopVectorize/AArch64/conditional-branches-cost.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1484,7 +1484,7 @@ define void @redundant_branch_and_tail_folding(ptr %dst, i1 %c) {
14841484
; DEFAULT-NEXT: [[VEC_IND:%.*]] = phi <4 x i64> [ <i64 0, i64 1, i64 2, i64 3>, %[[VECTOR_PH]] ], [ [[VEC_IND_NEXT:%.*]], %[[VECTOR_BODY]] ]
14851485
; DEFAULT-NEXT: [[STEP_ADD:%.*]] = add <4 x i64> [[VEC_IND]], splat (i64 4)
14861486
; DEFAULT-NEXT: [[TMP0:%.*]] = add nuw nsw <4 x i64> [[STEP_ADD]], splat (i64 1)
1487-
; DEFAULT-NEXT: [[TMP1:%.*]] = trunc <4 x i64> [[TMP0]] to <4 x i32>
1487+
; DEFAULT-NEXT: [[TMP1:%.*]] = trunc nuw nsw <4 x i64> [[TMP0]] to <4 x i32>
14881488
; DEFAULT-NEXT: [[TMP2:%.*]] = extractelement <4 x i32> [[TMP1]], i32 3
14891489
; DEFAULT-NEXT: store i32 [[TMP2]], ptr [[DST]], align 4
14901490
; DEFAULT-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 8
@@ -1521,7 +1521,7 @@ define void @redundant_branch_and_tail_folding(ptr %dst, i1 %c) {
15211521
; PRED-NEXT: [[VEC_IND:%.*]] = phi <4 x i64> [ <i64 0, i64 1, i64 2, i64 3>, %[[VECTOR_PH]] ], [ [[VEC_IND_NEXT:%.*]], %[[PRED_STORE_CONTINUE6]] ]
15221522
; PRED-NEXT: [[TMP0:%.*]] = icmp ule <4 x i64> [[VEC_IND]], splat (i64 20)
15231523
; PRED-NEXT: [[TMP1:%.*]] = add nuw nsw <4 x i64> [[VEC_IND]], splat (i64 1)
1524-
; PRED-NEXT: [[TMP2:%.*]] = trunc <4 x i64> [[TMP1]] to <4 x i32>
1524+
; PRED-NEXT: [[TMP2:%.*]] = trunc nuw nsw <4 x i64> [[TMP1]] to <4 x i32>
15251525
; PRED-NEXT: [[TMP3:%.*]] = extractelement <4 x i1> [[TMP0]], i32 0
15261526
; PRED-NEXT: br i1 [[TMP3]], label %[[PRED_STORE_IF:.*]], label %[[PRED_STORE_CONTINUE:.*]]
15271527
; PRED: [[PRED_STORE_IF]]:

llvm/test/Transforms/LoopVectorize/ARM/mve-reg-pressure-vmla.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --filter-out-after "^scalar.ph:" --version 5
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals none --filter-out-after "^scalar.ph:" --version 5
22
; RUN: opt -mattr=+mve -passes=loop-vectorize < %s -S -o - | FileCheck %s
33

44
target datalayout = "e-m:e-p:32:32-Fi8-i64:64-v128:64:128-a:0:32-n32-S64"
@@ -49,7 +49,7 @@ define void @fn(i32 noundef %n, ptr %in, ptr %out) #0 {
4949
; CHECK-NEXT: [[TMP10:%.*]] = add nuw nsw <4 x i32> [[TMP9]], [[TMP6]]
5050
; CHECK-NEXT: [[TMP11:%.*]] = add nuw nsw <4 x i32> [[TMP10]], [[TMP8]]
5151
; CHECK-NEXT: [[TMP12:%.*]] = lshr <4 x i32> [[TMP11]], splat (i32 16)
52-
; CHECK-NEXT: [[TMP13:%.*]] = trunc <4 x i32> [[TMP12]] to <4 x i8>
52+
; CHECK-NEXT: [[TMP13:%.*]] = trunc nuw <4 x i32> [[TMP12]] to <4 x i8>
5353
; CHECK-NEXT: [[TMP14:%.*]] = mul nuw nsw <4 x i32> [[TMP3]], splat (i32 32767)
5454
; CHECK-NEXT: [[TMP15:%.*]] = mul nuw <4 x i32> [[TMP5]], splat (i32 16762097)
5555
; CHECK-NEXT: [[TMP16:%.*]] = mul nuw <4 x i32> [[TMP7]], splat (i32 16759568)

llvm/test/Transforms/PhaseOrdering/ARM/arm_mult_q15.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ define void @arm_mult_q15(ptr %pSrcA, ptr %pSrcB, ptr noalias %pDst, i32 %blockS
4141
; CHECK-NEXT: [[TMP5:%.*]] = mul nsw <8 x i32> [[TMP4]], [[TMP3]]
4242
; CHECK-NEXT: [[TMP6:%.*]] = ashr <8 x i32> [[TMP5]], splat (i32 15)
4343
; CHECK-NEXT: [[TMP7:%.*]] = tail call <8 x i32> @llvm.smin.v8i32(<8 x i32> [[TMP6]], <8 x i32> splat (i32 32767))
44-
; CHECK-NEXT: [[TMP8:%.*]] = trunc <8 x i32> [[TMP7]] to <8 x i16>
44+
; CHECK-NEXT: [[TMP8:%.*]] = trunc nsw <8 x i32> [[TMP7]] to <8 x i16>
4545
; CHECK-NEXT: store <8 x i16> [[TMP8]], ptr [[NEXT_GEP14]], align 2
4646
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
4747
; CHECK-NEXT: [[TMP9:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]

0 commit comments

Comments
 (0)