diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp index e63633b8a1e1ab..2c481d15be5bde 100644 --- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp +++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp @@ -51,7 +51,7 @@ static bool isIntrinsicExpansion(Function &F) { return false; } -static bool expandAbs(CallInst *Orig) { +static Value *expandAbs(CallInst *Orig) { Value *X = Orig->getOperand(0); IRBuilder<> Builder(Orig->getParent()); Builder.SetInsertPoint(Orig); @@ -64,14 +64,11 @@ static bool expandAbs(CallInst *Orig) { ConstantInt::get(EltTy, 0)) : ConstantInt::get(EltTy, 0); auto *V = Builder.CreateSub(Zero, X); - auto *MaxCall = - Builder.CreateIntrinsic(Ty, Intrinsic::smax, {X, V}, nullptr, "dx.max"); - Orig->replaceAllUsesWith(MaxCall); - Orig->eraseFromParent(); - return true; + return Builder.CreateIntrinsic(Ty, Intrinsic::smax, {X, V}, nullptr, + "dx.max"); } -static bool expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) { +static Value *expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) { assert(DotIntrinsic == Intrinsic::dx_sdot || DotIntrinsic == Intrinsic::dx_udot); Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot @@ -97,12 +94,10 @@ static bool expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) { ArrayRef{Elt0, Elt1, Result}, nullptr, "dx.mad"); } - Orig->replaceAllUsesWith(Result); - Orig->eraseFromParent(); - return true; + return Result; } -static bool expandExpIntrinsic(CallInst *Orig) { +static Value *expandExpIntrinsic(CallInst *Orig) { Value *X = Orig->getOperand(0); IRBuilder<> Builder(Orig->getParent()); Builder.SetInsertPoint(Orig); @@ -119,23 +114,21 @@ static bool expandExpIntrinsic(CallInst *Orig) { Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {NewX}, nullptr, "dx.exp2"); Exp2Call->setTailCall(Orig->isTailCall()); Exp2Call->setAttributes(Orig->getAttributes()); - Orig->replaceAllUsesWith(Exp2Call); - Orig->eraseFromParent(); - return true; + return Exp2Call; } -static bool expandAnyIntrinsic(CallInst *Orig) { +static Value *expandAnyIntrinsic(CallInst *Orig) { Value *X = Orig->getOperand(0); IRBuilder<> Builder(Orig->getParent()); Builder.SetInsertPoint(Orig); Type *Ty = X->getType(); Type *EltTy = Ty->getScalarType(); + Value *Result = nullptr; if (!Ty->isVectorTy()) { - Value *Cond = EltTy->isFloatingPointTy() - ? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0)) - : Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0)); - Orig->replaceAllUsesWith(Cond); + Result = EltTy->isFloatingPointTy() + ? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0)) + : Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0)); } else { auto *XVec = dyn_cast(Ty); Value *Cond = @@ -148,18 +141,16 @@ static bool expandAnyIntrinsic(CallInst *Orig) { X, ConstantVector::getSplat( ElementCount::getFixed(XVec->getNumElements()), ConstantInt::get(EltTy, 0))); - Value *Result = Builder.CreateExtractElement(Cond, (uint64_t)0); + Result = Builder.CreateExtractElement(Cond, (uint64_t)0); for (unsigned I = 1; I < XVec->getNumElements(); I++) { Value *Elt = Builder.CreateExtractElement(Cond, I); Result = Builder.CreateOr(Result, Elt); } - Orig->replaceAllUsesWith(Result); } - Orig->eraseFromParent(); - return true; + return Result; } -static bool expandLengthIntrinsic(CallInst *Orig) { +static Value *expandLengthIntrinsic(CallInst *Orig) { Value *X = Orig->getOperand(0); IRBuilder<> Builder(Orig->getParent()); Builder.SetInsertPoint(Orig); @@ -182,15 +173,11 @@ static bool expandLengthIntrinsic(CallInst *Orig) { Value *Mul = Builder.CreateFMul(Elt, Elt); Sum = Builder.CreateFAdd(Sum, Mul); } - Value *Result = Builder.CreateIntrinsic( - EltTy, Intrinsic::sqrt, ArrayRef{Sum}, nullptr, "elt.sqrt"); - - Orig->replaceAllUsesWith(Result); - Orig->eraseFromParent(); - return true; + return Builder.CreateIntrinsic(EltTy, Intrinsic::sqrt, ArrayRef{Sum}, + nullptr, "elt.sqrt"); } -static bool expandLerpIntrinsic(CallInst *Orig) { +static Value *expandLerpIntrinsic(CallInst *Orig) { Value *X = Orig->getOperand(0); Value *Y = Orig->getOperand(1); Value *S = Orig->getOperand(2); @@ -198,14 +185,11 @@ static bool expandLerpIntrinsic(CallInst *Orig) { Builder.SetInsertPoint(Orig); auto *V = Builder.CreateFSub(Y, X); V = Builder.CreateFMul(S, V); - auto *Result = Builder.CreateFAdd(X, V, "dx.lerp"); - Orig->replaceAllUsesWith(Result); - Orig->eraseFromParent(); - return true; + return Builder.CreateFAdd(X, V, "dx.lerp"); } -static bool expandLogIntrinsic(CallInst *Orig, - float LogConstVal = numbers::ln2f) { +static Value *expandLogIntrinsic(CallInst *Orig, + float LogConstVal = numbers::ln2f) { Value *X = Orig->getOperand(0); IRBuilder<> Builder(Orig->getParent()); Builder.SetInsertPoint(Orig); @@ -221,16 +205,13 @@ static bool expandLogIntrinsic(CallInst *Orig, Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2"); Log2Call->setTailCall(Orig->isTailCall()); Log2Call->setAttributes(Orig->getAttributes()); - auto *Result = Builder.CreateFMul(Ln2Const, Log2Call); - Orig->replaceAllUsesWith(Result); - Orig->eraseFromParent(); - return true; + return Builder.CreateFMul(Ln2Const, Log2Call); } -static bool expandLog10Intrinsic(CallInst *Orig) { +static Value *expandLog10Intrinsic(CallInst *Orig) { return expandLogIntrinsic(Orig, numbers::ln2f / numbers::ln10f); } -static bool expandNormalizeIntrinsic(CallInst *Orig) { +static Value *expandNormalizeIntrinsic(CallInst *Orig) { Value *X = Orig->getOperand(0); Type *Ty = Orig->getType(); Type *EltTy = Ty->getScalarType(); @@ -245,11 +226,7 @@ static bool expandNormalizeIntrinsic(CallInst *Orig) { report_fatal_error(Twine("Invalid input scalar: length is zero"), /* gen_crash_diag=*/false); } - Value *Result = Builder.CreateFDiv(X, X); - - Orig->replaceAllUsesWith(Result); - Orig->eraseFromParent(); - return true; + return Builder.CreateFDiv(X, X); } unsigned XVecSize = XVec->getNumElements(); @@ -291,14 +268,10 @@ static bool expandNormalizeIntrinsic(CallInst *Orig) { nullptr, "dx.rsqrt"); Value *MultiplicandVec = Builder.CreateVectorSplat(XVecSize, Multiplicand); - Value *Result = Builder.CreateFMul(X, MultiplicandVec); - - Orig->replaceAllUsesWith(Result); - Orig->eraseFromParent(); - return true; + return Builder.CreateFMul(X, MultiplicandVec); } -static bool expandPowIntrinsic(CallInst *Orig) { +static Value *expandPowIntrinsic(CallInst *Orig) { Value *X = Orig->getOperand(0); Value *Y = Orig->getOperand(1); @@ -313,9 +286,7 @@ static bool expandPowIntrinsic(CallInst *Orig) { Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {Mul}, nullptr, "elt.exp2"); Exp2Call->setTailCall(Orig->isTailCall()); Exp2Call->setAttributes(Orig->getAttributes()); - Orig->replaceAllUsesWith(Exp2Call); - Orig->eraseFromParent(); - return true; + return Exp2Call; } static Intrinsic::ID getMaxForClamp(Type *ElemTy, @@ -344,7 +315,8 @@ static Intrinsic::ID getMinForClamp(Type *ElemTy, return Intrinsic::minnum; } -static bool expandClampIntrinsic(CallInst *Orig, Intrinsic::ID ClampIntrinsic) { +static Value *expandClampIntrinsic(CallInst *Orig, + Intrinsic::ID ClampIntrinsic) { Value *X = Orig->getOperand(0); Value *Min = Orig->getOperand(1); Value *Max = Orig->getOperand(2); @@ -353,41 +325,54 @@ static bool expandClampIntrinsic(CallInst *Orig, Intrinsic::ID ClampIntrinsic) { Builder.SetInsertPoint(Orig); auto *MaxCall = Builder.CreateIntrinsic( Ty, getMaxForClamp(Ty, ClampIntrinsic), {X, Min}, nullptr, "dx.max"); - auto *MinCall = - Builder.CreateIntrinsic(Ty, getMinForClamp(Ty, ClampIntrinsic), - {MaxCall, Max}, nullptr, "dx.min"); - - Orig->replaceAllUsesWith(MinCall); - Orig->eraseFromParent(); - return true; + return Builder.CreateIntrinsic(Ty, getMinForClamp(Ty, ClampIntrinsic), + {MaxCall, Max}, nullptr, "dx.min"); } static bool expandIntrinsic(Function &F, CallInst *Orig) { + Value *Result = nullptr; switch (F.getIntrinsicID()) { case Intrinsic::abs: - return expandAbs(Orig); + Result = expandAbs(Orig); + break; case Intrinsic::exp: - return expandExpIntrinsic(Orig); + Result = expandExpIntrinsic(Orig); + break; case Intrinsic::log: - return expandLogIntrinsic(Orig); + Result = expandLogIntrinsic(Orig); + break; case Intrinsic::log10: - return expandLog10Intrinsic(Orig); + Result = expandLog10Intrinsic(Orig); + break; case Intrinsic::pow: - return expandPowIntrinsic(Orig); + Result = expandPowIntrinsic(Orig); + break; case Intrinsic::dx_any: - return expandAnyIntrinsic(Orig); + Result = expandAnyIntrinsic(Orig); + break; case Intrinsic::dx_uclamp: case Intrinsic::dx_clamp: - return expandClampIntrinsic(Orig, F.getIntrinsicID()); + Result = expandClampIntrinsic(Orig, F.getIntrinsicID()); + break; case Intrinsic::dx_lerp: - return expandLerpIntrinsic(Orig); + Result = expandLerpIntrinsic(Orig); + break; case Intrinsic::dx_length: - return expandLengthIntrinsic(Orig); + Result = expandLengthIntrinsic(Orig); + break; case Intrinsic::dx_normalize: - return expandNormalizeIntrinsic(Orig); + Result = expandNormalizeIntrinsic(Orig); + break; case Intrinsic::dx_sdot: case Intrinsic::dx_udot: - return expandIntegerDot(Orig, F.getIntrinsicID()); + Result = expandIntegerDot(Orig, F.getIntrinsicID()); + break; + } + + if (Result) { + Orig->replaceAllUsesWith(Result); + Orig->eraseFromParent(); + return true; } return false; }