Skip to content

[X86] Move the AVX512 VSELECT(COND, 0, X) -> VSELECT(!COND, X, 0) fold to DAGToDAG #145724

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 54 additions & 16 deletions llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1139,24 +1139,62 @@ void X86DAGToDAGISel::PreprocessISelDAG() {
break;
}
case ISD::VSELECT: {
// Replace VSELECT with non-mask conditions with with BLENDV/VPTERNLOG.
EVT EleVT = N->getOperand(0).getValueType().getVectorElementType();
if (EleVT == MVT::i1)
break;

assert(Subtarget->hasSSE41() && "Expected SSE4.1 support!");
assert(N->getValueType(0).getVectorElementType() != MVT::i16 &&
"We can't replace VSELECT with BLENDV in vXi16!");
SDValue Cond = N->getOperand(0);
SDValue LHS = N->getOperand(1);
SDValue RHS = N->getOperand(2);
EVT CondVT = Cond.getValueType();
EVT CondSVT = CondVT.getVectorElementType();
EVT VT = N->getValueType(0);
SDLoc DL(N);
SDValue R;
if (Subtarget->hasVLX() && CurDAG->ComputeNumSignBits(N->getOperand(0)) ==
EleVT.getSizeInBits()) {
R = CurDAG->getNode(X86ISD::VPTERNLOG, SDLoc(N), N->getValueType(0),
N->getOperand(0), N->getOperand(1), N->getOperand(2),
CurDAG->getTargetConstant(0xCA, SDLoc(N), MVT::i8));

if (CondSVT == MVT::i1) {
assert(Subtarget->hasAVX512() && "Expected AVX512 support!");
if (!Cond->hasOneUse() || !ISD::isBuildVectorAllZeros(LHS.getNode()) ||
ISD::isBuildVectorAllZeros(RHS.getNode()))
break;
// If this is an avx512 target we can improve the use of zero masking by
// swapping the operands and inverting the condition.
// vselect cond, zero, op = vselect not(cond), op, zero
auto InverseCondition = [this](SDValue Cond, const SDLoc &DL) {
EVT CondVT = Cond.getValueType();
if (Cond.getOpcode() == ISD::SETCC &&
!ISD::isBuildVectorAllZeros(Cond.getOperand(0).getNode())) {
ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
CC = ISD::getSetCCInverse(CC, Cond.getOperand(0).getValueType());
return CurDAG->getSetCC(DL, CondVT, Cond.getOperand(0),
Cond.getOperand(1), CC);
}
if (Cond.getOpcode() == X86ISD::CMPM ||
Cond.getOpcode() == X86ISD::FSETCCM) {
unsigned CC = Cond.getConstantOperandVal(2);
return CurDAG->getNode(
Cond.getOpcode(), DL, CondVT, Cond.getOperand(0),
Cond.getOperand(1),
CurDAG->getTargetConstant(CC ^ 4, DL, MVT::i8));
}
return CurDAG->getNOT(DL, Cond, CondVT);
};
if (Cond.getOpcode() == ISD::INSERT_SUBVECTOR &&
Cond.getOperand(0).isUndef())
R = CurDAG->getNode(
ISD::INSERT_SUBVECTOR, DL, CondVT, Cond.getOperand(0),
InverseCondition(Cond.getOperand(1), DL), Cond.getOperand(2));
else
R = InverseCondition(Cond, DL);
R = CurDAG->getSelect(DL, VT, R, RHS, LHS);
} else {
R = CurDAG->getNode(X86ISD::BLENDV, SDLoc(N), N->getValueType(0),
N->getOperand(0), N->getOperand(1),
N->getOperand(2));
// Replace VSELECT with non-mask conditions with BLENDV/VPTERNLOG.
assert(Subtarget->hasSSE41() && "Expected SSE4.1 support!");
assert(VT.getVectorElementType() != MVT::i16 &&
"We can't replace VSELECT with BLENDV in vXi16!");
if (Subtarget->hasVLX() &&
CurDAG->ComputeNumSignBits(Cond) == CondSVT.getSizeInBits()) {
R = CurDAG->getNode(X86ISD::VPTERNLOG, DL, VT, Cond, LHS, RHS,
CurDAG->getTargetConstant(0xCA, DL, MVT::i8));
} else {
R = CurDAG->getNode(X86ISD::BLENDV, DL, VT, Cond, LHS, RHS);
}
}
--I;
CurDAG->ReplaceAllUsesWith(N, R.getNode());
Expand Down
61 changes: 25 additions & 36 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48039,19 +48039,6 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
}
}

