@@ -7193,15 +7193,19 @@ static SDValue LowerAsSplatVectorLoad(SDValue SrcOp, MVT VT, const SDLoc &dl,
7193
7193
}
7194
7194
7195
7195
// Recurse to find a LoadSDNode source and the accumulated ByteOffest.
7196
- static bool findEltLoadSrc(SDValue Elt, LoadSDNode *&Ld, int64_t &ByteOffset) {
7197
- if (ISD::isNON_EXTLoad(Elt.getNode())) {
7198
- auto *BaseLd = cast<LoadSDNode>(Elt);
7199
- if (!BaseLd->isSimple())
7200
- return false;
7196
+ static bool findEltLoadSrc(SDValue Elt, MemSDNode *&Ld, int64_t &ByteOffset) {
7197
+ if (auto *BaseLd = dyn_cast<AtomicSDNode>(Elt)) {
7201
7198
Ld = BaseLd;
7202
7199
ByteOffset = 0;
7203
7200
return true;
7204
- }
7201
+ } else if (auto *BaseLd = dyn_cast<LoadSDNode>(Elt))
7202
+ if (ISD::isNON_EXTLoad(Elt.getNode())) {
7203
+ if (!BaseLd->isSimple())
7204
+ return false;
7205
+ Ld = BaseLd;
7206
+ ByteOffset = 0;
7207
+ return true;
7208
+ }
7205
7209
7206
7210
switch (Elt.getOpcode()) {
7207
7211
case ISD::BITCAST:
@@ -7230,6 +7234,20 @@ static bool findEltLoadSrc(SDValue Elt, LoadSDNode *&Ld, int64_t &ByteOffset) {
7230
7234
}
7231
7235
}
7232
7236
break;
7237
+ case ISD::EXTRACT_ELEMENT:
7238
+ if (auto *IdxC = dyn_cast<ConstantSDNode>(Elt.getOperand(1))) {
7239
+ SDValue Src = Elt.getOperand(0);
7240
+ unsigned SrcSizeInBits = Src.getScalarValueSizeInBits();
7241
+ unsigned DstSizeInBits = Elt.getScalarValueSizeInBits();
7242
+ if (2 * DstSizeInBits == SrcSizeInBits && (SrcSizeInBits % 8) == 0 &&
7243
+ findEltLoadSrc(Src, Ld, ByteOffset)) {
7244
+ uint64_t Idx = IdxC->getZExtValue();
7245
+ if (Idx == 1) // Get the upper half.
7246
+ ByteOffset += SrcSizeInBits / 8 / 2;
7247
+ return true;
7248
+ }
7249
+ }
7250
+ break;
7233
7251
}
7234
7252
7235
7253
return false;
@@ -7254,7 +7272,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
7254
7272
APInt ZeroMask = APInt::getZero(NumElems);
7255
7273
APInt UndefMask = APInt::getZero(NumElems);
7256
7274
7257
- SmallVector<LoadSDNode *, 8> Loads(NumElems, nullptr);
7275
+ SmallVector<MemSDNode *, 8> Loads(NumElems, nullptr);
7258
7276
SmallVector<int64_t, 8> ByteOffsets(NumElems, 0);
7259
7277
7260
7278
// For each element in the initializer, see if we've found a load, zero or an
@@ -7304,7 +7322,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
7304
7322
EVT EltBaseVT = EltBase.getValueType();
7305
7323
assert(EltBaseVT.getSizeInBits() == EltBaseVT.getStoreSizeInBits() &&
7306
7324
"Register/Memory size mismatch");
7307
- LoadSDNode *LDBase = Loads[FirstLoadedElt];
7325
+ MemSDNode *LDBase = Loads[FirstLoadedElt];
7308
7326
assert(LDBase && "Did not find base load for merging consecutive loads");
7309
7327
unsigned BaseSizeInBits = EltBaseVT.getStoreSizeInBits();
7310
7328
unsigned BaseSizeInBytes = BaseSizeInBits / 8;
@@ -7318,8 +7336,8 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
7318
7336
7319
7337
// Check to see if the element's load is consecutive to the base load
7320
7338
// or offset from a previous (already checked) load.
7321
- auto CheckConsecutiveLoad = [&](LoadSDNode *Base, int EltIdx) {
7322
- LoadSDNode *Ld = Loads[EltIdx];
7339
+ auto CheckConsecutiveLoad = [&](MemSDNode *Base, int EltIdx) {
7340
+ MemSDNode *Ld = Loads[EltIdx];
7323
7341
int64_t ByteOffset = ByteOffsets[EltIdx];
7324
7342
if (ByteOffset && (ByteOffset % BaseSizeInBytes) == 0) {
7325
7343
int64_t BaseIdx = EltIdx - (ByteOffset / BaseSizeInBytes);
@@ -7347,7 +7365,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
7347
7365
}
7348
7366
}
7349
7367
7350
- auto CreateLoad = [&DAG, &DL, &Loads](EVT VT, LoadSDNode *LDBase) {
7368
+ auto CreateLoad = [&DAG, &DL, &Loads](EVT VT, MemSDNode *LDBase) {
7351
7369
auto MMOFlags = LDBase->getMemOperand()->getFlags();
7352
7370
assert(LDBase->isSimple() &&
7353
7371
"Cannot merge volatile or atomic loads.");
@@ -9452,8 +9470,9 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const {
9452
9470
{
9453
9471
SmallVector<SDValue, 64> Ops(Op->ops().take_front(NumElems));
9454
9472
if (SDValue LD =
9455
- EltsFromConsecutiveLoads(VT, Ops, dl, DAG, Subtarget, false))
9473
+ EltsFromConsecutiveLoads(VT, Ops, dl, DAG, Subtarget, false)) {
9456
9474
return LD;
9475
+ }
9457
9476
}
9458
9477
9459
9478
// If this is a splat of pairs of 32-bit elements, we can use a narrower
@@ -60381,6 +60400,35 @@ static SDValue combineINTRINSIC_VOID(SDNode *N, SelectionDAG &DAG,
60381
60400
return SDValue();
60382
60401
}
60383
60402
60403
+ static SDValue combineVZEXT_LOAD(SDNode *N, SelectionDAG &DAG,
60404
+ TargetLowering::DAGCombinerInfo &DCI) {
60405
+ // Find the TokenFactor to locate the associated AtomicLoad.
60406
+ SDNode *ALD = nullptr;
60407
+ for (auto &TF : DAG.allnodes())
60408
+ if (TF.getOpcode() == ISD::TokenFactor) {
60409
+ SDValue L = TF.getOperand(0);
60410
+ SDValue R = TF.getOperand(1);
60411
+ if (L.getNode() == N)
60412
+ ALD = R.getNode();
60413
+ else if (R.getNode() == N)
60414
+ ALD = L.getNode();
60415
+ }
60416
+
60417
+ if (!ALD)
60418
+ return SDValue();
60419
+ if (!isa<AtomicSDNode>(ALD))
60420
+ return SDValue();
60421
+
60422
+ // Replace the VZEXT_LOAD with the AtomicLoad.
60423
+ SDLoc dl(N);
60424
+ SDValue SV = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl,
60425
+ N->getValueType(0).changeTypeToInteger(),
60426
+ SDValue(ALD, 0));
60427
+ SDValue BC = DAG.getNode(ISD::BITCAST, dl, N->getValueType(0), SV);
60428
+ BC = DCI.CombineTo(N, BC, SDValue(ALD, 1));
60429
+ return BC;
60430
+ }
60431
+
60384
60432
SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
60385
60433
DAGCombinerInfo &DCI) const {
60386
60434
SelectionDAG &DAG = DCI.DAG;
@@ -60577,6 +60625,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
60577
60625
case ISD::INTRINSIC_VOID: return combineINTRINSIC_VOID(N, DAG, DCI);
60578
60626
case ISD::FP_TO_SINT_SAT:
60579
60627
case ISD::FP_TO_UINT_SAT: return combineFP_TO_xINT_SAT(N, DAG, Subtarget);
60628
+ case X86ISD::VZEXT_LOAD: return combineVZEXT_LOAD(N, DAG, DCI);
60580
60629
// clang-format on
60581
60630
}
60582
60631
0 commit comments