Skip to content

Commit

Permalink
[NFC][DXIL] move replace/erase in DXIL intrinsic expansion to caller (l…
Browse files Browse the repository at this point in the history
…lvm#104626)

All expansions end with replacing the previous inrinsic with the new
expansion and erasing the old one. By moving this operation to the
caller, these expansion functions can be called in more contexts and a
small amount of duplicated code is consolidated.

Pre-req for llvm#88056
  • Loading branch information
pow2clk authored Aug 17, 2024
1 parent 981191a commit cd3f48d
Showing 1 changed file with 61 additions and 76 deletions.
137 changes: 61 additions & 76 deletions llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ static bool isIntrinsicExpansion(Function &F) {
return false;
}

static bool expandAbs(CallInst *Orig) {
static Value *expandAbs(CallInst *Orig) {
Value *X = Orig->getOperand(0);
IRBuilder<> Builder(Orig->getParent());
Builder.SetInsertPoint(Orig);
Expand All @@ -64,14 +64,11 @@ static bool expandAbs(CallInst *Orig) {
ConstantInt::get(EltTy, 0))
: ConstantInt::get(EltTy, 0);
auto *V = Builder.CreateSub(Zero, X);
auto *MaxCall =
Builder.CreateIntrinsic(Ty, Intrinsic::smax, {X, V}, nullptr, "dx.max");
Orig->replaceAllUsesWith(MaxCall);
Orig->eraseFromParent();
return true;
return Builder.CreateIntrinsic(Ty, Intrinsic::smax, {X, V}, nullptr,
"dx.max");
}

static bool expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
static Value *expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
assert(DotIntrinsic == Intrinsic::dx_sdot ||
DotIntrinsic == Intrinsic::dx_udot);
Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot
Expand All @@ -97,12 +94,10 @@ static bool expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
ArrayRef<Value *>{Elt0, Elt1, Result},
nullptr, "dx.mad");
}
Orig->replaceAllUsesWith(Result);
Orig->eraseFromParent();
return true;
return Result;
}

static bool expandExpIntrinsic(CallInst *Orig) {
static Value *expandExpIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
IRBuilder<> Builder(Orig->getParent());
Builder.SetInsertPoint(Orig);
Expand All @@ -119,23 +114,21 @@ static bool expandExpIntrinsic(CallInst *Orig) {
Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {NewX}, nullptr, "dx.exp2");
Exp2Call->setTailCall(Orig->isTailCall());
Exp2Call->setAttributes(Orig->getAttributes());
Orig->replaceAllUsesWith(Exp2Call);
Orig->eraseFromParent();
return true;
return Exp2Call;
}

static bool expandAnyIntrinsic(CallInst *Orig) {
static Value *expandAnyIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
IRBuilder<> Builder(Orig->getParent());
Builder.SetInsertPoint(Orig);
Type *Ty = X->getType();
Type *EltTy = Ty->getScalarType();

Value *Result = nullptr;
if (!Ty->isVectorTy()) {
Value *Cond = EltTy->isFloatingPointTy()
? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0))
: Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0));
Orig->replaceAllUsesWith(Cond);
Result = EltTy->isFloatingPointTy()
? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0))
: Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0));
} else {
auto *XVec = dyn_cast<FixedVectorType>(Ty);
Value *Cond =
Expand All @@ -148,18 +141,16 @@ static bool expandAnyIntrinsic(CallInst *Orig) {
X, ConstantVector::getSplat(
ElementCount::getFixed(XVec->getNumElements()),
ConstantInt::get(EltTy, 0)));
Value *Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
for (unsigned I = 1; I < XVec->getNumElements(); I++) {
Value *Elt = Builder.CreateExtractElement(Cond, I);
Result = Builder.CreateOr(Result, Elt);
}
Orig->replaceAllUsesWith(Result);
}
Orig->eraseFromParent();
return true;
return Result;
}