// Check if the first operand is all zeros and Cond type is vXi1.
// If this an avx512 target we can improve the use of zero masking by
// swapping the operands and inverting the condition.
if (N->getOpcode() == ISD::VSELECT && Cond.hasOneUse() &&
Subtarget.hasAVX512() && CondVT.getVectorElementType() == MVT::i1 &&
ISD::isBuildVectorAllZeros(LHS.getNode()) &&
!ISD::isBuildVectorAllZeros(RHS.getNode())) {
// Invert the cond to not(cond) : xor(op,allones)=not(op)
SDValue CondNew = DAG.getNOT(DL, Cond, CondVT);
// Vselect cond, op1, op2 = Vselect not(cond), op2, op1
return DAG.getSelect(DL, VT, CondNew, RHS, LHS);
}

// Attempt to convert a (vXi1 bitcast(iX Cond)) selection mask before it might
// get split by legalization.
if (N->getOpcode() == ISD::VSELECT && Cond.getOpcode() == ISD::BITCAST &&
Expand Down Expand Up @@ -48115,33 +48102,35 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
return V;

// select(~Cond, X, Y) -> select(Cond, Y, X)
if (CondVT.getScalarType() != MVT::i1) {
// Limit vXi1 cases to AVX512 canonicalization of zero mask to the RHS.
if (CondVT.getScalarType() != MVT::i1 ||
(ISD::isBuildVectorAllZeros(LHS.getNode()) &&
!ISD::isBuildVectorAllZeros(RHS.getNode())))
if (SDValue CondNot = IsNOT(Cond, DAG))
return DAG.getNode(N->getOpcode(), DL, VT,
DAG.getBitcast(CondVT, CondNot), RHS, LHS);

// select(pcmpeq(and(X,Pow2),0),A,B) -> select(pcmpeq(and(X,Pow2),Pow2),B,A)
if (Cond.getOpcode() == X86ISD::PCMPEQ &&
Cond.getOperand(0).getOpcode() == ISD::AND &&
ISD::isBuildVectorAllZeros(Cond.getOperand(1).getNode()) &&
isConstantPowerOf2(Cond.getOperand(0).getOperand(1),
Cond.getScalarValueSizeInBits(),
/*AllowUndefs=*/true) &&
Cond.hasOneUse()) {
Cond = DAG.getNode(X86ISD::PCMPEQ, DL, CondVT, Cond.getOperand(0),
Cond.getOperand(0).getOperand(1));
return DAG.getNode(N->getOpcode(), DL, VT, Cond, RHS, LHS);
}

// pcmpgt(X, -1) -> pcmpgt(0, X) to help select/blendv just use the
// signbit.
if (Cond.getOpcode() == X86ISD::PCMPGT &&
ISD::isBuildVectorAllOnes(Cond.getOperand(1).getNode()) &&
Cond.hasOneUse()) {
Cond = DAG.getNode(X86ISD::PCMPGT, DL, CondVT,
DAG.getConstant(0, DL, CondVT), Cond.getOperand(0));
return DAG.getNode(N->getOpcode(), DL, VT, Cond, RHS, LHS);
}
// select(pcmpeq(and(X,Pow2),0),A,B) -> select(pcmpeq(and(X,Pow2),Pow2),B,A)
if (Cond.getOpcode() == X86ISD::PCMPEQ &&
Cond.getOperand(0).getOpcode() == ISD::AND &&
ISD::isBuildVectorAllZeros(Cond.getOperand(1).getNode()) &&
isConstantPowerOf2(Cond.getOperand(0).getOperand(1),
Cond.getScalarValueSizeInBits(),
/*AllowUndefs=*/true) &&
Cond.hasOneUse()) {
Cond = DAG.getNode(X86ISD::PCMPEQ, DL, CondVT, Cond.getOperand(0),
Cond.getOperand(0).getOperand(1));
return DAG.getNode(N->getOpcode(), DL, VT, Cond, RHS, LHS);
}

// pcmpgt(X, -1) -> pcmpgt(0, X) to help select/blendv just use the
// signbit.
if (Cond.getOpcode() == X86ISD::PCMPGT &&
ISD::isBuildVectorAllOnes(Cond.getOperand(1).getNode()) &&
Cond.hasOneUse()) {
Cond = DAG.getNode(X86ISD::PCMPGT, DL, CondVT,
DAG.getConstant(0, DL, CondVT), Cond.getOperand(0));
return DAG.getNode(N->getOpcode(), DL, VT, Cond, RHS, LHS);
}

// Try to optimize vXi1 selects if both operands are either all constants or
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/CodeGen/X86/psubus.ll
Original file line number Diff line number Diff line change
Expand Up @@ -981,9 +981,9 @@ define <16 x i8> @test14(<16 x i8> %x, <16 x i32> %y) nounwind {
; AVX512-LABEL: test14:
; AVX512: # %bb.0: # %vector.ph
; AVX512-NEXT: vpmovzxbd {{.*#+}} zmm2 = xmm0[0],zero,zero,zero,xmm0[1],zero,zero,zero,xmm0[2],zero,zero,zero,xmm0[3],zero,zero,zero,xmm0[4],zero,zero,zero,xmm0[5],zero,zero,zero,xmm0[6],zero,zero,zero,xmm0[7],zero,zero,zero,xmm0[8],zero,zero,zero,xmm0[9],zero,zero,zero,xmm0[10],zero,zero,zero,xmm0[11],zero,zero,zero,xmm0[12],zero,zero,zero,xmm0[13],zero,zero,zero,xmm0[14],zero,zero,zero,xmm0[15],zero,zero,zero
; AVX512-NEXT: vpmovdb %zmm1, %xmm3
; AVX512-NEXT: vpcmpnltud %zmm2, %zmm1, %k1
; AVX512-NEXT: vpmovdb %zmm1, %xmm1
; AVX512-NEXT: vpsubb %xmm0, %xmm1, %xmm0 {%k1} {z}
; AVX512-NEXT: vpsubb %xmm0, %xmm3, %xmm0 {%k1} {z}
; AVX512-NEXT: vzeroupper
; AVX512-NEXT: retq
vector.ph:
Expand Down
16 changes: 8 additions & 8 deletions llvm/test/CodeGen/X86/var-permute-256.ll
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,11 @@ define <4 x i64> @var_shuffle_zero_v4i64(<4 x i64> %v, <4 x i64> %indices) nounw
; AVX512-NEXT: # kill: def $ymm1 killed $ymm1 def $zmm1
; AVX512-NEXT: # kill: def $ymm0 killed $ymm0 def $zmm0
; AVX512-NEXT: vpbroadcastq {{.*#+}} ymm2 = [3,3,3,3]
; AVX512-NEXT: vpcmpnleuq %zmm2, %zmm1, %k1
; AVX512-NEXT: vpcmpeqd %ymm3, %ymm3, %ymm3
; AVX512-NEXT: vpblendmq %zmm3, %zmm1, %zmm3 {%k1}
; AVX512-NEXT: vpcmpleuq %zmm2, %zmm1, %k1
; AVX512-NEXT: vpcmpnleuq %zmm2, %zmm1, %k2
; AVX512-NEXT: vpcmpeqd %ymm2, %ymm2, %ymm2
; AVX512-NEXT: vmovdqa64 %zmm2, %zmm1 {%k2}
; AVX512-NEXT: vpermq %zmm0, %zmm1, %zmm0 {%k1} {z}
; AVX512-NEXT: vpermq %zmm0, %zmm3, %zmm0 {%k1} {z}
; AVX512-NEXT: # kill: def $ymm0 killed $ymm0 killed $zmm0
; AVX512-NEXT: retq
;
Expand Down Expand Up @@ -1192,11 +1192,11 @@ define <4 x double> @var_shuffle_zero_v4f64(<4 x double> %v, <4 x i64> %indices)
; AVX512-NEXT: # kill: def $ymm1 killed $ymm1 def $zmm1
; AVX512-NEXT: # kill: def $ymm0 killed $ymm0 def $zmm0
; AVX512-NEXT: vpbroadcastq {{.*#+}} ymm2 = [3,3,3,3]
; AVX512-NEXT: vpcmpnleuq %zmm2, %zmm1, %k1
; AVX512-NEXT: vpcmpeqd %ymm3, %ymm3, %ymm3
; AVX512-NEXT: vpblendmq %zmm3, %zmm1, %zmm3 {%k1}
; AVX512-NEXT: vpcmpleuq %zmm2, %zmm1, %k1
; AVX512-NEXT: vpcmpnleuq %zmm2, %zmm1, %k2
; AVX512-NEXT: vpcmpeqd %ymm2, %ymm2, %ymm2
; AVX512-NEXT: vmovdqa64 %zmm2, %zmm1 {%k2}
; AVX512-NEXT: vpermpd %zmm0, %zmm1, %zmm0 {%k1} {z}
; AVX512-NEXT: vpermpd %zmm0, %zmm3, %zmm0 {%k1} {z}
; AVX512-NEXT: # kill: def $ymm0 killed $ymm0 killed $zmm0
; AVX512-NEXT: retq
;
Expand Down