Skip to content

Commit 129a4ef

Browse files
david-armMDevereau
andcommitted
[LV] Add support for partial reductions without a binary op
Consider IR such as this: for.body: %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ] %accum = phi i32 [ 0, %entry ], [ %add, %for.body ] %gep.a = getelementptr i8, ptr %a, i64 %iv %load.a = load i8, ptr %gep.a, align 1 %ext.a = zext i8 %load.a to i32 %add = add i32 %ext.a, %accum %iv.next = add i64 %iv, 1 %exitcond.not = icmp eq i64 %iv.next, 1025 br i1 %exitcond.not, label %for.exit, label %for.body Conceptually we can vectorise this using partial reductions too, although the current loop vectoriser implementation requires the accumulation of a multiply. For AArch64 this is easily done with a udot or sdot with an identity operand, i.e. a vector of (i16 1). In order to do this I had to teach getScaledReductions that the accumulated value may come from a unary op, hence there is only one extension to consider. Similarly, I updated the vplan and AArch64 TTI cost model to understand the possible unary op. Co-authored-by: Matt Devereau <[email protected]>
1 parent de93f08 commit 129a4ef

File tree

11 files changed

+344
-282
lines changed

11 files changed

+344
-282
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1325,9 +1325,21 @@ class TargetTransformInfo {
13251325
/// \return The cost of a partial reduction, which is a reduction from a
13261326
/// vector to another vector with fewer elements of larger size. They are
13271327
/// represented by the llvm.experimental.partial.reduce.add intrinsic, which
1328-
/// takes an accumulator and a binary operation operand that itself is fed by
1329-
/// two extends. An example of an operation that uses a partial reduction is a
1330-
/// dot product, which reduces two vectors to another of 4 times fewer and 4
1328+
/// takes an accumulator of type \p AccumType and a second vector operand to
1329+
/// be accumulated, whose element count is specified by \p VF. The type of
1330+
/// reduction is specified by \p Opcode. The second operand passed to the
1331+
/// intrinsic could be the result of an extend, such as sext or zext. In
1332+
/// this case \p BinOp is nullopt, \p InputTypeA represents the type being
1333+
/// extended and \p OpAExtend the operation, i.e. sign- or zero-extend.
1334+
/// Also, \p InputTypeB should be nullptr and OpBExtend should be None.
1335+
/// Alternatively, the second operand could be the result of a binary
1336+
/// operation performed on two extends, i.e.
1337+
/// mul(zext i8 %a -> i32, zext i8 %b -> i32).
1338+
/// In this case \p BinOp may specify the opcode of the binary operation,
1339+
/// \p InputTypeA and \p InputTypeB the types being extended, and
1340+
/// \p OpAExtend, \p OpBExtend the form of extensions. An example of an
1341+
/// operation that uses a partial reduction is a dot product, which reduces
1342+
/// two vectors in binary mul operation to another of 4 times fewer and 4
13311343
/// times larger elements.
13321344
LLVM_ABI InstructionCost getPartialReductionCost(
13331345
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5405,11 +5405,21 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
54055405

54065406
// Sub opcodes currently only occur in chained cases.
54075407
// Independent partial reduction subtractions are still costed as an add
5408-
if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
5408+
if ((Opcode != Instruction::Add && Opcode != Instruction::Sub) ||
5409+
OpAExtend == TTI::PR_None)
54095410
return Invalid;
54105411

5411-
if (InputTypeA != InputTypeB)
5412+
// We only support multiply binary operations for now, and for muls we
5413+
// require the types being extended to be the same.
5414+
// NOTE: For muls AArch64 supports lowering mixed extensions to a usdot but
5415+
// only if the i8mm or sve/streaming features are available.
5416+
if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB ||
5417+
OpBExtend == TTI::PR_None ||
5418+
(OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
5419+
!ST->isSVEorStreamingSVEAvailable())))
54125420
return Invalid;
5421+
assert((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) &&
5422+
"Unexpected values for OpBExtend or InputTypeB");
54135423

54145424
EVT InputEVT = EVT::getEVT(InputTypeA);
54155425
EVT AccumEVT = EVT::getEVT(AccumType);
@@ -5456,15 +5466,6 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
54565466
} else
54575467
return Invalid;
54585468

5459-
// AArch64 supports lowering mixed fixed-width extensions to a usdot but only
5460-
// if the i8mm feature is available.
5461-
if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
5462-
(OpAExtend != OpBExtend && !ST->hasMatMulInt8()))
5463-
return Invalid;
5464-
5465-
if (!BinOp || *BinOp != Instruction::Mul)
5466-
return Invalid;
5467-
54685469
return Cost;
54695470
}
54705471

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 60 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8148,15 +8148,15 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
81488148
// something that isn't another partial reduction. This is because the
81498149
// extends are intended to be lowered along with the reduction itself.
81508150

