Skip to content

Commit a570eac

Browse files
MacDuemattarde
authored andcommitted
[AArch64] Support lowering smaller than legal LOOP_DEP_MASKs to whilewr/rw (llvm#171982)
This adds support for lowering smaller-than-legal masks such as: ``` <vscale x 8 x i1> @llvm.loop.dependence.war.mask.nxv8i1(ptr %a, ptr %b, i64 1) ``` To a whilewr + unpack. It also slightly simplifies the lowering.
1 parent a51be1e commit a570eac

File tree

3 files changed

+100
-40
lines changed

3 files changed

+100
-40
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 28 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -5439,55 +5439,43 @@ static MVT getSVEContainerType(EVT ContentTy);
54395439
SDValue
54405440
AArch64TargetLowering::LowerLOOP_DEPENDENCE_MASK(SDValue Op,
54415441
SelectionDAG &DAG) const {
5442+
assert((Subtarget->hasSVE2() ||
5443+
(Subtarget->hasSME() && Subtarget->isStreaming())) &&
5444+
"Lowering loop_dependence_raw_mask or loop_dependence_war_mask "
5445+
"requires SVE or SME");
5446+
54425447
SDLoc DL(Op);
54435448
EVT VT = Op.getValueType();
5444-
SDValue EltSize = Op.getOperand(2);
5445-
switch (EltSize->getAsZExtVal()) {
5446-
case 1:
5447-
if (VT != MVT::v16i8 && VT != MVT::nxv16i1)
5448-
return SDValue();
5449-
break;
5450-
case 2:
5451-
if (VT != MVT::v8i8 && VT != MVT::nxv8i1)
5452-
return SDValue();
5453-
break;
5454-
case 4:
5455-
if (VT != MVT::v4i16 && VT != MVT::nxv4i1)
5456-
return SDValue();
5457-
break;
5458-
case 8:
5459-
if (VT != MVT::v2i32 && VT != MVT::nxv2i1)
5460-
return SDValue();
5461-
break;
5462-
default:
5463-
// Other element sizes are incompatible with whilewr/rw, so expand instead
5464-
return SDValue();
5465-
}
5449+
unsigned LaneOffset = Op.getConstantOperandVal(3);
5450+
unsigned NumElements = VT.getVectorMinNumElements();
5451+
uint64_t EltSizeInBytes = Op.getConstantOperandVal(2);
54665452

5467-
SDValue LaneOffset = Op.getOperand(3);
5468-
if (LaneOffset->getAsZExtVal())
5453+
// Lane offsets and other element sizes are not supported by whilewr/rw.
5454+
if (LaneOffset != 0 || !is_contained({1u, 2u, 4u, 8u}, EltSizeInBytes))
54695455
return SDValue();
54705456

5471-
SDValue PtrA = Op.getOperand(0);
5472-
SDValue PtrB = Op.getOperand(1);
5457+
EVT EltVT = MVT::getIntegerVT(EltSizeInBytes * 8);
5458+
EVT PredVT = getPackedSVEVectorVT(EltVT).changeElementType(MVT::i1);
54735459

5474-
if (VT.isScalableVT())
5475-
return DAG.getNode(Op.getOpcode(), DL, VT, PtrA, PtrB, EltSize, LaneOffset);
5460+
// Legal whilewr/rw (lowered by tablegen matcher).
5461+
if (PredVT == VT)
5462+
return Op;
54765463

5477-
// We can use the SVE whilewr/whilerw instruction to lower this
5478-
// intrinsic by creating the appropriate sequence of scalable vector
5479-
// operations and then extracting a fixed-width subvector from the scalable
5480-
// vector. Scalable vector variants are already legal.
5481-
EVT ContainerVT =
5482-
EVT::getVectorVT(*DAG.getContext(), VT.getVectorElementType(),
5483-
VT.getVectorNumElements(), true);
5484-
EVT WhileVT = ContainerVT.changeElementType(MVT::i1);
5464+
// Expand if this mask needs splitting (this will produce a whilelo).
5465+
if (NumElements > PredVT.getVectorMinNumElements())
5466+
return SDValue();
54855467

54865468
SDValue Mask =
5487-
DAG.getNode(Op.getOpcode(), DL, WhileVT, PtrA, PtrB, EltSize, LaneOffset);
5488-
SDValue MaskAsInt = DAG.getNode(ISD::SIGN_EXTEND, DL, ContainerVT, Mask);
5489-
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, MaskAsInt,
5490-
DAG.getVectorIdxConstant(0, DL));
5469+
DAG.getNode(Op.getOpcode(), DL, PredVT, to_vector(Op->op_values()));
5470+
5471+
if (VT.isFixedLengthVector()) {
5472+
EVT WidePredVT = PredVT.changeElementType(VT.getScalarType());
5473+
SDValue MaskAsInt = DAG.getNode(ISD::SIGN_EXTEND, DL, WidePredVT, Mask);
5474+
return convertFromScalableVector(DAG, VT, MaskAsInt);
5475+
}
5476+
5477+
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Mask,
5478+
DAG.getConstant(0, DL, MVT::i64));
54915479
}
54925480

