-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[NVPTX] Lower LLVM masked vector stores to PTX using new sink symbol syntax #159387
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[NVPTX] Lower LLVM masked vector stores to PTX using new sink symbol syntax #159387
Conversation
@llvm/pr-subscribers-backend-hexagon @llvm/pr-subscribers-backend-arm Author: Drew Kersnar (dakersnar) ChangesThis backend support will allow the LoadStoreVectorizer, in certain cases, to fill in gaps when creating store vectors and generate LLVM masked stores (https://llvm.org/docs/LangRef.html#llvm-masked-store-intrinsics), which get lowered to PTX using the new sink symbol syntax (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st). To accomplish this, changes are separated into two parts. This first part has the backend lowering and TTI changes, and a follow up PR will have the LSV generate these intrinsics: [INSERT] TTI changes are needed because NVPTX only supports masked stores with constant masks. If the masked stores make it to the NVPTX backend without being scalarized, they are handled by the following:
For example,
Patch is 35.48 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/159387.diff 20 Files Affected:
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 41ff54f0781a2..e7886537379bc 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -810,9 +810,13 @@ class TargetTransformInfo {
LLVM_ABI AddressingModeKind
getPreferredAddressingMode(const Loop *L, ScalarEvolution *SE) const;
- /// Return true if the target supports masked store.
+ /// Return true if the target supports masked store. A value of false for
+ /// IsMaskConstant indicates that the mask could either be variable or
+ /// constant. This is for targets that only support masked store with a
+ /// constant mask.
LLVM_ABI bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddressSpace) const;
+ unsigned AddressSpace,
+ bool IsMaskConstant = false) const;
/// Return true if the target supports masked load.
LLVM_ABI bool isLegalMaskedLoad(Type *DataType, Align Alignment,
unsigned AddressSpace) const;
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 566e1cf51631a..33705e1dd5f98 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -309,7 +309,7 @@ class TargetTransformInfoImplBase {
}
virtual bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddressSpace) const {
+ unsigned AddressSpace, bool IsMaskConstant) const {
return false;
}
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 09b50c5270e57..838712e55d0dd 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -467,8 +467,8 @@ TargetTransformInfo::getPreferredAddressingMode(const Loop *L,
}
bool TargetTransformInfo::isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddressSpace) const {
- return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace);
+ unsigned AddressSpace, bool IsMaskConstant) const {
+ return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace, IsMaskConstant);
}
bool TargetTransformInfo::isLegalMaskedLoad(Type *DataType, Align Alignment,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index fe2e849258e3f..e40631d88748c 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -321,7 +321,7 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
}
bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned /*AddressSpace*/) const override {
+ unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
return isLegalMaskedLoadStore(DataType, Alignment);
}
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index 0810c5532ed91..ee4f72552d90d 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -190,7 +190,7 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
unsigned AddressSpace) const override;
bool isLegalMaskedStore(Type *DataTy, Align Alignment,
- unsigned AddressSpace) const override {
+ unsigned AddressSpace, bool /*IsMaskConstant*/) const override {
return isLegalMaskedLoad(DataTy, Alignment, AddressSpace);
}
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
index 171e2949366ad..c989bf77a9d51 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
@@ -341,7 +341,7 @@ InstructionCost HexagonTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
}
bool HexagonTTIImpl::isLegalMaskedStore(Type *DataType, Align /*Alignment*/,
- unsigned /*AddressSpace*/) const {
+ unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const {
// This function is called from scalarize-masked-mem-intrin, which runs
// in pre-isel. Use ST directly instead of calling isHVXVectorType.
return HexagonMaskedVMem && ST.isTypeForHVX(DataType);
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
index dbf16c99c314c..e2674bb9cdad7 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
@@ -166,7 +166,7 @@ class HexagonTTIImpl final : public BasicTTIImplBase<HexagonTTIImpl> {
}
bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddressSpace) const override;
+ unsigned AddressSpace, bool IsMaskConstant) const override;
bool isLegalMaskedLoad(Type *DataType, Align Alignment,
unsigned AddressSpace) const override;
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index f9bdc09935330..dc6b631d33451 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -392,6 +392,16 @@ void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum,
}
}
+void NVPTXInstPrinter::printRegisterOrSinkSymbol(const MCInst *MI, int OpNum,
+ raw_ostream &O,
+ const char *Modifier) {
+ const MCOperand &Op = MI->getOperand(OpNum);
+ if (Op.isReg() && Op.getReg() == MCRegister::NoRegister)
+ O << "_";
+ else
+ printOperand(MI, OpNum, O);
+}
+
void NVPTXInstPrinter::printHexu32imm(const MCInst *MI, int OpNum,
raw_ostream &O) {
int64_t Imm = MI->getOperand(OpNum).getImm();
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
index 92155b01464e8..d373668aa591f 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
@@ -46,6 +46,8 @@ class NVPTXInstPrinter : public MCInstPrinter {
StringRef Modifier = {});
void printMemOperand(const MCInst *MI, int OpNum, raw_ostream &O,
StringRef Modifier = {});
+ void printRegisterOrSinkSymbol(const MCInst *MI, int OpNum, raw_ostream &O,
+ const char *Modifier = nullptr);
void printHexu32imm(const MCInst *MI, int OpNum, raw_ostream &O);
void printProtoIdent(const MCInst *MI, int OpNum, raw_ostream &O);
void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O);
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index d3fb657851fe2..6810b6008d8cf 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -753,7 +753,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction({ISD::LOAD, ISD::STORE}, {MVT::i128, MVT::f128}, Custom);
for (MVT VT : MVT::fixedlen_vector_valuetypes())
if (!isTypeLegal(VT) && VT.getStoreSizeInBits() <= 256)
- setOperationAction({ISD::STORE, ISD::LOAD}, VT, Custom);
+ setOperationAction({ISD::STORE, ISD::LOAD, ISD::MSTORE}, VT, Custom);
// Custom legalization for LDU intrinsics.
// TODO: The logic to lower these is not very robust and we should rewrite it.
@@ -2869,6 +2869,87 @@ static SDValue lowerSELECT(SDValue Op, SelectionDAG &DAG) {
return Or;
}
+static SDValue lowerMSTORE(SDValue Op, SelectionDAG &DAG) {
+ SDNode *N = Op.getNode();
+
+ SDValue Chain = N->getOperand(0);
+ SDValue Val = N->getOperand(1);
+ SDValue BasePtr = N->getOperand(2);
+ SDValue Offset = N->getOperand(3);
+ SDValue Mask = N->getOperand(4);
+
+ SDLoc DL(N);
+ EVT ValVT = Val.getValueType();
+ MemSDNode *MemSD = cast<MemSDNode>(N);
+ assert(ValVT.isVector() && "Masked vector store must have vector type");
+ assert(MemSD->getAlign() >= DAG.getEVTAlign(ValVT) &&
+ "Unexpected alignment for masked store");
+
+ unsigned Opcode = 0;
+ switch (ValVT.getSimpleVT().SimpleTy) {
+ default:
+ llvm_unreachable("Unexpected masked vector store type");
+ case MVT::v4i64:
+ case MVT::v4f64: {
+ Opcode = NVPTXISD::StoreV4;
+ break;
+ }
+ case MVT::v8i32:
+ case MVT::v8f32: {
+ Opcode = NVPTXISD::StoreV8;
+ break;
+ }
+ }
+
+ SmallVector<SDValue, 8> Ops;
+
+ // Construct the new SDNode. First operand is the chain.
+ Ops.push_back(Chain);
+
+ // The next N operands are the values to store. Encode the mask into the
+ // values using the sentinel register 0 to represent a masked-off element.
+ assert(Mask.getValueType().isVector() &&
+ Mask.getValueType().getVectorElementType() == MVT::i1 &&
+ "Mask must be a vector of i1");
+ assert(Mask.getOpcode() == ISD::BUILD_VECTOR &&
+ "Mask expected to be a BUILD_VECTOR");
+ assert(Mask.getValueType().getVectorNumElements() ==
+ ValVT.getVectorNumElements() &&
+ "Mask size must be the same as the vector size");
+ for (unsigned I : llvm::seq(ValVT.getVectorNumElements())) {
+ assert(isa<ConstantSDNode>(Mask.getOperand(I)) &&
+ "Mask elements must be constants");
+ if (Mask->getConstantOperandVal(I) == 0) {
+ // Append a sentinel register 0 to the Ops vector to represent a masked
+ // off element, this will be handled in tablegen
+ Ops.push_back(DAG.getRegister(MCRegister::NoRegister,
+ ValVT.getVectorElementType()));
+ } else {
+ // Extract the element from the vector to store
+ SDValue ExtVal =
+ DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ValVT.getVectorElementType(),
+ Val, DAG.getIntPtrConstant(I, DL));
+ Ops.push_back(ExtVal);
+ }
+ }
+
+ // Next, the pointer operand.
+ Ops.push_back(BasePtr);
+
+ // Finally, the offset operand. We expect this to always be undef, and it will
+ // be ignored in lowering, but to mirror the handling of the other vector
+ // store instructions we include it in the new SDNode.
+ assert(Offset.getOpcode() == ISD::UNDEF &&
+ "Offset operand expected to be undef");
+ Ops.push_back(Offset);
+
+ SDValue NewSt =
+ DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops,
+ MemSD->getMemoryVT(), MemSD->getMemOperand());
+
+ return NewSt;
+}
+
SDValue
NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
switch (Op.getOpcode()) {
@@ -2905,6 +2986,12 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
return LowerVECREDUCE(Op, DAG);
case ISD::STORE:
return LowerSTORE(Op, DAG);
+ case ISD::MSTORE: {
+ assert(STI.has256BitVectorLoadStore(
+ cast<MemSDNode>(Op.getNode())->getAddressSpace()) &&
+ "Masked store vector not supported on subtarget.");
+ return lowerMSTORE(Op, DAG);
+ }
case ISD::LOAD:
return LowerLOAD(Op, DAG);
case ISD::SHL_PARTS:
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 4e38e026e6bda..a8d6ff60c9b82 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -1500,6 +1500,10 @@ def ADDR : Operand<pAny> {
let MIOperandInfo = (ops ADDR_base, i32imm);
}
+def RegOrSink : Operand<Any> {
+ let PrintMethod = "printRegisterOrSinkSymbol";
+}
+
def AtomicCode : Operand<i32> {
let PrintMethod = "printAtomicCode";
}
@@ -1806,7 +1810,7 @@ multiclass ST_VEC<DAGOperand O, bit support_v8 = false> {
"\t[$addr], {{$src1, $src2}};">;
def _v4 : NVPTXInst<
(outs),
- (ins O:$src1, O:$src2, O:$src3, O:$src4,
+ (ins RegOrSink:$src1, RegOrSink:$src2, RegOrSink:$src3, RegOrSink:$src4,
AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, i32imm:$fromWidth,
ADDR:$addr),
"st${sem:sem}${scope:scope}${addsp:addsp}.v4.b$fromWidth "
@@ -1814,8 +1818,8 @@ multiclass ST_VEC<DAGOperand O, bit support_v8 = false> {
if support_v8 then
def _v8 : NVPTXInst<
(outs),
- (ins O:$src1, O:$src2, O:$src3, O:$src4,
- O:$src5, O:$src6, O:$src7, O:$src8,
+ (ins RegOrSink:$src1, RegOrSink:$src2, RegOrSink:$src3, RegOrSink:$src4,
+ RegOrSink:$src5, RegOrSink:$src6, RegOrSink:$src7, RegOrSink:$src8,
AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, i32imm:$fromWidth,
ADDR:$addr),
"st${sem:sem}${scope:scope}${addsp:addsp}.v8.b$fromWidth "
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index f4f89613b358d..88b13cb38d67b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -597,6 +597,32 @@ Value *NVPTXTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
return nullptr;
}
+bool NVPTXTTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
+ unsigned AddrSpace, bool IsMaskConstant) const {
+
+ if (!IsMaskConstant)
+ return false;
+
+ // We currently only support this feature for 256-bit vectors, so the
+ // alignment must be at least 32
+ if (Alignment < 32)
+ return false;
+
+ if (!ST->has256BitVectorLoadStore(AddrSpace))
+ return false;
+
+ auto *VTy = dyn_cast<FixedVectorType>(DataTy);
+ if (!VTy)
+ return false;
+
+ auto *ScalarTy = VTy->getScalarType();
+ if ((ScalarTy->getScalarSizeInBits() == 32 && VTy->getNumElements() == 8) ||
+ (ScalarTy->getScalarSizeInBits() == 64 && VTy->getNumElements() == 4))
+ return true;
+
+ return false;
+}
+
unsigned NVPTXTTIImpl::getLoadStoreVecRegBitWidth(unsigned AddrSpace) const {
// 256 bit loads/stores are currently only supported for global address space
if (ST->has256BitVectorLoadStore(AddrSpace))
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index b32d931bd3074..9e5500966fe10 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -181,6 +181,9 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
bool collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,
Intrinsic::ID IID) const override;
+ bool isLegalMaskedStore(Type *DataType, Align Alignment,
+ unsigned AddrSpace, bool IsMaskConstant) const override;
+
unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const override;
Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV,
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 47e0a250d285a..80f10eb29bca4 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -287,7 +287,7 @@ class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
return isLegalMaskedLoadStore(DataType, Alignment);
}
bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned /*AddressSpace*/) const override {
+ unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
return isLegalMaskedLoadStore(DataType, Alignment);
}
diff --git a/llvm/lib/Target/VE/VETargetTransformInfo.h b/llvm/lib/Target/VE/VETargetTransformInfo.h
index 5c0ddca62c761..4971d9148b747 100644
--- a/llvm/lib/Target/VE/VETargetTransformInfo.h
+++ b/llvm/lib/Target/VE/VETargetTransformInfo.h
@@ -139,7 +139,7 @@ class VETTIImpl final : public BasicTTIImplBase<VETTIImpl> {
return isVectorLaneType(*getLaneType(DataType));
}
bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned /*AddressSpace*/) const override {
+ unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
return isVectorLaneType(*getLaneType(DataType));
}
bool isLegalMaskedGather(Type *DataType, Align Alignment) const override {
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index 3d8d0a236a3c1..b16a2a593df03 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -6330,7 +6330,7 @@ bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
}
bool X86TTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
- unsigned AddressSpace) const {
+ unsigned AddressSpace, bool IsMaskConstant) const {
Type *ScalarTy = DataTy->getScalarType();
// The backend can't handle a single element vector w/o CFCMOV.
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index 133b3668a46c8..7f6ff65d427ed 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -271,7 +271,7 @@ class X86TTIImpl final : public BasicTTIImplBase<X86TTIImpl> {
bool isLegalMaskedLoad(Type *DataType, Align Alignment,
unsigned AddressSpace) const override;
bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddressSpace) const override;
+ unsigned AddressSpace, bool IsMaskConstant = false) const override;
bool isLegalNTLoad(Type *DataType, Align Alignment) const override;
bool isLegalNTStore(Type *DataType, Align Alignment) const override;
bool isLegalBroadcastLoad(Type *ElementTy,
diff --git a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
index 42d6680c3cb7d..412c1b04cdf3a 100644
--- a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
+++ b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
@@ -1137,7 +1137,8 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
CI->getArgOperand(0)->getType(),
cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue(),
cast<PointerType>(CI->getArgOperand(1)->getType())
- ->getAddressSpace()))
+ ->getAddressSpace(),
+ isConstantIntVector(CI->getArgOperand(3))))
return false;
scalarizeMaskedStore(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
return true;
diff --git a/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll b/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll
new file mode 100644
index 0000000000000..7d8f65b25bb02
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll
@@ -0,0 +1,56 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | FileCheck %s -check-prefixes=CHECK
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | %ptxas-verify -arch=sm_100 %}
+
+; Confirm that a masked store with a variable mask is scalarized before lowering
+
+define void @global_variable_mask(ptr addrspace(1) %a, ptr addrspace(1) %b, <4 x i1> %mask) {
+; CHECK-LABEL: global_variable_mask(
+; CHECK: {
+; CHECK-NEXT: .reg .pred %p<9>;
+; CHECK-NEXT: .reg .b16 %rs<9>;
+; CHECK-NEXT: .reg .b64 %rd<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b8 %rs1, [global_variable_mask_param_2+3];
+; CHECK-NEXT: ld.param.b8 %rs3, [global_variable_mask_param_2+2];
+; CHECK-NEXT: and.b16 %rs4, %rs3, 1;
+; CHECK-NEXT: ld.param.b8 %rs5, [global_var...
[truncated]
|
@llvm/pr-subscribers-backend-nvptx Author: Drew Kersnar (dakersnar) ChangesThis backend support will allow the LoadStoreVectorizer, in certain cases, to fill in gaps when creating store vectors and generate LLVM masked stores (https://llvm.org/docs/LangRef.html#llvm-masked-store-intrinsics), which get lowered to PTX using the new sink symbol syntax (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st). To accomplish this, changes are separated into two parts. This first part has the backend lowering and TTI changes, and a follow up PR will have the LSV generate these intrinsics: [INSERT] TTI changes are needed because NVPTX only supports masked stores with constant masks. If the masked stores make it to the NVPTX backend without being scalarized, they are handled by the following:
For example,
Patch is 35.48 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/159387.diff 20 Files Affected:
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 41ff54f0781a2..e7886537379bc 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -810,9 +810,13 @@ class TargetTransformInfo {
LLVM_ABI AddressingModeKind
getPreferredAddressingMode(const Loop *L, ScalarEvolution *SE) const;
- /// Return true if the target supports masked store.
+ /// Return true if the target supports masked store. A value of false for
+ /// IsMaskConstant indicates that the mask could either be variable or
+ /// constant. This is for targets that only support masked store with a
+ /// constant mask.
LLVM_ABI bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddressSpace) const;
+ unsigned AddressSpace,
+ bool IsMaskConstant = false) const;
/// Return true if the target supports masked load.
LLVM_ABI bool isLegalMaskedLoad(Type *DataType, Align Alignment,
unsigned AddressSpace) const;
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 566e1cf51631a..33705e1dd5f98 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -309,7 +309,7 @@ class TargetTransformInfoImplBase {
}
virtual bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddressSpace) const {
+ unsigned AddressSpace, bool IsMaskConstant) const {
return false;
}
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 09b50c5270e57..838712e55d0dd 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -467,8 +467,8 @@ TargetTransformInfo::getPreferredAddressingMode(const Loop *L,
}
bool TargetTransformInfo::isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddressSpace) const {
- return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace);
+ unsigned AddressSpace, bool IsMaskConstant) const {
+ return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace, IsMaskConstant);
}
bool TargetTransformInfo::isLegalMaskedLoad(Type *DataType, Align Alignment,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index fe2e849258e3f..e40631d88748c 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -321,7 +321,7 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
}
bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned /*AddressSpace*/) const override {
+ unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
return isLegalMaskedLoadStore(DataType, Alignment);
}
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index 0810c5532ed91..ee4f72552d90d 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -190,7 +190,7 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
unsigned AddressSpace) const override;
bool isLegalMaskedStore(Type *DataTy, Align Alignment,
- unsigned AddressSpace) const override {
+ unsigned AddressSpace, bool /*IsMaskConstant*/) const override {
return isLegalMaskedLoad(DataTy, Alignment, AddressSpace);
}
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
index 171e2949366ad..c989bf77a9d51 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
@@ -341,7 +341,7 @@ InstructionCost HexagonTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
}
bool HexagonTTIImpl::isLegalMaskedStore(Type *DataType, Align /*Alignment*/,
- unsigned /*AddressSpace*/) const {
+ unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const {
// This function is called from scalarize-masked-mem-intrin, which runs
// in pre-isel. Use ST directly instead of calling isHVXVectorType.
return HexagonMaskedVMem && ST.isTypeForHVX(DataType);
diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
index dbf16c99c314c..e2674bb9cdad7 100644
--- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
+++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
@@ -166,7 +166,7 @@ class HexagonTTIImpl final : public BasicTTIImplBase<HexagonTTIImpl> {
}
bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddressSpace) const override;
+ unsigned AddressSpace, bool IsMaskConstant) const override;
bool isLegalMaskedLoad(Type *DataType, Align Alignment,
unsigned AddressSpace) const override;
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index f9bdc09935330..dc6b631d33451 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -392,6 +392,16 @@ void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum,
}
}
+void NVPTXInstPrinter::printRegisterOrSinkSymbol(const MCInst *MI, int OpNum,
+ raw_ostream &O,
+ const char *Modifier) {
+ const MCOperand &Op = MI->getOperand(OpNum);
+ if (Op.isReg() && Op.getReg() == MCRegister::NoRegister)
+ O << "_";
+ else
+ printOperand(MI, OpNum, O);
+}
+
void NVPTXInstPrinter::printHexu32imm(const MCInst *MI, int OpNum,
raw_ostream &O) {
int64_t Imm = MI->getOperand(OpNum).getImm();
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
index 92155b01464e8..d373668aa591f 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
@@ -46,6 +46,8 @@ class NVPTXInstPrinter : public MCInstPrinter {
StringRef Modifier = {});
void printMemOperand(const MCInst *MI, int OpNum, raw_ostream &O,
StringRef Modifier = {});
+ void printRegisterOrSinkSymbol(const MCInst *MI, int OpNum, raw_ostream &O,
+ const char *Modifier = nullptr);
void printHexu32imm(const MCInst *MI, int OpNum, raw_ostream &O);
void printProtoIdent(const MCInst *MI, int OpNum, raw_ostream &O);
void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O);
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index d3fb657851fe2..6810b6008d8cf 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -753,7 +753,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction({ISD::LOAD, ISD::STORE}, {MVT::i128, MVT::f128}, Custom);
for (MVT VT : MVT::fixedlen_vector_valuetypes())
if (!isTypeLegal(VT) && VT.getStoreSizeInBits() <= 256)
- setOperationAction({ISD::STORE, ISD::LOAD}, VT, Custom);
+ setOperationAction({ISD::STORE, ISD::LOAD, ISD::MSTORE}, VT, Custom);
// Custom legalization for LDU intrinsics.
// TODO: The logic to lower these is not very robust and we should rewrite it.
@@ -2869,6 +2869,87 @@ static SDValue lowerSELECT(SDValue Op, SelectionDAG &DAG) {
return Or;
}
+static SDValue lowerMSTORE(SDValue Op, SelectionDAG &DAG) {
+ SDNode *N = Op.getNode();
+
+ SDValue Chain = N->getOperand(0);
+ SDValue Val = N->getOperand(1);
+ SDValue BasePtr = N->getOperand(2);
+ SDValue Offset = N->getOperand(3);
+ SDValue Mask = N->getOperand(4);
+
+ SDLoc DL(N);
+ EVT ValVT = Val.getValueType();
+ MemSDNode *MemSD = cast<MemSDNode>(N);
+ assert(ValVT.isVector() && "Masked vector store must have vector type");
+ assert(MemSD->getAlign() >= DAG.getEVTAlign(ValVT) &&
+ "Unexpected alignment for masked store");
+
+ unsigned Opcode = 0;
+ switch (ValVT.getSimpleVT().SimpleTy) {
+ default:
+ llvm_unreachable("Unexpected masked vector store type");
+ case MVT::v4i64:
+ case MVT::v4f64: {
+ Opcode = NVPTXISD::StoreV4;
+ break;
+ }
+ case MVT::v8i32:
+ case MVT::v8f32: {
+ Opcode = NVPTXISD::StoreV8;
+ break;
+ }
+ }
+
+ SmallVector<SDValue, 8> Ops;
+
+ // Construct the new SDNode. First operand is the chain.
+ Ops.push_back(Chain);
+
+ // The next N operands are the values to store. Encode the mask into the
+ // values using the sentinel register 0 to represent a masked-off element.
+ assert(Mask.getValueType().isVector() &&
+ Mask.getValueType().getVectorElementType() == MVT::i1 &&
+ "Mask must be a vector of i1");
+ assert(Mask.getOpcode() == ISD::BUILD_VECTOR &&
+ "Mask expected to be a BUILD_VECTOR");
+ assert(Mask.getValueType().getVectorNumElements() ==
+ ValVT.getVectorNumElements() &&
+ "Mask size must be the same as the vector size");
+ for (unsigned I : llvm::seq(ValVT.getVectorNumElements())) {
+ assert(isa<ConstantSDNode>(Mask.getOperand(I)) &&
+ "Mask elements must be constants");
+ if (Mask->getConstantOperandVal(I) == 0) {
+ // Append a sentinel register 0 to the Ops vector to represent a masked
+ // off element, this will be handled in tablegen
+ Ops.push_back(DAG.getRegister(MCRegister::NoRegister,
+ ValVT.getVectorElementType()));
+ } else {
+ // Extract the element from the vector to store
+ SDValue ExtVal =
+ DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ValVT.getVectorElementType(),
+ Val, DAG.getIntPtrConstant(I, DL));
+ Ops.push_back(ExtVal);
+ }
+ }
+
+ // Next, the pointer operand.
+ Ops.push_back(BasePtr);
+
+ // Finally, the offset operand. We expect this to always be undef, and it will
+ // be ignored in lowering, but to mirror the handling of the other vector
+ // store instructions we include it in the new SDNode.
+ assert(Offset.getOpcode() == ISD::UNDEF &&
+ "Offset operand expected to be undef");
+ Ops.push_back(Offset);
+
+ SDValue NewSt =
+ DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops,
+ MemSD->getMemoryVT(), MemSD->getMemOperand());
+
+ return NewSt;
+}
+
SDValue
NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
switch (Op.getOpcode()) {
@@ -2905,6 +2986,12 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
return LowerVECREDUCE(Op, DAG);
case ISD::STORE:
return LowerSTORE(Op, DAG);
+ case ISD::MSTORE: {
+ assert(STI.has256BitVectorLoadStore(
+ cast<MemSDNode>(Op.getNode())->getAddressSpace()) &&
+ "Masked store vector not supported on subtarget.");
+ return lowerMSTORE(Op, DAG);
+ }
case ISD::LOAD:
return LowerLOAD(Op, DAG);
case ISD::SHL_PARTS:
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 4e38e026e6bda..a8d6ff60c9b82 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -1500,6 +1500,10 @@ def ADDR : Operand<pAny> {
let MIOperandInfo = (ops ADDR_base, i32imm);
}
+def RegOrSink : Operand<Any> {
+ let PrintMethod = "printRegisterOrSinkSymbol";
+}
+
def AtomicCode : Operand<i32> {
let PrintMethod = "printAtomicCode";
}
@@ -1806,7 +1810,7 @@ multiclass ST_VEC<DAGOperand O, bit support_v8 = false> {
"\t[$addr], {{$src1, $src2}};">;
def _v4 : NVPTXInst<
(outs),
- (ins O:$src1, O:$src2, O:$src3, O:$src4,
+ (ins RegOrSink:$src1, RegOrSink:$src2, RegOrSink:$src3, RegOrSink:$src4,
AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, i32imm:$fromWidth,
ADDR:$addr),
"st${sem:sem}${scope:scope}${addsp:addsp}.v4.b$fromWidth "
@@ -1814,8 +1818,8 @@ multiclass ST_VEC<DAGOperand O, bit support_v8 = false> {
if support_v8 then
def _v8 : NVPTXInst<
(outs),
- (ins O:$src1, O:$src2, O:$src3, O:$src4,
- O:$src5, O:$src6, O:$src7, O:$src8,
+ (ins RegOrSink:$src1, RegOrSink:$src2, RegOrSink:$src3, RegOrSink:$src4,
+ RegOrSink:$src5, RegOrSink:$src6, RegOrSink:$src7, RegOrSink:$src8,
AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, i32imm:$fromWidth,
ADDR:$addr),
"st${sem:sem}${scope:scope}${addsp:addsp}.v8.b$fromWidth "
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index f4f89613b358d..88b13cb38d67b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -597,6 +597,32 @@ Value *NVPTXTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
return nullptr;
}
+bool NVPTXTTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
+ unsigned AddrSpace, bool IsMaskConstant) const {
+
+ if (!IsMaskConstant)
+ return false;
+
+ // We currently only support this feature for 256-bit vectors, so the
+ // alignment must be at least 32
+ if (Alignment < 32)
+ return false;
+
+ if (!ST->has256BitVectorLoadStore(AddrSpace))
+ return false;
+
+ auto *VTy = dyn_cast<FixedVectorType>(DataTy);
+ if (!VTy)
+ return false;
+
+ auto *ScalarTy = VTy->getScalarType();
+ if ((ScalarTy->getScalarSizeInBits() == 32 && VTy->getNumElements() == 8) ||
+ (ScalarTy->getScalarSizeInBits() == 64 && VTy->getNumElements() == 4))
+ return true;
+
+ return false;
+}
+
unsigned NVPTXTTIImpl::getLoadStoreVecRegBitWidth(unsigned AddrSpace) const {
// 256 bit loads/stores are currently only supported for global address space
if (ST->has256BitVectorLoadStore(AddrSpace))
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index b32d931bd3074..9e5500966fe10 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -181,6 +181,9 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
bool collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,
Intrinsic::ID IID) const override;
+ bool isLegalMaskedStore(Type *DataType, Align Alignment,
+ unsigned AddrSpace, bool IsMaskConstant) const override;
+
unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const override;
Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV,
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 47e0a250d285a..80f10eb29bca4 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -287,7 +287,7 @@ class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
return isLegalMaskedLoadStore(DataType, Alignment);
}
bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned /*AddressSpace*/) const override {
+ unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
return isLegalMaskedLoadStore(DataType, Alignment);
}
diff --git a/llvm/lib/Target/VE/VETargetTransformInfo.h b/llvm/lib/Target/VE/VETargetTransformInfo.h
index 5c0ddca62c761..4971d9148b747 100644
--- a/llvm/lib/Target/VE/VETargetTransformInfo.h
+++ b/llvm/lib/Target/VE/VETargetTransformInfo.h
@@ -139,7 +139,7 @@ class VETTIImpl final : public BasicTTIImplBase<VETTIImpl> {
return isVectorLaneType(*getLaneType(DataType));
}
bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned /*AddressSpace*/) const override {
+ unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
return isVectorLaneType(*getLaneType(DataType));
}
bool isLegalMaskedGather(Type *DataType, Align Alignment) const override {
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index 3d8d0a236a3c1..b16a2a593df03 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -6330,7 +6330,7 @@ bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
}
bool X86TTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
- unsigned AddressSpace) const {
+ unsigned AddressSpace, bool IsMaskConstant) const {
Type *ScalarTy = DataTy->getScalarType();
// The backend can't handle a single element vector w/o CFCMOV.
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index 133b3668a46c8..7f6ff65d427ed 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -271,7 +271,7 @@ class X86TTIImpl final : public BasicTTIImplBase<X86TTIImpl> {
bool isLegalMaskedLoad(Type *DataType, Align Alignment,
unsigned AddressSpace) const override;
bool isLegalMaskedStore(Type *DataType, Align Alignment,
- unsigned AddressSpace) const override;
+ unsigned AddressSpace, bool IsMaskConstant = false) const override;
bool isLegalNTLoad(Type *DataType, Align Alignment) const override;
bool isLegalNTStore(Type *DataType, Align Alignment) const override;
bool isLegalBroadcastLoad(Type *ElementTy,
diff --git a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
index 42d6680c3cb7d..412c1b04cdf3a 100644
--- a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
+++ b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
@@ -1137,7 +1137,8 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
CI->getArgOperand(0)->getType(),
cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue(),
cast<PointerType>(CI->getArgOperand(1)->getType())
- ->getAddressSpace()))
+ ->getAddressSpace(),
+ isConstantIntVector(CI->getArgOperand(3))))
return false;
scalarizeMaskedStore(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
return true;
diff --git a/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll b/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll
new file mode 100644
index 0000000000000..7d8f65b25bb02
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll
@@ -0,0 +1,56 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | FileCheck %s -check-prefixes=CHECK
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | %ptxas-verify -arch=sm_100 %}
+
+; Confirm that a masked store with a variable mask is scalarized before lowering
+
+define void @global_variable_mask(ptr addrspace(1) %a, ptr addrspace(1) %b, <4 x i1> %mask) {
+; CHECK-LABEL: global_variable_mask(
+; CHECK: {
+; CHECK-NEXT: .reg .pred %p<9>;
+; CHECK-NEXT: .reg .b16 %rs<9>;
+; CHECK-NEXT: .reg .b64 %rd<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b8 %rs1, [global_variable_mask_param_2+3];
+; CHECK-NEXT: ld.param.b8 %rs3, [global_variable_mask_param_2+2];
+; CHECK-NEXT: and.b16 %rs4, %rs3, 1;
+; CHECK-NEXT: ld.param.b8 %rs5, [global_var...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
LLVM_ABI bool isLegalMaskedStore(Type *DataType, Align Alignment, | ||
unsigned AddressSpace) const; | ||
unsigned AddressSpace, | ||
bool IsMaskConstant = false) const; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, if the purpose of the check is to check validity of X, we should pass X itself. Passing derived into instead of the actual object we're checking makes everyone do potentially unnecessary checks and limits what property of X we get to check. While a bool indicating contness of the masked store is sufficient for us here and now, it's not the best choice for a generic API that potentially needs to work for different architectures with different requirements.
; SM100-NEXT: ld.param.b64 %rd1, [global_8xi32_param_0]; | ||
; SM100-NEXT: ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1]; | ||
; SM100-NEXT: ld.param.b64 %rd2, [global_8xi32_param_1]; | ||
; SM100-NEXT: st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there any downsides of issuing a masked store with only one element enabled?
A potential future optimization would be to extend our load/store vectorization on lowering to take advantage of the masked stores. Right now we're looking for contiguous stores, but if we can find a sequence of properly aligned disjoint stores, that would be an easy win. Not sure though if higher level LLVM optimizations would leave us many opportunities, though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I believe that there are some potential register pressure issues that can arise if we are too aggressive with generating masked stores. That's why I tuned the heuristic in the LoadStoreVectorizer in the way that I did, to only fill gaps of 1-2 elements: #159388
Right now we're looking for contiguous stores, but if we can find a sequence of properly aligned disjoint stores, that would be an easy win. Not sure though if higher level LLVM optimizations would leave us many opportunities, though.
Is this the same as the LSV changes linked above, or were you imagining something in else?
This backend support will allow the LoadStoreVectorizer, in certain cases, to fill in gaps when creating store vectors and generate LLVM masked stores (https://llvm.org/docs/LangRef.html#llvm-masked-store-intrinsics), which get lowered to PTX using the new sink symbol syntax (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st). To accomplish this, changes are separated into two parts. This first part has the backend lowering and TTI changes, and a follow up PR will have the LSV generate these intrinsics: #159388.
TTI changes are needed because NVPTX only supports masked stores with constant masks.
ScalarizeMaskedMemIntrin.cpp
is adjusted to check that the mask is constant and pass that result into the TTI check. Behavior shouldn't change for non-NVPTX targets, which do not care whether the mask is variable or constant when determining legality.If the masked stores make it to the NVPTX backend without being scalarized, they are handled by the following:
NVPTXISelLowering.cpp
- Sets up a custom operation action and handles it in lowerMSTORE. Similar handling to normal store vectors, except we read the mask and place a sentinel register$noreg
in each position where the mask reads as false.For example,
NVPTXInstInfo.td
- changes the definition of store vectors to allow for a mix of sink symbols and registers.NVPXInstPrinter.h/.cpp
- Handles the$noreg
case by printing "_".