8151-
// Build up a set of partial reduction bin ops for efficient use checking.
8152-
SmallSet<User *, 4> PartialReductionBinOps;
8151+
// Build up a set of partial reduction ops for efficient use checking.
8152+
SmallSet<User *, 4> PartialReductionOps;
81538153
for (const auto &[PartialRdx, _] : PartialReductionChains)
8154-
PartialReductionBinOps.insert(PartialRdx.BinOp);
8154+
PartialReductionOps.insert(PartialRdx.ExtendUser);
81558155

81568156
auto ExtendIsOnlyUsedByPartialReductions =
8157-
[&PartialReductionBinOps](Instruction *Extend) {
8157+
[&PartialReductionOps](Instruction *Extend) {
81588158
return all_of(Extend->users(), [&](const User *U) {
8159-
return PartialReductionBinOps.contains(U);
8159+
return PartialReductionOps.contains(U);
81608160
});
81618161
};
81628162

@@ -8165,15 +8165,14 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
81658165
for (auto Pair : PartialReductionChains) {
81668166
PartialReductionChain Chain = Pair.first;
81678167
if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
8168-
ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB))
8168+
(!Chain.ExtendB || ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB)))
81698169
ScaledReductionMap.try_emplace(Chain.Reduction, Pair.second);
81708170
}
81718171
}
81728172

81738173
bool VPRecipeBuilder::getScaledReductions(
81748174
Instruction *PHI, Instruction *RdxExitInstr, VFRange &Range,
81758175
SmallVectorImpl<std::pair<PartialReductionChain, unsigned>> &Chains) {
8176-
81778176
if (!CM.TheLoop->contains(RdxExitInstr))
81788177
return false;
81798178

@@ -8202,43 +8201,75 @@ bool VPRecipeBuilder::getScaledReductions(
82028201
if (PhiOp != PHI)
82038202
return false;
82048203

8205-
auto *BinOp = dyn_cast<BinaryOperator>(Op);
8206-
if (!BinOp || !BinOp->hasOneUse())
8207-
return false;
8208-
82098204
using namespace llvm::PatternMatch;
8210-
// Use the side-effect of match to replace BinOp only if the pattern is
8211-
// matched, we don't care at this point whether it actually matched.
8212-
match(BinOp, m_Neg(m_BinOp(BinOp)));
82138205

8214-
Value *A, *B;
8215-
if (!match(BinOp->getOperand(0), m_ZExtOrSExt(m_Value(A))) ||
8216-
!match(BinOp->getOperand(1), m_ZExtOrSExt(m_Value(B))))
8217-
return false;
8206+
// If the update is a binary operator, check both of its operands to see if
8207+
// they are extends. Otherwise, see if the update comes directly from an
8208+
// extend.
8209+
Instruction *Exts[2] = {nullptr};
8210+
BinaryOperator *ExtendUser = dyn_cast<BinaryOperator>(Op);
8211+
std::optional<unsigned> BinOpc;
8212+
Type *ExtOpTypes[2] = {nullptr};
8213+
8214+
auto CollectExtInfo = [&Exts,
8215+
&ExtOpTypes](SmallVectorImpl<Value *> &Ops) -> bool {
8216+
unsigned I = 0;
8217+
for (Value *OpI : Ops) {
8218+
Value *ExtOp;
8219+
if (!match(OpI, m_ZExtOrSExt(m_Value(ExtOp))))
8220+
return false;
8221+
Exts[I] = cast<Instruction>(OpI);
8222+
ExtOpTypes[I] = ExtOp->getType();
8223+
I++;
8224+
}
8225+
return true;
8226+
};
8227+
8228+
if (ExtendUser) {
8229+
if (!ExtendUser->hasOneUse())
8230+
return false;
82188231

8219-
Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0));
8220-
Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1));
8232+
// Use the side-effect of match to replace BinOp only if the pattern is
8233+
// matched, we don't care at this point whether it actually matched.
8234+
match(ExtendUser, m_Neg(m_BinOp(ExtendUser)));
8235+
8236+
SmallVector<Value *> Ops(ExtendUser->operands());
8237+
if (!CollectExtInfo(Ops))
8238+
return false;
8239+
8240+
BinOpc = std::make_optional(ExtendUser->getOpcode());
8241+
} else if (match(Update, m_Add(m_Value(), m_Value()))) {
8242+
// We already know the operands for Update are Op and PhiOp.
8243+
SmallVector<Value *> Ops({Op});
8244+
if (!CollectExtInfo(Ops))
8245+
return false;
8246+
8247+
ExtendUser = Update;
8248+
BinOpc = std::nullopt;
8249+
} else
8250+
return false;
82218251