54935481
SDValue AArch64TargetLowering::LowerBITCAST(SDValue Op,

llvm/test/CodeGen/AArch64/alias_mask.ll

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,3 +563,40 @@ entry:
563563
%0 = call <1 x i1> @llvm.loop.dependence.raw.mask.v1i1(ptr %a, ptr %b, i64 8)
564564
ret <1 x i1> %0
565565
}
566+
567+
define <8 x i1> @whilewr_extract_v8i1(ptr %a, ptr %b) {
568+
; CHECK-LABEL: whilewr_extract_v8i1:
569+
; CHECK: // %bb.0: // %entry
570+
; CHECK-NEXT: whilewr p0.b, x0, x1
571+
; CHECK-NEXT: mov z0.b, p0/z, #-1 // =0xffffffffffffffff
572+
; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
573+
; CHECK-NEXT: ret
574+
entry:
575+
%0 = call <8 x i1> @llvm.loop.dependence.war.mask.v8i1(ptr %a, ptr %b, i64 1)
576+
ret <8 x i1> %0
577+
}
578+
579+
define <4 x i1> @whilewr_extract_v4i1(ptr %a, ptr %b) {
580+
; CHECK-LABEL: whilewr_extract_v4i1:
581+
; CHECK: // %bb.0: // %entry
582+
; CHECK-NEXT: whilewr p0.b, x0, x1
583+
; CHECK-NEXT: punpklo p0.h, p0.b
584+
; CHECK-NEXT: mov z0.h, p0/z, #-1 // =0xffffffffffffffff
585+
; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
586+
; CHECK-NEXT: ret
587+
entry:
588+
%0 = call <4 x i1> @llvm.loop.dependence.war.mask.v4i1(ptr %a, ptr %b, i64 1)
589+
ret <4 x i1> %0
590+
}
591+
592+
define <2 x i1> @whilewr_extract_v2i1(ptr %a, ptr %b) {
593+
; CHECK-LABEL: whilewr_extract_v2i1:
594+
; CHECK: // %bb.0: // %entry
595+
; CHECK-NEXT: whilewr p0.s, x0, x1
596+
; CHECK-NEXT: mov z0.s, p0/z, #-1 // =0xffffffffffffffff
597+
; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
598+
; CHECK-NEXT: ret
599+
entry:
600+
%0 = call <2 x i1> @llvm.loop.dependence.war.mask.v2i1(ptr %a, ptr %b, i64 4)
601+
ret <2 x i1> %0
602+
}

llvm/test/CodeGen/AArch64/alias_mask_scalable.ll

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,3 +309,38 @@ entry:
309309
%0 = call <vscale x 16 x i1> @llvm.loop.dependence.war.mask.nxv16i1(ptr %a, ptr %b, i64 3)
310310
ret <vscale x 16 x i1> %0
311311
}
312+
313+
define <vscale x 8 x i1> @whilewr_extract_nxv8i1(ptr %a, ptr %b) {
314+
; CHECK-LABEL: whilewr_extract_nxv8i1:
315+
; CHECK: // %bb.0: // %entry
316+
; CHECK-NEXT: whilewr p0.b, x0, x1
317+
; CHECK-NEXT: punpklo p0.h, p0.b
318+
; CHECK-NEXT: ret
319+
entry:
320+
%0 = call <vscale x 8 x i1> @llvm.loop.dependence.war.mask.nxv8i1(ptr %a, ptr %b, i64 1)
321+
ret <vscale x 8 x i1> %0
322+
}
323+
324+
define <vscale x 4 x i1> @whilewr_extract_nxv4i1(ptr %a, ptr %b) {
325+
; CHECK-LABEL: whilewr_extract_nxv4i1:
326+
; CHECK: // %bb.0: // %entry
327+
; CHECK-NEXT: whilewr p0.b, x0, x1
328+
; CHECK-NEXT: punpklo p0.h, p0.b
329+
; CHECK-NEXT: punpklo p0.h, p0.b
330+
; CHECK-NEXT: ret
331+
entry:
332+
%0 = call <vscale x 4 x i1> @llvm.loop.dependence.war.mask.nxv4i1(ptr %a, ptr %b, i64 1)
333+
ret <vscale x 4 x i1> %0
334+
}
335+
336+
337+
define <vscale x 2 x i1> @whilewr_extract_nxv2i1(ptr %a, ptr %b) {
338+
; CHECK-LABEL: whilewr_extract_nxv2i1:
339+
; CHECK: // %bb.0: // %entry
340+
; CHECK-NEXT: whilewr p0.s, x0, x1
341+
; CHECK-NEXT: punpklo p0.h, p0.b
342+
; CHECK-NEXT: ret
343+
entry:
344+
%0 = call <vscale x 2 x i1> @llvm.loop.dependence.war.mask.nxv2i1(ptr %a, ptr %b, i64 4)
345+
ret <vscale x 2 x i1> %0
346+
}

0 commit comments

Comments
 (0)