static bool expandLengthIntrinsic(CallInst *Orig) {
static Value *expandLengthIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
IRBuilder<> Builder(Orig->getParent());
Builder.SetInsertPoint(Orig);
Expand All @@ -182,30 +173,23 @@ static bool expandLengthIntrinsic(CallInst *Orig) {
Value *Mul = Builder.CreateFMul(Elt, Elt);
Sum = Builder.CreateFAdd(Sum, Mul);
}
Value *Result = Builder.CreateIntrinsic(
EltTy, Intrinsic::sqrt, ArrayRef<Value *>{Sum}, nullptr, "elt.sqrt");

Orig->replaceAllUsesWith(Result);
Orig->eraseFromParent();
return true;
return Builder.CreateIntrinsic(EltTy, Intrinsic::sqrt, ArrayRef<Value *>{Sum},
nullptr, "elt.sqrt");
}

static bool expandLerpIntrinsic(CallInst *Orig) {
static Value *expandLerpIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
Value *Y = Orig->getOperand(1);
Value *S = Orig->getOperand(2);
IRBuilder<> Builder(Orig->getParent());
Builder.SetInsertPoint(Orig);
auto *V = Builder.CreateFSub(Y, X);
V = Builder.CreateFMul(S, V);
auto *Result = Builder.CreateFAdd(X, V, "dx.lerp");
Orig->replaceAllUsesWith(Result);
Orig->eraseFromParent();
return true;
return Builder.CreateFAdd(X, V, "dx.lerp");
}