82228252
TTI::PartialReductionExtendKind OpAExtend =
8223-
TargetTransformInfo::getPartialReductionExtendKind(ExtA);
8253+
TargetTransformInfo::getPartialReductionExtendKind(Exts[0]);
82248254
TTI::PartialReductionExtendKind OpBExtend =
8225-
TargetTransformInfo::getPartialReductionExtendKind(ExtB);
8226-
8227-
PartialReductionChain Chain(RdxExitInstr, ExtA, ExtB, BinOp);
8255+
Exts[1] ? TargetTransformInfo::getPartialReductionExtendKind(Exts[1])
8256+
: TargetTransformInfo::PR_None;
8257+
PartialReductionChain Chain(RdxExitInstr, Exts[0], Exts[1], ExtendUser);
82288258

82298259
TypeSize PHISize = PHI->getType()->getPrimitiveSizeInBits();
8230-
TypeSize ASize = A->getType()->getPrimitiveSizeInBits();
8231-
8260+
TypeSize ASize = ExtOpTypes[0]->getPrimitiveSizeInBits();
82328261
if (!PHISize.hasKnownScalarFactor(ASize))
82338262
return false;
82348263

8235-
unsigned TargetScaleFactor = PHISize.getKnownScalarFactor(ASize);
8264+
unsigned TargetScaleFactor =
8265+
PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
8266+
ExtOpTypes[0]->getPrimitiveSizeInBits());
82368267

82378268
if (LoopVectorizationPlanner::getDecisionAndClampRange(
82388269
[&](ElementCount VF) {
82398270
InstructionCost Cost = TTI->getPartialReductionCost(
8240-
Update->getOpcode(), A->getType(), B->getType(), PHI->getType(),
8241-
VF, OpAExtend, OpBExtend, BinOp->getOpcode(), CM.CostKind);
8271+
Update->getOpcode(), ExtOpTypes[0], ExtOpTypes[1], PHI->getType(),
8272+
VF, OpAExtend, OpBExtend, BinOpc, CM.CostKind);
82428273
return Cost.isValid();
82438274
},
82448275
Range)) {

llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,23 @@ struct HistogramInfo;
2424
struct VFRange;
2525

2626
/// A chain of instructions that form a partial reduction.
27-
/// Designed to match: reduction_bin_op (bin_op (extend (A), (extend (B))),
28-
/// accumulator).
27+
/// Designed to match either:
28+
/// reduction_bin_op (extend (A), accumulator), or
29+
/// reduction_bin_op (bin_op (extend (A), (extend (B))), accumulator).
2930
struct PartialReductionChain {
3031
PartialReductionChain(Instruction *Reduction, Instruction *ExtendA,
31-
Instruction *ExtendB, Instruction *BinOp)
32-
: Reduction(Reduction), ExtendA(ExtendA), ExtendB(ExtendB), BinOp(BinOp) {
33-
}
32+
Instruction *ExtendB, Instruction *ExtendUser)
33+
: Reduction(Reduction), ExtendA(ExtendA), ExtendB(ExtendB),
34+
ExtendUser(ExtendUser) {}
3435
/// The top-level binary operation that forms the reduction to a scalar
3536
/// after the loop body.
3637
Instruction *Reduction;
3738
/// The extension of each of the inner binary operation's operands.
3839
Instruction *ExtendA;
3940
Instruction *ExtendB;
4041

41-
/// The binary operation using the extends that is then reduced.
42-
Instruction *BinOp;
42+
/// The user of the extend that is then reduced.
43+
Instruction *ExtendUser;
4344
};
4445

