-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[CGP]: Optimize mul.overflow. #148343
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[CGP]: Optimize mul.overflow. #148343
Conversation
hassnaaHamdi
commented
Jul 12, 2025
- Detect cases where LHS & RHS values will not cause overflow (when the Hi parts are zero).
- Detect cases where either of LHS or RHS values could not cause overflow (when one of the Hi parts is zero).
- Detect cases where LHS & RHS values will not cause overflow (when the Hi parts are zero). - Detect cases where either of LHS or RHS values could not cause overflow (when one of the Hi parts is zero).
@llvm/pr-subscribers-backend-powerpc @llvm/pr-subscribers-backend-arm Author: Hassnaa Hamdi (hassnaaHamdi) Changes
Patch is 676.89 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/148343.diff 20 Files Affected:
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index 9bbb89e37865d..d9859ed246604 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -431,6 +431,8 @@ class CodeGenPrepare {
bool optimizeMemoryInst(Instruction *MemoryInst, Value *Addr, Type *AccessTy,
unsigned AddrSpace);
bool optimizeGatherScatterInst(Instruction *MemoryInst, Value *Ptr);
+ bool optimizeUMulWithOverflow(Instruction *I);
+ bool optimizeSMulWithOverflow(Instruction *I);
bool optimizeInlineAsmInst(CallInst *CS);
bool optimizeCallInst(CallInst *CI, ModifyDT &ModifiedDT);
bool optimizeExt(Instruction *&I);
@@ -2769,6 +2771,10 @@ bool CodeGenPrepare::optimizeCallInst(CallInst *CI, ModifyDT &ModifiedDT) {
return optimizeGatherScatterInst(II, II->getArgOperand(0));
case Intrinsic::masked_scatter:
return optimizeGatherScatterInst(II, II->getArgOperand(1));
+ case Intrinsic::umul_with_overflow:
+ return optimizeUMulWithOverflow(II);
+ case Intrinsic::smul_with_overflow:
+ return optimizeSMulWithOverflow(II);
}
SmallVector<Value *, 2> PtrOps;
@@ -6386,6 +6392,573 @@ bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
return true;
}
+// Rewrite the umul_with_overflow intrinsic by checking if any/both of the
+// operands' value range is within the legal type. If so, we can optimize the
+// multiplication algorithm. This code is supposed to be written during the step
+// of type legalization, but given that we need to reconstruct the IR which is
+// not doable there, we do it here.
+bool CodeGenPrepare::optimizeUMulWithOverflow(Instruction *I) {
+ if (TLI->getTypeAction(
+ I->getContext(),
+ TLI->getValueType(*DL, I->getType()->getContainedType(0))) !=
+ TargetLowering::TypeExpandInteger)
+ return false;
+ Value *LHS = I->getOperand(0);
+ Value *RHS = I->getOperand(1);
+ auto *Ty = LHS->getType();
+ unsigned VTBitWidth = Ty->getScalarSizeInBits();
+ unsigned VTHalfBitWidth = VTBitWidth / 2;
+ auto *LegalTy = IntegerType::getIntNTy(I->getContext(), VTHalfBitWidth);
+
+ assert(
+ (TLI->getTypeAction(I->getContext(), TLI->getValueType(*DL, LegalTy)) ==
+ TargetLowering::TypeLegal) &&
+ "Expected the type to be legal for the target lowering");
+
+ I->getParent()->setName("overflow.res");
+ auto *OverflowResBB = I->getParent();
+ auto *OverflowoEntryBB =
+ I->getParent()->splitBasicBlock(I, "overflow.entry", /*Before*/ true);
+ BasicBlock *OverflowLHSBB = BasicBlock::Create(
+ I->getContext(), "overflow.lhs", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowLHSBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.lhs", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowRHSonlyBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.rhs.only", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowLHSonlyBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.lhs.only", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowBB = BasicBlock::Create(
+ I->getContext(), "overflow.no", I->getFunction(), OverflowResBB);
+ BasicBlock *OverflowBB = BasicBlock::Create(I->getContext(), "overflow",
+ I->getFunction(), OverflowResBB);
+ // new blocks should be:
+ // entry:
+ // lhs_lo ne lhs_hi ? overflow_yes_lhs, overflow_no_lhs
+
+ // overflow_yes_lhs:
+ // rhs_lo ne rhs_hi ? overflow : overflow_no_rhs_only
+
+ // overflow_no_lhs:
+ // rhs_lo ne rhs_hi ? overflow_no_lhs_only : overflow_no
+
+ // overflow_no_rhs_only:
+ // overflow_no_lhs_only:
+ // overflow_no:
+ // overflow:
+ // overflow.res:
+
+ IRBuilder<> BuilderEntryBB(OverflowoEntryBB->getTerminator());
+ IRBuilder<> BuilderOverflowLHSBB(OverflowLHSBB);
+ IRBuilder<> BuilderNoOverflowLHSBB(NoOverflowLHSBB);
+ IRBuilder<> BuilderNoOverflowRHSonlyBB(NoOverflowRHSonlyBB);
+ IRBuilder<> BuilderNoOverflowLHSonlyBB(NoOverflowLHSonlyBB);
+ IRBuilder<> BuilderNoOverflowBB(NoOverflowBB);
+ IRBuilder<> BuilderOverflowResBB(OverflowResBB,
+ OverflowResBB->getFirstInsertionPt());
+
+ //------------------------------------------------------------------------------
+ // BB overflow.entry:
+ // get Lo and Hi of RHS & LHS:
+
+ auto *LoRHS = BuilderEntryBB.CreateTrunc(RHS, LegalTy, "lo.rhs.trunc");
+ auto *ShrHiRHS = BuilderEntryBB.CreateLShr(RHS, VTHalfBitWidth, "rhs.lsr");
+ auto *HiRHS = BuilderEntryBB.CreateTrunc(ShrHiRHS, LegalTy, "hi.rhs.trunc");
+
+ auto *LoLHS = BuilderEntryBB.CreateTrunc(LHS, LegalTy, "lo.lhs.trunc");
+ auto *ShrHiLHS = BuilderEntryBB.CreateLShr(LHS, VTHalfBitWidth, "lhs.lsr");
+ auto *HiLHS = BuilderEntryBB.CreateTrunc(ShrHiLHS, LegalTy, "hi.lhs.trunc");
+
+ auto *Cmp = BuilderEntryBB.CreateCmp(ICmpInst::ICMP_NE, HiLHS,
+ ConstantInt::getNullValue(LegalTy));
+ BuilderEntryBB.CreateCondBr(Cmp, OverflowLHSBB, NoOverflowLHSBB);
+ OverflowoEntryBB->getTerminator()->eraseFromParent();
+
+ //------------------------------------------------------------------------------
+ // BB overflow_yes_lhs:
+ Cmp = BuilderOverflowLHSBB.CreateCmp(ICmpInst::ICMP_NE, HiRHS,
+ ConstantInt::getNullValue(LegalTy));
+ BuilderOverflowLHSBB.CreateCondBr(Cmp, OverflowBB, NoOverflowRHSonlyBB);
+
+ //------------------------------------------------------------------------------
+ // BB overflow_no_lhs:
+ Cmp = BuilderNoOverflowLHSBB.CreateCmp(ICmpInst::ICMP_NE, HiRHS,
+ ConstantInt::getNullValue(LegalTy));
+ BuilderNoOverflowLHSBB.CreateCondBr(Cmp, NoOverflowLHSonlyBB, NoOverflowBB);
+
+ //------------------------------------------------------------------------------
+ // BB overflow_no_rhs_only:
+ // RHS is 64 value range, LHS is 128
+ // P0 = RHS * LoLHS
+ // P1 = RHS * HiLHS
+
+ LoLHS = BuilderNoOverflowRHSonlyBB.CreateZExt(LoLHS, Ty, "lo.lhs");
+
+ // P0 = (RHS * LoLHS)
+ auto *P0 = BuilderNoOverflowRHSonlyBB.CreateMul(RHS, LoLHS,
+ "mul.no.overflow.rhs.lolhs");
+ auto *P0Lo = BuilderNoOverflowRHSonlyBB.CreateTrunc(P0, LegalTy, "p0.lo.rhs");
+ auto *P0Hi =
+ BuilderNoOverflowRHSonlyBB.CreateLShr(P0, VTHalfBitWidth, "p0.rhs.lsr");
+ P0Hi = BuilderNoOverflowRHSonlyBB.CreateTrunc(P0Hi, LegalTy, "p0.hi.rhs");
+
+ // P1 = (RHS * HiLHS)
+ auto *P1 = BuilderNoOverflowRHSonlyBB.CreateMul(RHS, ShrHiLHS,
+ "mul.no.overflow.rhs.hilhs");
+ auto *P1Lo = BuilderNoOverflowRHSonlyBB.CreateTrunc(P1, LegalTy, "p1.lo.rhs");
+ auto *P1Hi =
+ BuilderNoOverflowRHSonlyBB.CreateLShr(P1, VTHalfBitWidth, "p1.rhs.lsr");
+ P1Hi = BuilderNoOverflowRHSonlyBB.CreateTrunc(P1Hi, LegalTy, "p1.hi.rhs");
+
+ auto *AddOverflow = BuilderNoOverflowRHSonlyBB.CreateIntrinsic(
+ Intrinsic::uadd_with_overflow, LegalTy, {P0Hi, P1Lo});
+ auto *AddOResMid = BuilderNoOverflowRHSonlyBB.CreateExtractValue(
+ AddOverflow, 0, "rhs.p0.p1.res");
+ auto *Carry = BuilderNoOverflowRHSonlyBB.CreateExtractValue(
+ AddOverflow, 1, "rhs.p0.p1.carry");
+ Carry =
+ BuilderNoOverflowRHSonlyBB.CreateZExt(Carry, LegalTy, "rhs.carry.zext");
+ auto *ResHi =
+ BuilderNoOverflowRHSonlyBB.CreateAdd(P1Hi, Carry, "rhs.p1.carry");
+
+ auto *ResLoEx =
+ BuilderNoOverflowRHSonlyBB.CreateZExt(P0Lo, Ty, "rhs.res_lo.zext");
+ auto *ResMid =
+ BuilderNoOverflowRHSonlyBB.CreateZExt(AddOResMid, Ty, "rhs.res_mid.zext");
+ auto *ResMidShl = BuilderNoOverflowRHSonlyBB.CreateShl(ResMid, VTHalfBitWidth,
+ "rhs.res_mid.shl");
+ auto *FinalRes = BuilderNoOverflowRHSonlyBB.CreateOr(ResLoEx, ResMidShl,
+ "rhs.res_lo.or.mid");
+ auto *IsOverflow = BuilderNoOverflowRHSonlyBB.CreateICmp(
+ ICmpInst::ICMP_NE, ResHi, Constant::getNullValue(LegalTy),
+ "rhs.check.overflow");
+
+ StructType *STy = StructType::get(
+ I->getContext(), {Ty, IntegerType::getInt1Ty(I->getContext())});
+ Value *StructValNoOverflowRHS = PoisonValue::get(STy);
+ StructValNoOverflowRHS = BuilderNoOverflowRHSonlyBB.CreateInsertValue(
+ StructValNoOverflowRHS, FinalRes, {0});
+ StructValNoOverflowRHS = BuilderNoOverflowRHSonlyBB.CreateInsertValue(
+ StructValNoOverflowRHS, IsOverflow, {1});
+ BuilderNoOverflowRHSonlyBB.CreateBr(OverflowResBB);
+ //------------------------------------------------------------------------------
+
+ // BB overflow_no_lhs_only:
+
+ LoRHS = BuilderNoOverflowLHSonlyBB.CreateZExt(LoRHS, Ty, "lo.rhs");
+
+ // P0 = (LHS * LoRHS)
+ P0 = BuilderNoOverflowLHSonlyBB.CreateMul(LHS, LoRHS,
+ "mul.no.overflow.lhs.lorhs");
+ P0Lo = BuilderNoOverflowLHSonlyBB.CreateTrunc(P0, LegalTy, "p0.lo.lhs");
+ P0Hi =
+ BuilderNoOverflowLHSonlyBB.CreateLShr(P0, VTHalfBitWidth, "p0.lsr.lhs");
+ P0Hi = BuilderNoOverflowLHSonlyBB.CreateTrunc(P0Hi, LegalTy, "p0.hi.lhs");
+
+ // P1 = (LHS * HiRHS)
+ P1 = BuilderNoOverflowLHSonlyBB.CreateMul(LHS, ShrHiRHS,
+ "mul.no.overflow.lhs.hirhs");
+ P1Lo = BuilderNoOverflowLHSonlyBB.CreateTrunc(P1, LegalTy, "p1.lo.lhs");
+ P1Hi =
+ BuilderNoOverflowLHSonlyBB.CreateLShr(P1, VTHalfBitWidth, "p1.lhs.lsr");
+ P1Hi = BuilderNoOverflowLHSonlyBB.CreateTrunc(P1Hi, LegalTy, "p1.hi.lhs");
+
+ AddOverflow = BuilderNoOverflowLHSonlyBB.CreateIntrinsic(
+ Intrinsic::uadd_with_overflow, LegalTy, {P0Hi, P1Lo});
+ AddOResMid = BuilderNoOverflowLHSonlyBB.CreateExtractValue(AddOverflow, 0,
+ "lhs.p0.p1.res");
+ Carry = BuilderNoOverflowLHSonlyBB.CreateExtractValue(AddOverflow, 1,
+ "lhs.p0.p1.carry");
+ Carry =
+ BuilderNoOverflowLHSonlyBB.CreateZExt(Carry, LegalTy, "lhs.carry.zext");
+ ResHi = BuilderNoOverflowLHSonlyBB.CreateAdd(P1Hi, Carry, "lhs.p1.carry");
+
+ ResLoEx = BuilderNoOverflowLHSonlyBB.CreateZExt(P0Lo, Ty, "lhs.res_lo.zext");
+ ResMid =
+ BuilderNoOverflowLHSonlyBB.CreateZExt(AddOResMid, Ty, "lhs.res_mid.zext");
+ ResMidShl = BuilderNoOverflowLHSonlyBB.CreateShl(ResMid, VTHalfBitWidth,
+ "lhs.res_mid.shl");
+ FinalRes = BuilderNoOverflowLHSonlyBB.CreateOr(ResLoEx, ResMidShl,
+ "lhs.res_lo.or.mid");
+ IsOverflow = BuilderNoOverflowLHSonlyBB.CreateICmp(
+ ICmpInst::ICMP_NE, ResHi, Constant::getNullValue(LegalTy),
+ "lhs.check.overflow");
+
+ STy = StructType::get(I->getContext(),
+ {Ty, IntegerType::getInt1Ty(I->getContext())});
+ Value *StructValNoOverflowLHS = PoisonValue::get(STy);
+ StructValNoOverflowLHS = BuilderNoOverflowLHSonlyBB.CreateInsertValue(
+ StructValNoOverflowLHS, FinalRes, {0});
+ StructValNoOverflowLHS = BuilderNoOverflowLHSonlyBB.CreateInsertValue(
+ StructValNoOverflowLHS, IsOverflow, {1});
+
+ BuilderNoOverflowLHSonlyBB.CreateBr(OverflowResBB);
+ //------------------------------------------------------------------------------
+
+ // BB overflow.no:
+ auto *Mul = BuilderNoOverflowBB.CreateMul(LHS, RHS, "mul.no.overflow");
+ STy = StructType::get(I->getContext(),
+ {Ty, IntegerType::getInt1Ty(I->getContext())});
+ Value *StructValNoOverflow = PoisonValue::get(STy);
+ StructValNoOverflow =
+ BuilderNoOverflowBB.CreateInsertValue(StructValNoOverflow, Mul, {0});
+ StructValNoOverflow = BuilderNoOverflowBB.CreateInsertValue(
+ StructValNoOverflow, ConstantInt::getFalse(I->getContext()), {1});
+ BuilderNoOverflowBB.CreateBr(OverflowResBB);
+
+ // BB overflow.res:
+ auto *PHINode = BuilderOverflowResBB.CreatePHI(STy, 2);
+ PHINode->addIncoming(StructValNoOverflow, NoOverflowBB);
+ PHINode->addIncoming(StructValNoOverflowLHS, NoOverflowLHSonlyBB);
+ PHINode->addIncoming(StructValNoOverflowRHS, NoOverflowRHSonlyBB);
+
+ // Before moving the mul.overflow intrinsic to the overflowBB, replace all its
+ // uses by PHINode.
+ I->replaceAllUsesWith(PHINode);
+
+ // BB overflow:
+ PHINode->addIncoming(I, OverflowBB);
+ I->removeFromParent();
+ I->insertInto(OverflowBB, OverflowBB->end());
+ IRBuilder<>(OverflowBB, OverflowBB->end()).CreateBr(OverflowResBB);
+
+ // return false to stop reprocessing the function.
+ return false;
+}
+
+// Rewrite the smul_with_overflow intrinsic by checking if any/both of the
+// operands' value range is within the legal type. If so, we can optimize the
+// multiplication algorithm. This code is supposed to be written during the step
+// of type legalization, but given that we need to reconstruct the IR which is
+// not doable there, we do it here.
+bool CodeGenPrepare::optimizeSMulWithOverflow(Instruction *I) {
+ if (TLI->getTypeAction(
+ I->getContext(),
+ TLI->getValueType(*DL, I->getType()->getContainedType(0))) !=
+ TargetLowering::TypeExpandInteger)
+ return false;
+ Value *LHS = I->getOperand(0);
+ Value *RHS = I->getOperand(1);
+ auto *Ty = LHS->getType();
+ unsigned VTBitWidth = Ty->getScalarSizeInBits();
+ unsigned VTHalfBitWidth = VTBitWidth / 2;
+ auto *LegalTy = IntegerType::getIntNTy(I->getContext(), VTHalfBitWidth);
+
+ assert(
+ (TLI->getTypeAction(I->getContext(), TLI->getValueType(*DL, LegalTy)) ==
+ TargetLowering::TypeLegal) &&
+ "Expected the type to be legal for the target lowering");
+
+ I->getParent()->setName("overflow.res");
+ auto *OverflowResBB = I->getParent();
+ auto *OverflowoEntryBB =
+ I->getParent()->splitBasicBlock(I, "overflow.entry", /*Before*/ true);
+ BasicBlock *OverflowLHSBB = BasicBlock::Create(
+ I->getContext(), "overflow.lhs", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowLHSBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.lhs", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowRHSonlyBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.rhs.only", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowLHSonlyBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.lhs.only", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowBB = BasicBlock::Create(
+ I->getContext(), "overflow.no", I->getFunction(), OverflowResBB);
+ BasicBlock *OverflowBB = BasicBlock::Create(I->getContext(), "overflow",
+ I->getFunction(), OverflowResBB);
+ // new blocks should be:
+ // entry:
+ // lhs_lo ne lhs_hi ? overflow_yes_lhs, overflow_no_lhs
+
+ // overflow_yes_lhs:
+ // rhs_lo ne rhs_hi ? overflow : overflow_no_rhs_only
+
+ // overflow_no_lhs:
+ // rhs_lo ne rhs_hi ? overflow_no_lhs_only : overflow_no
+
+ // overflow_no_rhs_only:
+ // overflow_no_lhs_only:
+ // overflow_no:
+ // overflow:
+ // overflow.res:
+
+ IRBuilder<> BuilderEntryBB(OverflowoEntryBB->getTerminator());
+ IRBuilder<> BuilderOverflowLHSBB(OverflowLHSBB);
+ IRBuilder<> BuilderNoOverflowLHSBB(NoOverflowLHSBB);
+ IRBuilder<> BuilderNoOverflowRHSonlyBB(NoOverflowRHSonlyBB);
+ IRBuilder<> BuilderNoOverflowLHSonlyBB(NoOverflowLHSonlyBB);
+ IRBuilder<> BuilderNoOverflowBB(NoOverflowBB);
+ IRBuilder<> BuilderOverflowResBB(OverflowResBB,
+ OverflowResBB->getFirstInsertionPt());
+
+ //------------------------------------------------------------------------------
+ // BB overflow.entry:
+ // get Lo and Hi of RHS & LHS:
+
+ auto *LoRHS = BuilderEntryBB.CreateTrunc(RHS, LegalTy, "lo.rhs");
+ auto *SignLoRHS =
+ BuilderEntryBB.CreateAShr(LoRHS, VTHalfBitWidth - 1, "sign.lo.rhs");
+ auto *HiRHS = BuilderEntryBB.CreateLShr(RHS, VTHalfBitWidth, "rhs.lsr");
+ HiRHS = BuilderEntryBB.CreateTrunc(HiRHS, LegalTy, "hi.rhs");
+
+ auto *LoLHS = BuilderEntryBB.CreateTrunc(LHS, LegalTy, "lo.lhs");
+ auto *SignLoLHS =
+ BuilderEntryBB.CreateAShr(LoLHS, VTHalfBitWidth - 1, "sign.lo.lhs");
+ auto *HiLHS = BuilderEntryBB.CreateLShr(LHS, VTHalfBitWidth, "lhs.lsr");
+ HiLHS = BuilderEntryBB.CreateTrunc(HiLHS, LegalTy, "hi.lhs");
+
+ auto *Cmp = BuilderEntryBB.CreateCmp(ICmpInst::ICMP_NE, HiLHS, SignLoLHS);
+ BuilderEntryBB.CreateCondBr(Cmp, OverflowLHSBB, NoOverflowLHSBB);
+ OverflowoEntryBB->getTerminator()->eraseFromParent();
+
+ //------------------------------------------------------------------------------
+ // BB overflow_yes_lhs:
+ Cmp = BuilderOverflowLHSBB.CreateCmp(ICmpInst::ICMP_NE, HiRHS, SignLoRHS);
+ BuilderOverflowLHSBB.CreateCondBr(Cmp, OverflowBB, NoOverflowRHSonlyBB);
+
+ //------------------------------------------------------------------------------
+ // BB overflow_no_lhs:
+ Cmp = BuilderNoOverflowLHSBB.CreateCmp(ICmpInst::ICMP_NE, HiRHS, SignLoRHS);
+ BuilderNoOverflowLHSBB.CreateCondBr(Cmp, NoOverflowLHSonlyBB, NoOverflowBB);
+
+ //------------------------------------------------------------------------------
+ // BB overflow_no_rhs_only:
+ // RHS is within 64 value range, LHS is 128
+ // P0 = RHS * LoLHS
+ // P1 = RHS * HiLHS
+
+ // check sign of RHS:
+ auto *IsNegRHS = BuilderNoOverflowRHSonlyBB.CreateIsNeg(RHS, "rhs.isneg");
+ auto *AbsRHSIntr = BuilderNoOverflowRHSonlyBB.CreateBinaryIntrinsic(
+ Intrinsic::abs, RHS, ConstantInt::getFalse(I->getContext()), {},
+ "abs.rhs");
+ auto *AbsRHS = BuilderNoOverflowRHSonlyBB.CreateSelect(
+ IsNegRHS, AbsRHSIntr, RHS, "lo.abs.rhs.select");
+
+ // check sign of LHS:
+ auto *IsNegLHS = BuilderNoOverflowRHSonlyBB.CreateIsNeg(LHS, "lhs.isneg");
+ auto *AbsLHSIntr = BuilderNoOverflowRHSonlyBB.CreateBinaryIntrinsic(
+ Intrinsic::abs, LHS, ConstantInt::getFalse(I->getContext()), {},
+ "abs.lhs");
+ auto *AbsLHS = BuilderNoOverflowRHSonlyBB.CreateSelect(IsNegLHS, AbsLHSIntr,
+ LHS, "abs.lhs.select");
+ LoLHS = BuilderNoOverflowRHSonlyBB.CreateAnd(
+ AbsLHS,
+ ConstantInt::get(Ty, APInt::getLowBitsSet(VTBitWidth, VTHalfBitWidth)),
+ "lo.abs.lhs");
+ HiLHS = BuilderNoOverflowRHSonlyBB.CreateLShr(AbsLHS, VTHalfBitWidth,
+ "hi.abs.lhs");
+
+ // P0 = (RHS * LoLHS)
+ auto *P0 = BuilderNoOverflowRHSonlyBB.CreateMul(AbsRHS, LoLHS,
+ "mul.no.overflow.rhs.lolhs");
+ auto *P0Lo = BuilderNoOverflowRHSonlyBB.CreateTrunc(P0, LegalTy, "p0.lo.rhs");
+ auto *P0Hi =
+ BuilderNoOverflowRHSonlyBB.CreateLShr(P0, VTHalfBitWidth, "p0.rhs.lsr");
+ P0Hi = BuilderNoOverflowRHSonlyBB.CreateTrunc(P0Hi, LegalTy, "p0.hi.rhs");
+
+ // P1 = (RHS * HiLHS)
+ auto *P1 = BuilderNoOverflowRHSonlyBB.CreateMul(AbsRHS, HiLHS,
+ "mul.no.overflow.rhs.hilhs");
+ auto *P1Lo = BuilderNoOverflowRHSonlyBB.CreateTrunc(P1, LegalTy, "p1.lo.rhs");
+ auto *P1Hi =
+ BuilderNoOverflowRHSonlyBB.CreateLShr(P1, VTHalfBitWidth, "p1.rhs.lsr");
+ P1Hi = BuilderNoOverflowRHSonlyBB.CreateTrunc(P1Hi, LegalTy, "p1.hi.rhs");
+
+ auto *AddOverflow = BuilderNoOverflowRHSonlyBB.CreateIntrinsic(
+ Intrinsic::uadd_with_overflow, LegalTy, {P0Hi, P1Lo});
+ auto *AddOResMid = BuilderNoOverflowRHSonlyBB.CreateExtractValue(
+ AddOverflow, 0, "rhs.p0.p1.res");
+ auto *Carry = BuilderNoOverflowRHSonlyBB.CreateExtractValue(
+ AddOverflow, 1, "rhs.p0.p1.carry");
+ Carry =
+ BuilderNoOverflowRHSonlyBB.CreateZExt(Carry, LegalTy, "rhs.carry.zext");
+ auto *ResHi =
+ BuilderNoOverflowRHSonlyBB.CreateAdd(P1Hi, Carry, "rhs.p1.carry");
+
+ // sign handling:
+ auto *IsNeg = BuilderNoOverflowRHSonlyBB.CreateXor(IsNegRHS, IsNegLHS); // i1
+ auto *Mask =
+ BuilderNoOverflowRHSonlyBB.CreateSExt(IsNeg, LegalTy, "rhs.sign.mask");
+ auto *Add_1 =
+ BuilderNoOverflowRHSonlyBB.CreateZExt(IsNeg, LegalTy, "rhs.add.1");
+ auto *ResLo =
+ BuilderNoOverflowRHSonlyBB.CreateXor(P0Lo, Mask, "rhs.res_lo.xor.mask");
+ ResLo =
+ BuilderNoOverflowRHSonlyBB.CreateAdd(ResLo, Add_1, "rhs.res_lo.add.1");
+
+ Carry = BuilderNoOverflowRHSonlyBB.CreateCmp(ICmpInst::ICMP_ULT, ResLo, Add_1,
+ "rhs.check.res_lo.c...
[truncated]
|
@llvm/pr-subscribers-backend-aarch64 Author: Hassnaa Hamdi (hassnaaHamdi) Changes
Patch is 676.89 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/148343.diff 20 Files Affected:
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index 9bbb89e37865d..d9859ed246604 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -431,6 +431,8 @@ class CodeGenPrepare {
bool optimizeMemoryInst(Instruction *MemoryInst, Value *Addr, Type *AccessTy,
unsigned AddrSpace);
bool optimizeGatherScatterInst(Instruction *MemoryInst, Value *Ptr);
+ bool optimizeUMulWithOverflow(Instruction *I);
+ bool optimizeSMulWithOverflow(Instruction *I);
bool optimizeInlineAsmInst(CallInst *CS);
bool optimizeCallInst(CallInst *CI, ModifyDT &ModifiedDT);
bool optimizeExt(Instruction *&I);
@@ -2769,6 +2771,10 @@ bool CodeGenPrepare::optimizeCallInst(CallInst *CI, ModifyDT &ModifiedDT) {
return optimizeGatherScatterInst(II, II->getArgOperand(0));
case Intrinsic::masked_scatter:
return optimizeGatherScatterInst(II, II->getArgOperand(1));
+ case Intrinsic::umul_with_overflow:
+ return optimizeUMulWithOverflow(II);
+ case Intrinsic::smul_with_overflow:
+ return optimizeSMulWithOverflow(II);
}
SmallVector<Value *, 2> PtrOps;
@@ -6386,6 +6392,573 @@ bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
return true;
}
+// Rewrite the umul_with_overflow intrinsic by checking if any/both of the
+// operands' value range is within the legal type. If so, we can optimize the
+// multiplication algorithm. This code is supposed to be written during the step
+// of type legalization, but given that we need to reconstruct the IR which is
+// not doable there, we do it here.
+bool CodeGenPrepare::optimizeUMulWithOverflow(Instruction *I) {
+ if (TLI->getTypeAction(
+ I->getContext(),
+ TLI->getValueType(*DL, I->getType()->getContainedType(0))) !=
+ TargetLowering::TypeExpandInteger)
+ return false;
+ Value *LHS = I->getOperand(0);
+ Value *RHS = I->getOperand(1);
+ auto *Ty = LHS->getType();
+ unsigned VTBitWidth = Ty->getScalarSizeInBits();
+ unsigned VTHalfBitWidth = VTBitWidth / 2;
+ auto *LegalTy = IntegerType::getIntNTy(I->getContext(), VTHalfBitWidth);
+
+ assert(
+ (TLI->getTypeAction(I->getContext(), TLI->getValueType(*DL, LegalTy)) ==
+ TargetLowering::TypeLegal) &&
+ "Expected the type to be legal for the target lowering");
+
+ I->getParent()->setName("overflow.res");
+ auto *OverflowResBB = I->getParent();
+ auto *OverflowoEntryBB =
+ I->getParent()->splitBasicBlock(I, "overflow.entry", /*Before*/ true);
+ BasicBlock *OverflowLHSBB = BasicBlock::Create(
+ I->getContext(), "overflow.lhs", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowLHSBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.lhs", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowRHSonlyBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.rhs.only", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowLHSonlyBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.lhs.only", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowBB = BasicBlock::Create(
+ I->getContext(), "overflow.no", I->getFunction(), OverflowResBB);
+ BasicBlock *OverflowBB = BasicBlock::Create(I->getContext(), "overflow",
+ I->getFunction(), OverflowResBB);
+ // new blocks should be:
+ // entry:
+ // lhs_lo ne lhs_hi ? overflow_yes_lhs, overflow_no_lhs
+
+ // overflow_yes_lhs:
+ // rhs_lo ne rhs_hi ? overflow : overflow_no_rhs_only
+
+ // overflow_no_lhs:
+ // rhs_lo ne rhs_hi ? overflow_no_lhs_only : overflow_no
+
+ // overflow_no_rhs_only:
+ // overflow_no_lhs_only:
+ // overflow_no:
+ // overflow:
+ // overflow.res:
+
+ IRBuilder<> BuilderEntryBB(OverflowoEntryBB->getTerminator());
+ IRBuilder<> BuilderOverflowLHSBB(OverflowLHSBB);
+ IRBuilder<> BuilderNoOverflowLHSBB(NoOverflowLHSBB);
+ IRBuilder<> BuilderNoOverflowRHSonlyBB(NoOverflowRHSonlyBB);
+ IRBuilder<> BuilderNoOverflowLHSonlyBB(NoOverflowLHSonlyBB);
+ IRBuilder<> BuilderNoOverflowBB(NoOverflowBB);
+ IRBuilder<> BuilderOverflowResBB(OverflowResBB,
+ OverflowResBB->getFirstInsertionPt());
+
+ //------------------------------------------------------------------------------
+ // BB overflow.entry:
+ // get Lo and Hi of RHS & LHS:
+
+ auto *LoRHS = BuilderEntryBB.CreateTrunc(RHS, LegalTy, "lo.rhs.trunc");
+ auto *ShrHiRHS = BuilderEntryBB.CreateLShr(RHS, VTHalfBitWidth, "rhs.lsr");
+ auto *HiRHS = BuilderEntryBB.CreateTrunc(ShrHiRHS, LegalTy, "hi.rhs.trunc");
+
+ auto *LoLHS = BuilderEntryBB.CreateTrunc(LHS, LegalTy, "lo.lhs.trunc");
+ auto *ShrHiLHS = BuilderEntryBB.CreateLShr(LHS, VTHalfBitWidth, "lhs.lsr");
+ auto *HiLHS = BuilderEntryBB.CreateTrunc(ShrHiLHS, LegalTy, "hi.lhs.trunc");
+
+ auto *Cmp = BuilderEntryBB.CreateCmp(ICmpInst::ICMP_NE, HiLHS,
+ ConstantInt::getNullValue(LegalTy));
+ BuilderEntryBB.CreateCondBr(Cmp, OverflowLHSBB, NoOverflowLHSBB);
+ OverflowoEntryBB->getTerminator()->eraseFromParent();
+
+ //------------------------------------------------------------------------------
+ // BB overflow_yes_lhs:
+ Cmp = BuilderOverflowLHSBB.CreateCmp(ICmpInst::ICMP_NE, HiRHS,
+ ConstantInt::getNullValue(LegalTy));
+ BuilderOverflowLHSBB.CreateCondBr(Cmp, OverflowBB, NoOverflowRHSonlyBB);
+
+ //------------------------------------------------------------------------------
+ // BB overflow_no_lhs:
+ Cmp = BuilderNoOverflowLHSBB.CreateCmp(ICmpInst::ICMP_NE, HiRHS,
+ ConstantInt::getNullValue(LegalTy));
+ BuilderNoOverflowLHSBB.CreateCondBr(Cmp, NoOverflowLHSonlyBB, NoOverflowBB);
+
+ //------------------------------------------------------------------------------
+ // BB overflow_no_rhs_only:
+ // RHS is 64 value range, LHS is 128
+ // P0 = RHS * LoLHS
+ // P1 = RHS * HiLHS
+
+ LoLHS = BuilderNoOverflowRHSonlyBB.CreateZExt(LoLHS, Ty, "lo.lhs");
+
+ // P0 = (RHS * LoLHS)
+ auto *P0 = BuilderNoOverflowRHSonlyBB.CreateMul(RHS, LoLHS,
+ "mul.no.overflow.rhs.lolhs");
+ auto *P0Lo = BuilderNoOverflowRHSonlyBB.CreateTrunc(P0, LegalTy, "p0.lo.rhs");
+ auto *P0Hi =
+ BuilderNoOverflowRHSonlyBB.CreateLShr(P0, VTHalfBitWidth, "p0.rhs.lsr");
+ P0Hi = BuilderNoOverflowRHSonlyBB.CreateTrunc(P0Hi, LegalTy, "p0.hi.rhs");
+
+ // P1 = (RHS * HiLHS)
+ auto *P1 = BuilderNoOverflowRHSonlyBB.CreateMul(RHS, ShrHiLHS,
+ "mul.no.overflow.rhs.hilhs");
+ auto *P1Lo = BuilderNoOverflowRHSonlyBB.CreateTrunc(P1, LegalTy, "p1.lo.rhs");
+ auto *P1Hi =
+ BuilderNoOverflowRHSonlyBB.CreateLShr(P1, VTHalfBitWidth, "p1.rhs.lsr");
+ P1Hi = BuilderNoOverflowRHSonlyBB.CreateTrunc(P1Hi, LegalTy, "p1.hi.rhs");
+
+ auto *AddOverflow = BuilderNoOverflowRHSonlyBB.CreateIntrinsic(
+ Intrinsic::uadd_with_overflow, LegalTy, {P0Hi, P1Lo});
+ auto *AddOResMid = BuilderNoOverflowRHSonlyBB.CreateExtractValue(
+ AddOverflow, 0, "rhs.p0.p1.res");
+ auto *Carry = BuilderNoOverflowRHSonlyBB.CreateExtractValue(
+ AddOverflow, 1, "rhs.p0.p1.carry");
+ Carry =
+ BuilderNoOverflowRHSonlyBB.CreateZExt(Carry, LegalTy, "rhs.carry.zext");
+ auto *ResHi =
+ BuilderNoOverflowRHSonlyBB.CreateAdd(P1Hi, Carry, "rhs.p1.carry");
+
+ auto *ResLoEx =
+ BuilderNoOverflowRHSonlyBB.CreateZExt(P0Lo, Ty, "rhs.res_lo.zext");
+ auto *ResMid =
+ BuilderNoOverflowRHSonlyBB.CreateZExt(AddOResMid, Ty, "rhs.res_mid.zext");
+ auto *ResMidShl = BuilderNoOverflowRHSonlyBB.CreateShl(ResMid, VTHalfBitWidth,
+ "rhs.res_mid.shl");
+ auto *FinalRes = BuilderNoOverflowRHSonlyBB.CreateOr(ResLoEx, ResMidShl,
+ "rhs.res_lo.or.mid");
+ auto *IsOverflow = BuilderNoOverflowRHSonlyBB.CreateICmp(
+ ICmpInst::ICMP_NE, ResHi, Constant::getNullValue(LegalTy),
+ "rhs.check.overflow");
+
+ StructType *STy = StructType::get(
+ I->getContext(), {Ty, IntegerType::getInt1Ty(I->getContext())});
+ Value *StructValNoOverflowRHS = PoisonValue::get(STy);
+ StructValNoOverflowRHS = BuilderNoOverflowRHSonlyBB.CreateInsertValue(
+ StructValNoOverflowRHS, FinalRes, {0});
+ StructValNoOverflowRHS = BuilderNoOverflowRHSonlyBB.CreateInsertValue(
+ StructValNoOverflowRHS, IsOverflow, {1});
+ BuilderNoOverflowRHSonlyBB.CreateBr(OverflowResBB);
+ //------------------------------------------------------------------------------
+
+ // BB overflow_no_lhs_only:
+
+ LoRHS = BuilderNoOverflowLHSonlyBB.CreateZExt(LoRHS, Ty, "lo.rhs");
+
+ // P0 = (LHS * LoRHS)
+ P0 = BuilderNoOverflowLHSonlyBB.CreateMul(LHS, LoRHS,
+ "mul.no.overflow.lhs.lorhs");
+ P0Lo = BuilderNoOverflowLHSonlyBB.CreateTrunc(P0, LegalTy, "p0.lo.lhs");
+ P0Hi =
+ BuilderNoOverflowLHSonlyBB.CreateLShr(P0, VTHalfBitWidth, "p0.lsr.lhs");
+ P0Hi = BuilderNoOverflowLHSonlyBB.CreateTrunc(P0Hi, LegalTy, "p0.hi.lhs");
+
+ // P1 = (LHS * HiRHS)
+ P1 = BuilderNoOverflowLHSonlyBB.CreateMul(LHS, ShrHiRHS,
+ "mul.no.overflow.lhs.hirhs");
+ P1Lo = BuilderNoOverflowLHSonlyBB.CreateTrunc(P1, LegalTy, "p1.lo.lhs");
+ P1Hi =
+ BuilderNoOverflowLHSonlyBB.CreateLShr(P1, VTHalfBitWidth, "p1.lhs.lsr");
+ P1Hi = BuilderNoOverflowLHSonlyBB.CreateTrunc(P1Hi, LegalTy, "p1.hi.lhs");
+
+ AddOverflow = BuilderNoOverflowLHSonlyBB.CreateIntrinsic(
+ Intrinsic::uadd_with_overflow, LegalTy, {P0Hi, P1Lo});
+ AddOResMid = BuilderNoOverflowLHSonlyBB.CreateExtractValue(AddOverflow, 0,
+ "lhs.p0.p1.res");
+ Carry = BuilderNoOverflowLHSonlyBB.CreateExtractValue(AddOverflow, 1,
+ "lhs.p0.p1.carry");
+ Carry =
+ BuilderNoOverflowLHSonlyBB.CreateZExt(Carry, LegalTy, "lhs.carry.zext");
+ ResHi = BuilderNoOverflowLHSonlyBB.CreateAdd(P1Hi, Carry, "lhs.p1.carry");
+
+ ResLoEx = BuilderNoOverflowLHSonlyBB.CreateZExt(P0Lo, Ty, "lhs.res_lo.zext");
+ ResMid =
+ BuilderNoOverflowLHSonlyBB.CreateZExt(AddOResMid, Ty, "lhs.res_mid.zext");
+ ResMidShl = BuilderNoOverflowLHSonlyBB.CreateShl(ResMid, VTHalfBitWidth,
+ "lhs.res_mid.shl");
+ FinalRes = BuilderNoOverflowLHSonlyBB.CreateOr(ResLoEx, ResMidShl,
+ "lhs.res_lo.or.mid");
+ IsOverflow = BuilderNoOverflowLHSonlyBB.CreateICmp(
+ ICmpInst::ICMP_NE, ResHi, Constant::getNullValue(LegalTy),
+ "lhs.check.overflow");
+
+ STy = StructType::get(I->getContext(),
+ {Ty, IntegerType::getInt1Ty(I->getContext())});
+ Value *StructValNoOverflowLHS = PoisonValue::get(STy);
+ StructValNoOverflowLHS = BuilderNoOverflowLHSonlyBB.CreateInsertValue(
+ StructValNoOverflowLHS, FinalRes, {0});
+ StructValNoOverflowLHS = BuilderNoOverflowLHSonlyBB.CreateInsertValue(
+ StructValNoOverflowLHS, IsOverflow, {1});
+
+ BuilderNoOverflowLHSonlyBB.CreateBr(OverflowResBB);
+ //------------------------------------------------------------------------------
+
+ // BB overflow.no:
+ auto *Mul = BuilderNoOverflowBB.CreateMul(LHS, RHS, "mul.no.overflow");
+ STy = StructType::get(I->getContext(),
+ {Ty, IntegerType::getInt1Ty(I->getContext())});
+ Value *StructValNoOverflow = PoisonValue::get(STy);
+ StructValNoOverflow =
+ BuilderNoOverflowBB.CreateInsertValue(StructValNoOverflow, Mul, {0});
+ StructValNoOverflow = BuilderNoOverflowBB.CreateInsertValue(
+ StructValNoOverflow, ConstantInt::getFalse(I->getContext()), {1});
+ BuilderNoOverflowBB.CreateBr(OverflowResBB);
+
+ // BB overflow.res:
+ auto *PHINode = BuilderOverflowResBB.CreatePHI(STy, 2);
+ PHINode->addIncoming(StructValNoOverflow, NoOverflowBB);
+ PHINode->addIncoming(StructValNoOverflowLHS, NoOverflowLHSonlyBB);
+ PHINode->addIncoming(StructValNoOverflowRHS, NoOverflowRHSonlyBB);
+
+ // Before moving the mul.overflow intrinsic to the overflowBB, replace all its
+ // uses by PHINode.
+ I->replaceAllUsesWith(PHINode);
+
+ // BB overflow:
+ PHINode->addIncoming(I, OverflowBB);
+ I->removeFromParent();
+ I->insertInto(OverflowBB, OverflowBB->end());
+ IRBuilder<>(OverflowBB, OverflowBB->end()).CreateBr(OverflowResBB);
+
+ // return false to stop reprocessing the function.
+ return false;
+}
+
+// Rewrite the smul_with_overflow intrinsic by checking if any/both of the
+// operands' value range is within the legal type. If so, we can optimize the
+// multiplication algorithm. This code is supposed to be written during the step
+// of type legalization, but given that we need to reconstruct the IR which is
+// not doable there, we do it here.
+bool CodeGenPrepare::optimizeSMulWithOverflow(Instruction *I) {
+ if (TLI->getTypeAction(
+ I->getContext(),
+ TLI->getValueType(*DL, I->getType()->getContainedType(0))) !=
+ TargetLowering::TypeExpandInteger)
+ return false;
+ Value *LHS = I->getOperand(0);
+ Value *RHS = I->getOperand(1);
+ auto *Ty = LHS->getType();
+ unsigned VTBitWidth = Ty->getScalarSizeInBits();
+ unsigned VTHalfBitWidth = VTBitWidth / 2;
+ auto *LegalTy = IntegerType::getIntNTy(I->getContext(), VTHalfBitWidth);
+
+ assert(
+ (TLI->getTypeAction(I->getContext(), TLI->getValueType(*DL, LegalTy)) ==
+ TargetLowering::TypeLegal) &&
+ "Expected the type to be legal for the target lowering");
+
+ I->getParent()->setName("overflow.res");
+ auto *OverflowResBB = I->getParent();
+ auto *OverflowoEntryBB =
+ I->getParent()->splitBasicBlock(I, "overflow.entry", /*Before*/ true);
+ BasicBlock *OverflowLHSBB = BasicBlock::Create(
+ I->getContext(), "overflow.lhs", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowLHSBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.lhs", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowRHSonlyBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.rhs.only", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowLHSonlyBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.lhs.only", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowBB = BasicBlock::Create(
+ I->getContext(), "overflow.no", I->getFunction(), OverflowResBB);
+ BasicBlock *OverflowBB = BasicBlock::Create(I->getContext(), "overflow",
+ I->getFunction(), OverflowResBB);
+ // new blocks should be:
+ // entry:
+ // lhs_lo ne lhs_hi ? overflow_yes_lhs, overflow_no_lhs
+
+ // overflow_yes_lhs:
+ // rhs_lo ne rhs_hi ? overflow : overflow_no_rhs_only
+
+ // overflow_no_lhs:
+ // rhs_lo ne rhs_hi ? overflow_no_lhs_only : overflow_no
+
+ // overflow_no_rhs_only:
+ // overflow_no_lhs_only:
+ // overflow_no:
+ // overflow:
+ // overflow.res:
+
+ IRBuilder<> BuilderEntryBB(OverflowoEntryBB->getTerminator());
+ IRBuilder<> BuilderOverflowLHSBB(OverflowLHSBB);
+ IRBuilder<> BuilderNoOverflowLHSBB(NoOverflowLHSBB);
+ IRBuilder<> BuilderNoOverflowRHSonlyBB(NoOverflowRHSonlyBB);
+ IRBuilder<> BuilderNoOverflowLHSonlyBB(NoOverflowLHSonlyBB);
+ IRBuilder<> BuilderNoOverflowBB(NoOverflowBB);
+ IRBuilder<> BuilderOverflowResBB(OverflowResBB,
+ OverflowResBB->getFirstInsertionPt());
+
+ //------------------------------------------------------------------------------
+ // BB overflow.entry:
+ // get Lo and Hi of RHS & LHS:
+
+ auto *LoRHS = BuilderEntryBB.CreateTrunc(RHS, LegalTy, "lo.rhs");
+ auto *SignLoRHS =
+ BuilderEntryBB.CreateAShr(LoRHS, VTHalfBitWidth - 1, "sign.lo.rhs");
+ auto *HiRHS = BuilderEntryBB.CreateLShr(RHS, VTHalfBitWidth, "rhs.lsr");
+ HiRHS = BuilderEntryBB.CreateTrunc(HiRHS, LegalTy, "hi.rhs");
+
+ auto *LoLHS = BuilderEntryBB.CreateTrunc(LHS, LegalTy, "lo.lhs");
+ auto *SignLoLHS =
+ BuilderEntryBB.CreateAShr(LoLHS, VTHalfBitWidth - 1, "sign.lo.lhs");
+ auto *HiLHS = BuilderEntryBB.CreateLShr(LHS, VTHalfBitWidth, "lhs.lsr");
+ HiLHS = BuilderEntryBB.CreateTrunc(HiLHS, LegalTy, "hi.lhs");
+
+ auto *Cmp = BuilderEntryBB.CreateCmp(ICmpInst::ICMP_NE, HiLHS, SignLoLHS);
+ BuilderEntryBB.CreateCondBr(Cmp, OverflowLHSBB, NoOverflowLHSBB);
+ OverflowoEntryBB->getTerminator()->eraseFromParent();
+
+ //------------------------------------------------------------------------------
+ // BB overflow_yes_lhs:
+ Cmp = BuilderOverflowLHSBB.CreateCmp(ICmpInst::ICMP_NE, HiRHS, SignLoRHS);
+ BuilderOverflowLHSBB.CreateCondBr(Cmp, OverflowBB, NoOverflowRHSonlyBB);
+
+ //------------------------------------------------------------------------------
+ // BB overflow_no_lhs:
+ Cmp = BuilderNoOverflowLHSBB.CreateCmp(ICmpInst::ICMP_NE, HiRHS, SignLoRHS);
+ BuilderNoOverflowLHSBB.CreateCondBr(Cmp, NoOverflowLHSonlyBB, NoOverflowBB);
+
+ //------------------------------------------------------------------------------
+ // BB overflow_no_rhs_only:
+ // RHS is within 64 value range, LHS is 128
+ // P0 = RHS * LoLHS
+ // P1 = RHS * HiLHS
+
+ // check sign of RHS:
+ auto *IsNegRHS = BuilderNoOverflowRHSonlyBB.CreateIsNeg(RHS, "rhs.isneg");
+ auto *AbsRHSIntr = BuilderNoOverflowRHSonlyBB.CreateBinaryIntrinsic(
+ Intrinsic::abs, RHS, ConstantInt::getFalse(I->getContext()), {},
+ "abs.rhs");
+ auto *AbsRHS = BuilderNoOverflowRHSonlyBB.CreateSelect(
+ IsNegRHS, AbsRHSIntr, RHS, "lo.abs.rhs.select");
+
+ // check sign of LHS:
+ auto *IsNegLHS = BuilderNoOverflowRHSonlyBB.CreateIsNeg(LHS, "lhs.isneg");
+ auto *AbsLHSIntr = BuilderNoOverflowRHSonlyBB.CreateBinaryIntrinsic(
+ Intrinsic::abs, LHS, ConstantInt::getFalse(I->getContext()), {},
+ "abs.lhs");
+ auto *AbsLHS = BuilderNoOverflowRHSonlyBB.CreateSelect(IsNegLHS, AbsLHSIntr,
+ LHS, "abs.lhs.select");
+ LoLHS = BuilderNoOverflowRHSonlyBB.CreateAnd(
+ AbsLHS,
+ ConstantInt::get(Ty, APInt::getLowBitsSet(VTBitWidth, VTHalfBitWidth)),
+ "lo.abs.lhs");
+ HiLHS = BuilderNoOverflowRHSonlyBB.CreateLShr(AbsLHS, VTHalfBitWidth,
+ "hi.abs.lhs");
+
+ // P0 = (RHS * LoLHS)
+ auto *P0 = BuilderNoOverflowRHSonlyBB.CreateMul(AbsRHS, LoLHS,
+ "mul.no.overflow.rhs.lolhs");
+ auto *P0Lo = BuilderNoOverflowRHSonlyBB.CreateTrunc(P0, LegalTy, "p0.lo.rhs");
+ auto *P0Hi =
+ BuilderNoOverflowRHSonlyBB.CreateLShr(P0, VTHalfBitWidth, "p0.rhs.lsr");
+ P0Hi = BuilderNoOverflowRHSonlyBB.CreateTrunc(P0Hi, LegalTy, "p0.hi.rhs");
+
+ // P1 = (RHS * HiLHS)
+ auto *P1 = BuilderNoOverflowRHSonlyBB.CreateMul(AbsRHS, HiLHS,
+ "mul.no.overflow.rhs.hilhs");
+ auto *P1Lo = BuilderNoOverflowRHSonlyBB.CreateTrunc(P1, LegalTy, "p1.lo.rhs");
+ auto *P1Hi =
+ BuilderNoOverflowRHSonlyBB.CreateLShr(P1, VTHalfBitWidth, "p1.rhs.lsr");
+ P1Hi = BuilderNoOverflowRHSonlyBB.CreateTrunc(P1Hi, LegalTy, "p1.hi.rhs");
+
+ auto *AddOverflow = BuilderNoOverflowRHSonlyBB.CreateIntrinsic(
+ Intrinsic::uadd_with_overflow, LegalTy, {P0Hi, P1Lo});
+ auto *AddOResMid = BuilderNoOverflowRHSonlyBB.CreateExtractValue(
+ AddOverflow, 0, "rhs.p0.p1.res");
+ auto *Carry = BuilderNoOverflowRHSonlyBB.CreateExtractValue(
+ AddOverflow, 1, "rhs.p0.p1.carry");
+ Carry =
+ BuilderNoOverflowRHSonlyBB.CreateZExt(Carry, LegalTy, "rhs.carry.zext");
+ auto *ResHi =
+ BuilderNoOverflowRHSonlyBB.CreateAdd(P1Hi, Carry, "rhs.p1.carry");
+
+ // sign handling:
+ auto *IsNeg = BuilderNoOverflowRHSonlyBB.CreateXor(IsNegRHS, IsNegLHS); // i1
+ auto *Mask =
+ BuilderNoOverflowRHSonlyBB.CreateSExt(IsNeg, LegalTy, "rhs.sign.mask");
+ auto *Add_1 =
+ BuilderNoOverflowRHSonlyBB.CreateZExt(IsNeg, LegalTy, "rhs.add.1");
+ auto *ResLo =
+ BuilderNoOverflowRHSonlyBB.CreateXor(P0Lo, Mask, "rhs.res_lo.xor.mask");
+ ResLo =
+ BuilderNoOverflowRHSonlyBB.CreateAdd(ResLo, Add_1, "rhs.res_lo.add.1");
+
+ Carry = BuilderNoOverflowRHSonlyBB.CreateCmp(ICmpInst::ICMP_ULT, ResLo, Add_1,
+ "rhs.check.res_lo.c...
[truncated]
|
@llvm/pr-subscribers-backend-risc-v Author: Hassnaa Hamdi (hassnaaHamdi) Changes
Patch is 676.89 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/148343.diff 20 Files Affected:
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index 9bbb89e37865d..d9859ed246604 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -431,6 +431,8 @@ class CodeGenPrepare {
bool optimizeMemoryInst(Instruction *MemoryInst, Value *Addr, Type *AccessTy,
unsigned AddrSpace);
bool optimizeGatherScatterInst(Instruction *MemoryInst, Value *Ptr);
+ bool optimizeUMulWithOverflow(Instruction *I);
+ bool optimizeSMulWithOverflow(Instruction *I);
bool optimizeInlineAsmInst(CallInst *CS);
bool optimizeCallInst(CallInst *CI, ModifyDT &ModifiedDT);
bool optimizeExt(Instruction *&I);
@@ -2769,6 +2771,10 @@ bool CodeGenPrepare::optimizeCallInst(CallInst *CI, ModifyDT &ModifiedDT) {
return optimizeGatherScatterInst(II, II->getArgOperand(0));
case Intrinsic::masked_scatter:
return optimizeGatherScatterInst(II, II->getArgOperand(1));
+ case Intrinsic::umul_with_overflow:
+ return optimizeUMulWithOverflow(II);
+ case Intrinsic::smul_with_overflow:
+ return optimizeSMulWithOverflow(II);
}
SmallVector<Value *, 2> PtrOps;
@@ -6386,6 +6392,573 @@ bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
return true;
}
+// Rewrite the umul_with_overflow intrinsic by checking if any/both of the
+// operands' value range is within the legal type. If so, we can optimize the
+// multiplication algorithm. This code is supposed to be written during the step
+// of type legalization, but given that we need to reconstruct the IR which is
+// not doable there, we do it here.
+bool CodeGenPrepare::optimizeUMulWithOverflow(Instruction *I) {
+ if (TLI->getTypeAction(
+ I->getContext(),
+ TLI->getValueType(*DL, I->getType()->getContainedType(0))) !=
+ TargetLowering::TypeExpandInteger)
+ return false;
+ Value *LHS = I->getOperand(0);
+ Value *RHS = I->getOperand(1);
+ auto *Ty = LHS->getType();
+ unsigned VTBitWidth = Ty->getScalarSizeInBits();
+ unsigned VTHalfBitWidth = VTBitWidth / 2;
+ auto *LegalTy = IntegerType::getIntNTy(I->getContext(), VTHalfBitWidth);
+
+ assert(
+ (TLI->getTypeAction(I->getContext(), TLI->getValueType(*DL, LegalTy)) ==
+ TargetLowering::TypeLegal) &&
+ "Expected the type to be legal for the target lowering");
+
+ I->getParent()->setName("overflow.res");
+ auto *OverflowResBB = I->getParent();
+ auto *OverflowoEntryBB =
+ I->getParent()->splitBasicBlock(I, "overflow.entry", /*Before*/ true);
+ BasicBlock *OverflowLHSBB = BasicBlock::Create(
+ I->getContext(), "overflow.lhs", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowLHSBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.lhs", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowRHSonlyBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.rhs.only", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowLHSonlyBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.lhs.only", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowBB = BasicBlock::Create(
+ I->getContext(), "overflow.no", I->getFunction(), OverflowResBB);
+ BasicBlock *OverflowBB = BasicBlock::Create(I->getContext(), "overflow",
+ I->getFunction(), OverflowResBB);
+ // new blocks should be:
+ // entry:
+ // lhs_lo ne lhs_hi ? overflow_yes_lhs, overflow_no_lhs
+
+ // overflow_yes_lhs:
+ // rhs_lo ne rhs_hi ? overflow : overflow_no_rhs_only
+
+ // overflow_no_lhs:
+ // rhs_lo ne rhs_hi ? overflow_no_lhs_only : overflow_no
+
+ // overflow_no_rhs_only:
+ // overflow_no_lhs_only:
+ // overflow_no:
+ // overflow:
+ // overflow.res:
+
+ IRBuilder<> BuilderEntryBB(OverflowoEntryBB->getTerminator());
+ IRBuilder<> BuilderOverflowLHSBB(OverflowLHSBB);
+ IRBuilder<> BuilderNoOverflowLHSBB(NoOverflowLHSBB);
+ IRBuilder<> BuilderNoOverflowRHSonlyBB(NoOverflowRHSonlyBB);
+ IRBuilder<> BuilderNoOverflowLHSonlyBB(NoOverflowLHSonlyBB);
+ IRBuilder<> BuilderNoOverflowBB(NoOverflowBB);
+ IRBuilder<> BuilderOverflowResBB(OverflowResBB,
+ OverflowResBB->getFirstInsertionPt());
+
+ //------------------------------------------------------------------------------
+ // BB overflow.entry:
+ // get Lo and Hi of RHS & LHS:
+
+ auto *LoRHS = BuilderEntryBB.CreateTrunc(RHS, LegalTy, "lo.rhs.trunc");
+ auto *ShrHiRHS = BuilderEntryBB.CreateLShr(RHS, VTHalfBitWidth, "rhs.lsr");
+ auto *HiRHS = BuilderEntryBB.CreateTrunc(ShrHiRHS, LegalTy, "hi.rhs.trunc");
+
+ auto *LoLHS = BuilderEntryBB.CreateTrunc(LHS, LegalTy, "lo.lhs.trunc");
+ auto *ShrHiLHS = BuilderEntryBB.CreateLShr(LHS, VTHalfBitWidth, "lhs.lsr");
+ auto *HiLHS = BuilderEntryBB.CreateTrunc(ShrHiLHS, LegalTy, "hi.lhs.trunc");
+
+ auto *Cmp = BuilderEntryBB.CreateCmp(ICmpInst::ICMP_NE, HiLHS,
+ ConstantInt::getNullValue(LegalTy));
+ BuilderEntryBB.CreateCondBr(Cmp, OverflowLHSBB, NoOverflowLHSBB);
+ OverflowoEntryBB->getTerminator()->eraseFromParent();
+
+ //------------------------------------------------------------------------------
+ // BB overflow_yes_lhs:
+ Cmp = BuilderOverflowLHSBB.CreateCmp(ICmpInst::ICMP_NE, HiRHS,
+ ConstantInt::getNullValue(LegalTy));
+ BuilderOverflowLHSBB.CreateCondBr(Cmp, OverflowBB, NoOverflowRHSonlyBB);
+
+ //------------------------------------------------------------------------------
+ // BB overflow_no_lhs:
+ Cmp = BuilderNoOverflowLHSBB.CreateCmp(ICmpInst::ICMP_NE, HiRHS,
+ ConstantInt::getNullValue(LegalTy));
+ BuilderNoOverflowLHSBB.CreateCondBr(Cmp, NoOverflowLHSonlyBB, NoOverflowBB);
+
+ //------------------------------------------------------------------------------
+ // BB overflow_no_rhs_only:
+ // RHS is 64 value range, LHS is 128
+ // P0 = RHS * LoLHS
+ // P1 = RHS * HiLHS
+
+ LoLHS = BuilderNoOverflowRHSonlyBB.CreateZExt(LoLHS, Ty, "lo.lhs");
+
+ // P0 = (RHS * LoLHS)
+ auto *P0 = BuilderNoOverflowRHSonlyBB.CreateMul(RHS, LoLHS,
+ "mul.no.overflow.rhs.lolhs");
+ auto *P0Lo = BuilderNoOverflowRHSonlyBB.CreateTrunc(P0, LegalTy, "p0.lo.rhs");
+ auto *P0Hi =
+ BuilderNoOverflowRHSonlyBB.CreateLShr(P0, VTHalfBitWidth, "p0.rhs.lsr");
+ P0Hi = BuilderNoOverflowRHSonlyBB.CreateTrunc(P0Hi, LegalTy, "p0.hi.rhs");
+
+ // P1 = (RHS * HiLHS)
+ auto *P1 = BuilderNoOverflowRHSonlyBB.CreateMul(RHS, ShrHiLHS,
+ "mul.no.overflow.rhs.hilhs");
+ auto *P1Lo = BuilderNoOverflowRHSonlyBB.CreateTrunc(P1, LegalTy, "p1.lo.rhs");
+ auto *P1Hi =
+ BuilderNoOverflowRHSonlyBB.CreateLShr(P1, VTHalfBitWidth, "p1.rhs.lsr");
+ P1Hi = BuilderNoOverflowRHSonlyBB.CreateTrunc(P1Hi, LegalTy, "p1.hi.rhs");
+
+ auto *AddOverflow = BuilderNoOverflowRHSonlyBB.CreateIntrinsic(
+ Intrinsic::uadd_with_overflow, LegalTy, {P0Hi, P1Lo});
+ auto *AddOResMid = BuilderNoOverflowRHSonlyBB.CreateExtractValue(
+ AddOverflow, 0, "rhs.p0.p1.res");
+ auto *Carry = BuilderNoOverflowRHSonlyBB.CreateExtractValue(
+ AddOverflow, 1, "rhs.p0.p1.carry");
+ Carry =
+ BuilderNoOverflowRHSonlyBB.CreateZExt(Carry, LegalTy, "rhs.carry.zext");
+ auto *ResHi =
+ BuilderNoOverflowRHSonlyBB.CreateAdd(P1Hi, Carry, "rhs.p1.carry");
+
+ auto *ResLoEx =
+ BuilderNoOverflowRHSonlyBB.CreateZExt(P0Lo, Ty, "rhs.res_lo.zext");
+ auto *ResMid =
+ BuilderNoOverflowRHSonlyBB.CreateZExt(AddOResMid, Ty, "rhs.res_mid.zext");
+ auto *ResMidShl = BuilderNoOverflowRHSonlyBB.CreateShl(ResMid, VTHalfBitWidth,
+ "rhs.res_mid.shl");
+ auto *FinalRes = BuilderNoOverflowRHSonlyBB.CreateOr(ResLoEx, ResMidShl,
+ "rhs.res_lo.or.mid");
+ auto *IsOverflow = BuilderNoOverflowRHSonlyBB.CreateICmp(
+ ICmpInst::ICMP_NE, ResHi, Constant::getNullValue(LegalTy),
+ "rhs.check.overflow");
+
+ StructType *STy = StructType::get(
+ I->getContext(), {Ty, IntegerType::getInt1Ty(I->getContext())});
+ Value *StructValNoOverflowRHS = PoisonValue::get(STy);
+ StructValNoOverflowRHS = BuilderNoOverflowRHSonlyBB.CreateInsertValue(
+ StructValNoOverflowRHS, FinalRes, {0});
+ StructValNoOverflowRHS = BuilderNoOverflowRHSonlyBB.CreateInsertValue(
+ StructValNoOverflowRHS, IsOverflow, {1});
+ BuilderNoOverflowRHSonlyBB.CreateBr(OverflowResBB);
+ //------------------------------------------------------------------------------
+
+ // BB overflow_no_lhs_only:
+
+ LoRHS = BuilderNoOverflowLHSonlyBB.CreateZExt(LoRHS, Ty, "lo.rhs");
+
+ // P0 = (LHS * LoRHS)
+ P0 = BuilderNoOverflowLHSonlyBB.CreateMul(LHS, LoRHS,
+ "mul.no.overflow.lhs.lorhs");
+ P0Lo = BuilderNoOverflowLHSonlyBB.CreateTrunc(P0, LegalTy, "p0.lo.lhs");
+ P0Hi =
+ BuilderNoOverflowLHSonlyBB.CreateLShr(P0, VTHalfBitWidth, "p0.lsr.lhs");
+ P0Hi = BuilderNoOverflowLHSonlyBB.CreateTrunc(P0Hi, LegalTy, "p0.hi.lhs");
+
+ // P1 = (LHS * HiRHS)
+ P1 = BuilderNoOverflowLHSonlyBB.CreateMul(LHS, ShrHiRHS,
+ "mul.no.overflow.lhs.hirhs");
+ P1Lo = BuilderNoOverflowLHSonlyBB.CreateTrunc(P1, LegalTy, "p1.lo.lhs");
+ P1Hi =
+ BuilderNoOverflowLHSonlyBB.CreateLShr(P1, VTHalfBitWidth, "p1.lhs.lsr");
+ P1Hi = BuilderNoOverflowLHSonlyBB.CreateTrunc(P1Hi, LegalTy, "p1.hi.lhs");
+
+ AddOverflow = BuilderNoOverflowLHSonlyBB.CreateIntrinsic(
+ Intrinsic::uadd_with_overflow, LegalTy, {P0Hi, P1Lo});
+ AddOResMid = BuilderNoOverflowLHSonlyBB.CreateExtractValue(AddOverflow, 0,
+ "lhs.p0.p1.res");
+ Carry = BuilderNoOverflowLHSonlyBB.CreateExtractValue(AddOverflow, 1,
+ "lhs.p0.p1.carry");
+ Carry =
+ BuilderNoOverflowLHSonlyBB.CreateZExt(Carry, LegalTy, "lhs.carry.zext");
+ ResHi = BuilderNoOverflowLHSonlyBB.CreateAdd(P1Hi, Carry, "lhs.p1.carry");
+
+ ResLoEx = BuilderNoOverflowLHSonlyBB.CreateZExt(P0Lo, Ty, "lhs.res_lo.zext");
+ ResMid =
+ BuilderNoOverflowLHSonlyBB.CreateZExt(AddOResMid, Ty, "lhs.res_mid.zext");
+ ResMidShl = BuilderNoOverflowLHSonlyBB.CreateShl(ResMid, VTHalfBitWidth,
+ "lhs.res_mid.shl");
+ FinalRes = BuilderNoOverflowLHSonlyBB.CreateOr(ResLoEx, ResMidShl,
+ "lhs.res_lo.or.mid");
+ IsOverflow = BuilderNoOverflowLHSonlyBB.CreateICmp(
+ ICmpInst::ICMP_NE, ResHi, Constant::getNullValue(LegalTy),
+ "lhs.check.overflow");
+
+ STy = StructType::get(I->getContext(),
+ {Ty, IntegerType::getInt1Ty(I->getContext())});
+ Value *StructValNoOverflowLHS = PoisonValue::get(STy);
+ StructValNoOverflowLHS = BuilderNoOverflowLHSonlyBB.CreateInsertValue(
+ StructValNoOverflowLHS, FinalRes, {0});
+ StructValNoOverflowLHS = BuilderNoOverflowLHSonlyBB.CreateInsertValue(
+ StructValNoOverflowLHS, IsOverflow, {1});
+
+ BuilderNoOverflowLHSonlyBB.CreateBr(OverflowResBB);
+ //------------------------------------------------------------------------------
+
+ // BB overflow.no:
+ auto *Mul = BuilderNoOverflowBB.CreateMul(LHS, RHS, "mul.no.overflow");
+ STy = StructType::get(I->getContext(),
+ {Ty, IntegerType::getInt1Ty(I->getContext())});
+ Value *StructValNoOverflow = PoisonValue::get(STy);
+ StructValNoOverflow =
+ BuilderNoOverflowBB.CreateInsertValue(StructValNoOverflow, Mul, {0});
+ StructValNoOverflow = BuilderNoOverflowBB.CreateInsertValue(
+ StructValNoOverflow, ConstantInt::getFalse(I->getContext()), {1});
+ BuilderNoOverflowBB.CreateBr(OverflowResBB);
+
+ // BB overflow.res:
+ auto *PHINode = BuilderOverflowResBB.CreatePHI(STy, 2);
+ PHINode->addIncoming(StructValNoOverflow, NoOverflowBB);
+ PHINode->addIncoming(StructValNoOverflowLHS, NoOverflowLHSonlyBB);
+ PHINode->addIncoming(StructValNoOverflowRHS, NoOverflowRHSonlyBB);
+
+ // Before moving the mul.overflow intrinsic to the overflowBB, replace all its
+ // uses by PHINode.
+ I->replaceAllUsesWith(PHINode);
+
+ // BB overflow:
+ PHINode->addIncoming(I, OverflowBB);
+ I->removeFromParent();
+ I->insertInto(OverflowBB, OverflowBB->end());
+ IRBuilder<>(OverflowBB, OverflowBB->end()).CreateBr(OverflowResBB);
+
+ // return false to stop reprocessing the function.
+ return false;
+}
+
+// Rewrite the smul_with_overflow intrinsic by checking if any/both of the
+// operands' value range is within the legal type. If so, we can optimize the
+// multiplication algorithm. This code is supposed to be written during the step
+// of type legalization, but given that we need to reconstruct the IR which is
+// not doable there, we do it here.
+bool CodeGenPrepare::optimizeSMulWithOverflow(Instruction *I) {
+ if (TLI->getTypeAction(
+ I->getContext(),
+ TLI->getValueType(*DL, I->getType()->getContainedType(0))) !=
+ TargetLowering::TypeExpandInteger)
+ return false;
+ Value *LHS = I->getOperand(0);
+ Value *RHS = I->getOperand(1);
+ auto *Ty = LHS->getType();
+ unsigned VTBitWidth = Ty->getScalarSizeInBits();
+ unsigned VTHalfBitWidth = VTBitWidth / 2;
+ auto *LegalTy = IntegerType::getIntNTy(I->getContext(), VTHalfBitWidth);
+
+ assert(
+ (TLI->getTypeAction(I->getContext(), TLI->getValueType(*DL, LegalTy)) ==
+ TargetLowering::TypeLegal) &&
+ "Expected the type to be legal for the target lowering");
+
+ I->getParent()->setName("overflow.res");
+ auto *OverflowResBB = I->getParent();
+ auto *OverflowoEntryBB =
+ I->getParent()->splitBasicBlock(I, "overflow.entry", /*Before*/ true);
+ BasicBlock *OverflowLHSBB = BasicBlock::Create(
+ I->getContext(), "overflow.lhs", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowLHSBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.lhs", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowRHSonlyBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.rhs.only", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowLHSonlyBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.lhs.only", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowBB = BasicBlock::Create(
+ I->getContext(), "overflow.no", I->getFunction(), OverflowResBB);
+ BasicBlock *OverflowBB = BasicBlock::Create(I->getContext(), "overflow",
+ I->getFunction(), OverflowResBB);
+ // new blocks should be:
+ // entry:
+ // lhs_lo ne lhs_hi ? overflow_yes_lhs, overflow_no_lhs
+
+ // overflow_yes_lhs:
+ // rhs_lo ne rhs_hi ? overflow : overflow_no_rhs_only
+
+ // overflow_no_lhs:
+ // rhs_lo ne rhs_hi ? overflow_no_lhs_only : overflow_no
+
+ // overflow_no_rhs_only:
+ // overflow_no_lhs_only:
+ // overflow_no:
+ // overflow:
+ // overflow.res:
+
+ IRBuilder<> BuilderEntryBB(OverflowoEntryBB->getTerminator());
+ IRBuilder<> BuilderOverflowLHSBB(OverflowLHSBB);
+ IRBuilder<> BuilderNoOverflowLHSBB(NoOverflowLHSBB);
+ IRBuilder<> BuilderNoOverflowRHSonlyBB(NoOverflowRHSonlyBB);
+ IRBuilder<> BuilderNoOverflowLHSonlyBB(NoOverflowLHSonlyBB);
+ IRBuilder<> BuilderNoOverflowBB(NoOverflowBB);
+ IRBuilder<> BuilderOverflowResBB(OverflowResBB,
+ OverflowResBB->getFirstInsertionPt());
+
+ //------------------------------------------------------------------------------
+ // BB overflow.entry:
+ // get Lo and Hi of RHS & LHS:
+
+ auto *LoRHS = BuilderEntryBB.CreateTrunc(RHS, LegalTy, "lo.rhs");
+ auto *SignLoRHS =
+ BuilderEntryBB.CreateAShr(LoRHS, VTHalfBitWidth - 1, "sign.lo.rhs");
+ auto *HiRHS = BuilderEntryBB.CreateLShr(RHS, VTHalfBitWidth, "rhs.lsr");
+ HiRHS = BuilderEntryBB.CreateTrunc(HiRHS, LegalTy, "hi.rhs");
+
+ auto *LoLHS = BuilderEntryBB.CreateTrunc(LHS, LegalTy, "lo.lhs");
+ auto *SignLoLHS =
+ BuilderEntryBB.CreateAShr(LoLHS, VTHalfBitWidth - 1, "sign.lo.lhs");
+ auto *HiLHS = BuilderEntryBB.CreateLShr(LHS, VTHalfBitWidth, "lhs.lsr");
+ HiLHS = BuilderEntryBB.CreateTrunc(HiLHS, LegalTy, "hi.lhs");
+
+ auto *Cmp = BuilderEntryBB.CreateCmp(ICmpInst::ICMP_NE, HiLHS, SignLoLHS);
+ BuilderEntryBB.CreateCondBr(Cmp, OverflowLHSBB, NoOverflowLHSBB);
+ OverflowoEntryBB->getTerminator()->eraseFromParent();
+
+ //------------------------------------------------------------------------------
+ // BB overflow_yes_lhs:
+ Cmp = BuilderOverflowLHSBB.CreateCmp(ICmpInst::ICMP_NE, HiRHS, SignLoRHS);
+ BuilderOverflowLHSBB.CreateCondBr(Cmp, OverflowBB, NoOverflowRHSonlyBB);
+
+ //------------------------------------------------------------------------------
+ // BB overflow_no_lhs:
+ Cmp = BuilderNoOverflowLHSBB.CreateCmp(ICmpInst::ICMP_NE, HiRHS, SignLoRHS);
+ BuilderNoOverflowLHSBB.CreateCondBr(Cmp, NoOverflowLHSonlyBB, NoOverflowBB);
+
+ //------------------------------------------------------------------------------
+ // BB overflow_no_rhs_only:
+ // RHS is within 64 value range, LHS is 128
+ // P0 = RHS * LoLHS
+ // P1 = RHS * HiLHS
+
+ // check sign of RHS:
+ auto *IsNegRHS = BuilderNoOverflowRHSonlyBB.CreateIsNeg(RHS, "rhs.isneg");
+ auto *AbsRHSIntr = BuilderNoOverflowRHSonlyBB.CreateBinaryIntrinsic(
+ Intrinsic::abs, RHS, ConstantInt::getFalse(I->getContext()), {},
+ "abs.rhs");
+ auto *AbsRHS = BuilderNoOverflowRHSonlyBB.CreateSelect(
+ IsNegRHS, AbsRHSIntr, RHS, "lo.abs.rhs.select");
+
+ // check sign of LHS:
+ auto *IsNegLHS = BuilderNoOverflowRHSonlyBB.CreateIsNeg(LHS, "lhs.isneg");
+ auto *AbsLHSIntr = BuilderNoOverflowRHSonlyBB.CreateBinaryIntrinsic(
+ Intrinsic::abs, LHS, ConstantInt::getFalse(I->getContext()), {},
+ "abs.lhs");
+ auto *AbsLHS = BuilderNoOverflowRHSonlyBB.CreateSelect(IsNegLHS, AbsLHSIntr,
+ LHS, "abs.lhs.select");
+ LoLHS = BuilderNoOverflowRHSonlyBB.CreateAnd(
+ AbsLHS,
+ ConstantInt::get(Ty, APInt::getLowBitsSet(VTBitWidth, VTHalfBitWidth)),
+ "lo.abs.lhs");
+ HiLHS = BuilderNoOverflowRHSonlyBB.CreateLShr(AbsLHS, VTHalfBitWidth,
+ "hi.abs.lhs");
+
+ // P0 = (RHS * LoLHS)
+ auto *P0 = BuilderNoOverflowRHSonlyBB.CreateMul(AbsRHS, LoLHS,
+ "mul.no.overflow.rhs.lolhs");
+ auto *P0Lo = BuilderNoOverflowRHSonlyBB.CreateTrunc(P0, LegalTy, "p0.lo.rhs");
+ auto *P0Hi =
+ BuilderNoOverflowRHSonlyBB.CreateLShr(P0, VTHalfBitWidth, "p0.rhs.lsr");
+ P0Hi = BuilderNoOverflowRHSonlyBB.CreateTrunc(P0Hi, LegalTy, "p0.hi.rhs");
+
+ // P1 = (RHS * HiLHS)
+ auto *P1 = BuilderNoOverflowRHSonlyBB.CreateMul(AbsRHS, HiLHS,
+ "mul.no.overflow.rhs.hilhs");
+ auto *P1Lo = BuilderNoOverflowRHSonlyBB.CreateTrunc(P1, LegalTy, "p1.lo.rhs");
+ auto *P1Hi =
+ BuilderNoOverflowRHSonlyBB.CreateLShr(P1, VTHalfBitWidth, "p1.rhs.lsr");
+ P1Hi = BuilderNoOverflowRHSonlyBB.CreateTrunc(P1Hi, LegalTy, "p1.hi.rhs");
+
+ auto *AddOverflow = BuilderNoOverflowRHSonlyBB.CreateIntrinsic(
+ Intrinsic::uadd_with_overflow, LegalTy, {P0Hi, P1Lo});
+ auto *AddOResMid = BuilderNoOverflowRHSonlyBB.CreateExtractValue(
+ AddOverflow, 0, "rhs.p0.p1.res");
+ auto *Carry = BuilderNoOverflowRHSonlyBB.CreateExtractValue(
+ AddOverflow, 1, "rhs.p0.p1.carry");
+ Carry =
+ BuilderNoOverflowRHSonlyBB.CreateZExt(Carry, LegalTy, "rhs.carry.zext");
+ auto *ResHi =
+ BuilderNoOverflowRHSonlyBB.CreateAdd(P1Hi, Carry, "rhs.p1.carry");
+
+ // sign handling:
+ auto *IsNeg = BuilderNoOverflowRHSonlyBB.CreateXor(IsNegRHS, IsNegLHS); // i1
+ auto *Mask =
+ BuilderNoOverflowRHSonlyBB.CreateSExt(IsNeg, LegalTy, "rhs.sign.mask");
+ auto *Add_1 =
+ BuilderNoOverflowRHSonlyBB.CreateZExt(IsNeg, LegalTy, "rhs.add.1");
+ auto *ResLo =
+ BuilderNoOverflowRHSonlyBB.CreateXor(P0Lo, Mask, "rhs.res_lo.xor.mask");
+ ResLo =
+ BuilderNoOverflowRHSonlyBB.CreateAdd(ResLo, Add_1, "rhs.res_lo.add.1");
+
+ Carry = BuilderNoOverflowRHSonlyBB.CreateCmp(ICmpInst::ICMP_ULT, ResLo, Add_1,
+ "rhs.check.res_lo.c...
[truncated]
|
; AARCH-NEXT: mov w9, wzr | ||
; AARCH-NEXT: madd x8, x0, x3, x8 | ||
; AARCH-NEXT: mul x0, x0, x2 | ||
; AARCH-NEXT: madd x8, x1, x2, x8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the 2 instruction of madd
, there should be extra patch to fold them away as x3 and x1 are expected to be 0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I question the value of this optimization, especially as implemented right now. This might (and I'm really not convinced) make more sense if you only specialize for the case of both high halves being zero without covering mixed cases, which blows up code size a lot.
I don't think that's a good idea as this is currently implemented and I'd like to see some benchmark results. In our database system, we make heavy use of 128-bit mul-with-overflow. Having spent some time optimizing this specific case:
For x86-64, the full out-of-line function is hand-optimized assembly code, which performs slightly better than LLVM's expansion (e.g., uses one less register, exploits some x86-specific flags tricks, and has optimized scheduling for some recent uArches) (also note: GCC's expansion is (was?) horrible with lots of data-dependent branches). For AArch64, we use the compiler-rt function, we haven't felt the need to look closer into this so far (most of our benchmarks target x86-64 :-) ). |
FWIW, the reason why LLVM expands this inline and does not use the compiler-rt builtin is that we, unfortunately, have to cater to the lowest common denominator, which is libgcc, which does not support the __muloti4 builtin :( |