static bool expandLogIntrinsic(CallInst *Orig,
float LogConstVal = numbers::ln2f) {
static Value *expandLogIntrinsic(CallInst *Orig,
float LogConstVal = numbers::ln2f) {
Value *X = Orig->getOperand(0);
IRBuilder<> Builder(Orig->getParent());
Builder.SetInsertPoint(Orig);
Expand All @@ -221,16 +205,13 @@ static bool expandLogIntrinsic(CallInst *Orig,
Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
Log2Call->setTailCall(Orig->isTailCall());
Log2Call->setAttributes(Orig->getAttributes());
auto *Result = Builder.CreateFMul(Ln2Const, Log2Call);
Orig->replaceAllUsesWith(Result);
Orig->eraseFromParent();
return true;
return Builder.CreateFMul(Ln2Const, Log2Call);
}
static bool expandLog10Intrinsic(CallInst *Orig) {
static Value *expandLog10Intrinsic(CallInst *Orig) {
return expandLogIntrinsic(Orig, numbers::ln2f / numbers::ln10f);
}

static bool expandNormalizeIntrinsic(CallInst *Orig) {
static Value *expandNormalizeIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
Type *Ty = Orig->getType();
Type *EltTy = Ty->getScalarType();
Expand All @@ -245,11 +226,7 @@ static bool expandNormalizeIntrinsic(CallInst *Orig) {
report_fatal_error(Twine("Invalid input scalar: length is zero"),
/* gen_crash_diag=*/false);
}
Value *Result = Builder.CreateFDiv(X, X);

Orig->replaceAllUsesWith(Result);
Orig->eraseFromParent();
return true;
return Builder.CreateFDiv(X, X);
}

unsigned XVecSize = XVec->getNumElements();
Expand Down Expand Up @@ -291,14 +268,10 @@ static bool expandNormalizeIntrinsic(CallInst *Orig) {
nullptr, "dx.rsqrt");

Value *MultiplicandVec = Builder.CreateVectorSplat(XVecSize, Multiplicand);
Value *Result = Builder.CreateFMul(X, MultiplicandVec);

Orig->replaceAllUsesWith(Result);
Orig->eraseFromParent();
return true;
return Builder.CreateFMul(X, MultiplicandVec);
}

static bool expandPowIntrinsic(CallInst *Orig) {
static Value *expandPowIntrinsic(CallInst *Orig) {

Value *X = Orig->getOperand(0);
Value *Y = Orig->getOperand(1);
Expand All @@ -313,9 +286,7 @@ static bool expandPowIntrinsic(CallInst *Orig) {
Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {Mul}, nullptr, "elt.exp2");
Exp2Call->setTailCall(Orig->isTailCall());
Exp2Call->setAttributes(Orig->getAttributes());
Orig->replaceAllUsesWith(Exp2Call);
Orig->eraseFromParent();
return true;
return Exp2Call;
}

static Intrinsic::ID getMaxForClamp(Type *ElemTy,
Expand Down Expand Up @@ -344,7 +315,8 @@ static Intrinsic::ID getMinForClamp(Type *ElemTy,
return Intrinsic::minnum;
}

static bool expandClampIntrinsic(CallInst *Orig, Intrinsic::ID ClampIntrinsic) {
static Value *expandClampIntrinsic(CallInst *Orig,
Intrinsic::ID ClampIntrinsic) {
Value *X = Orig->getOperand(0);
Value *Min = Orig->getOperand(1);
Value *Max = Orig->getOperand(2);
Expand All @@ -353,41 +325,54 @@ static bool expandClampIntrinsic(CallInst *Orig, Intrinsic::ID ClampIntrinsic) {
Builder.SetInsertPoint(Orig);
auto *MaxCall = Builder.CreateIntrinsic(
Ty, getMaxForClamp(Ty, ClampIntrinsic), {X, Min}, nullptr, "dx.max");
auto *MinCall =
Builder.CreateIntrinsic(Ty, getMinForClamp(Ty, ClampIntrinsic),
{MaxCall, Max}, nullptr, "dx.min");

Orig->replaceAllUsesWith(MinCall);
Orig->eraseFromParent();
return true;
return Builder.CreateIntrinsic(Ty, getMinForClamp(Ty, ClampIntrinsic),
{MaxCall, Max}, nullptr, "dx.min");
}

static bool expandIntrinsic(Function &F, CallInst *Orig) {
Value *Result = nullptr;
switch (F.getIntrinsicID()) {
case Intrinsic::abs:
return expandAbs(Orig);
Result = expandAbs(Orig);
break;
case Intrinsic::exp:
return expandExpIntrinsic(Orig);
Result = expandExpIntrinsic(Orig);
break;
case Intrinsic::log:
return expandLogIntrinsic(Orig);
Result = expandLogIntrinsic(Orig);
break;
case Intrinsic::log10:
return expandLog10Intrinsic(Orig);
Result = expandLog10Intrinsic(Orig);
break;
case Intrinsic::pow:
return expandPowIntrinsic(Orig);
Result = expandPowIntrinsic(Orig);
break;
case Intrinsic::dx_any:
return expandAnyIntrinsic(Orig);
Result = expandAnyIntrinsic(Orig);
break;
case Intrinsic::dx_uclamp:
case Intrinsic::dx_clamp:
return expandClampIntrinsic(Orig, F.getIntrinsicID());
Result = expandClampIntrinsic(Orig, F.getIntrinsicID());
break;
case Intrinsic::dx_lerp:
return expandLerpIntrinsic(Orig);
Result = expandLerpIntrinsic(Orig);
break;
case Intrinsic::dx_length:
return expandLengthIntrinsic(Orig);
Result = expandLengthIntrinsic(Orig);
break;
case Intrinsic::dx_normalize:
return expandNormalizeIntrinsic(Orig);
Result = expandNormalizeIntrinsic(Orig);
break;
case Intrinsic::dx_sdot:
case Intrinsic::dx_udot:
return expandIntegerDot(Orig, F.getIntrinsicID());
Result = expandIntegerDot(Orig, F.getIntrinsicID());
break;
}

if (Result) {
Orig->replaceAllUsesWith(Result);
Orig->eraseFromParent();
return true;
}
return false;
}
Expand Down

0 comments on commit cd3f48d

Please sign in to comment.