4546
/// Helper class to create VPRecipies from IR instructions.

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 45 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -296,34 +296,21 @@ bool VPRecipeBase::isScalarCast() const {
296296
InstructionCost
297297
VPPartialReductionRecipe::computeCost(ElementCount VF,
298298
VPCostContext &Ctx) const {
299-
std::optional<unsigned> Opcode = std::nullopt;
300-
VPValue *BinOp = getOperand(1);
299+
std::optional<unsigned> Opcode;
300+
VPValue *Op = getOperand(0);
301+
VPRecipeBase *OpR = Op->getDefiningRecipe();
301302

302-
// If the partial reduction is predicated, a select will be operand 0 rather
303-
// than the binary op
303+
// If the partial reduction is predicated, a select will be operand 0
304304
using namespace llvm::VPlanPatternMatch;
305-
if (match(getOperand(1), m_Select(m_VPValue(), m_VPValue(), m_VPValue())))
306-
BinOp = BinOp->getDefiningRecipe()->getOperand(1);
307-
308-
// If BinOp is a negation, use the side effect of match to assign the actual
309-
// binary operation to BinOp
310-
match(BinOp, m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(BinOp)));
311-
VPRecipeBase *BinOpR = BinOp->getDefiningRecipe();
312-
313-
if (auto *WidenR = dyn_cast<VPWidenRecipe>(BinOpR))
314-
Opcode = std::make_optional(WidenR->getOpcode());
315-
316-
VPRecipeBase *ExtAR = BinOpR->getOperand(0)->getDefiningRecipe();
317-
VPRecipeBase *ExtBR = BinOpR->getOperand(1)->getDefiningRecipe();
305+
if (match(getOperand(1), m_Select(m_VPValue(), m_VPValue(Op), m_VPValue()))) {
306+
OpR = Op->getDefiningRecipe();
307+
}
318308

319-
auto *PhiType = Ctx.Types.inferScalarType(getOperand(1));
320-
auto *InputTypeA = Ctx.Types.inferScalarType(ExtAR ? ExtAR->getOperand(0)
321-
: BinOpR->getOperand(0));
322-
auto *InputTypeB = Ctx.Types.inferScalarType(ExtBR ? ExtBR->getOperand(0)
323-
: BinOpR->getOperand(1));
309+
Type *InputTypeA = nullptr, *InputTypeB = nullptr;
310+
TTI::PartialReductionExtendKind ExtAType = TargetTransformInfo::PR_None,
311+
ExtBType = TargetTransformInfo::PR_None;
324312

325313
auto GetExtendKind = [](VPRecipeBase *R) {
326-
// The extend could come from outside the plan.
327314
if (!R)
328315
return TargetTransformInfo::PR_None;
329316
auto *WidenCastR = dyn_cast<VPWidenCastRecipe>(R);
@@ -336,9 +323,42 @@ VPPartialReductionRecipe::computeCost(ElementCount VF,
336323
return TargetTransformInfo::PR_None;
337324
};
338325

326+
// Pick out opcode, type/ext information and use sub side effects from a widen recipe.
327+
auto HandleWiden = [&](VPWidenRecipe* Widen){
328+
if (match(Widen,
329+
m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(Op)))) {
330+
Widen = dyn_cast<VPWidenRecipe>(Op->getDefiningRecipe());
331+
}
332+
Opcode = Widen->getOpcode();
333+
VPRecipeBase *ExtAR = Widen->getOperand(0)->getDefiningRecipe();
334+
VPRecipeBase *ExtBR = Widen->getOperand(1)->getDefiningRecipe();
335+
InputTypeA = Ctx.Types.inferScalarType(ExtAR ? ExtAR->getOperand(0)
336+
: Widen->getOperand(0));
337+
InputTypeB = Ctx.Types.inferScalarType(ExtBR ? ExtBR->getOperand(0)
338+
: Widen->getOperand(1));
339+
ExtAType = GetExtendKind(ExtAR);
340+
ExtBType = GetExtendKind(ExtBR);
341+
};
342+
343+
if (isa<VPWidenCastRecipe>(OpR)) {
344+
InputTypeA = Ctx.Types.inferScalarType(OpR->getOperand(0));
345+
ExtAType = GetExtendKind(OpR);
346+
} else if (isa<VPReductionPHIRecipe>(OpR)) {
347+
auto RedPhiOp1R = getOperand(1)->getDefiningRecipe();
348+
if (isa<VPWidenCastRecipe>(RedPhiOp1R)) {
349+
InputTypeA = Ctx.Types.inferScalarType(RedPhiOp1R->getOperand(0));
350+
ExtAType = GetExtendKind(RedPhiOp1R);
351+
} else if (auto Widen = dyn_cast<VPWidenRecipe>(RedPhiOp1R))
352+
HandleWiden(Widen);
353+
} else if (auto Widen = dyn_cast<VPWidenRecipe>(OpR)) {
354+
HandleWiden(Widen);
355+
} else if (auto Reduction = dyn_cast<VPPartialReductionRecipe>(OpR)) {
356+
return Reduction->computeCost(VF, Ctx);
357+
}
358+
auto *PhiType = Ctx.Types.inferScalarType(getOperand(1));
339359
return Ctx.TTI.getPartialReductionCost(
340-
getOpcode(), InputTypeA, InputTypeB, PhiType, VF, GetExtendKind(ExtAR),
341-
GetExtendKind(ExtBR), Opcode, Ctx.CostKind);
360+
getOpcode(), InputTypeA, InputTypeB, PhiType, VF, ExtAType,
361+
ExtBType, Opcode, Ctx.CostKind);
342362
}
343363

344364
void VPPartialReductionRecipe::execute(VPTransformState &State) {

llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1255,4 +1255,3 @@ entry:
12551255
%partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add(<2 x i64> %acc, <8 x i64> %input.wide)
12561256
ret <2 x i64> %partial.reduce
12571257
}
1258-

0 commit comments

Comments
 (0)