Skip to content

Commit 7e297e6

Browse files
[RISC-V] Utilize Zba extension instructions (#113999)
* [RISC-V] Group arrRef with index * [RISC-V] Add Zba instructions * [RISC-V] Create Shxadd GenTree * [RISC-V] Lower ADD(LSH) node to SH(X)ADD(.UW) node * [RISC-V] Utilize SH(X)ADD instruction for GT_INDEX_ADDR * [RISC-V] Fix build error: correct format & add preprocessor directives * [RISC-V] Update conditions for transforming ADD(LSH) into SHXADD * [RISC-V] Update GT_SHXADD* register liveliness * [RISC-V] Guard SHXADD instruction usage with extension check * [RISC-V] Add description comments to SHXADD node and struct * [RISC-V] Add more JIT dumps * [RISC-V] Remove GenTreeShxadd and create separate nodes (SH1ADD, SH1ADD_UW, etc.) * [RISC-V] Only use SH(X)ADD when ADD is expected, not ADDW. * [RISC-V] Support add.uw instruction. * [RISC-V] Utilize ADD.UW for zero extension * [RISC-V] Support slli.uw instruction * [RISC-V] Refactor * [RISC-V] Fix missed optimization: contain slli.uw into sh(x)add.uw * [RISC-V] Remove repeated directive condition Co-authored-by: Bruce Forstall <[email protected]> * [RISC-V] Update comment to reflect changes --------- Co-authored-by: Bruce Forstall <[email protected]>
1 parent 1e3c798 commit 7e297e6

11 files changed

+637
-31
lines changed

src/coreclr/jit/codegen.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,13 @@ class CodeGen final : public CodeGenInterface
841841
int scale RISCV64_ARG(regNumber scaleTempReg));
842842
#endif // TARGET_ARMARCH || TARGET_LOONGARCH64 || TARGET_RISCV64
843843

844+
#if defined(TARGET_RISCV64)
845+
void genCodeForShxadd(GenTreeOp* tree);
846+
void genCodeForAddUw(GenTreeOp* tree);
847+
void genCodeForSlliUw(GenTreeOp* tree);
848+
instruction getShxaddVariant(int scale, bool useUnsignedVariant);
849+
#endif
850+
844851
#if defined(TARGET_ARMARCH)
845852
void genCodeForMulLong(GenTreeOp* mul);
846853
#endif // TARGET_ARMARCH

src/coreclr/jit/codegenriscv64.cpp

Lines changed: 137 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2775,6 +2775,30 @@ instruction CodeGen::genGetInsForOper(GenTree* treeNode)
27752775
}
27762776
break;
27772777

2778+
case GT_SH1ADD:
2779+
ins = INS_sh1add;
2780+
break;
2781+
2782+
case GT_SH2ADD:
2783+
ins = INS_sh2add;
2784+
break;
2785+
2786+
case GT_SH3ADD:
2787+
ins = INS_sh3add;
2788+
break;
2789+
2790+
case GT_SH1ADD_UW:
2791+
ins = INS_sh1add_uw;
2792+
break;
2793+
2794+
case GT_SH2ADD_UW:
2795+
ins = INS_sh2add_uw;
2796+
break;
2797+
2798+
case GT_SH3ADD_UW:
2799+
ins = INS_sh3add_uw;
2800+
break;
2801+
27782802
case GT_XOR_NOT:
27792803
assert(compiler->compOpportunisticallyDependsOn(InstructionSet_Zbb));
27802804
assert(!isImmed(treeNode));
@@ -4584,6 +4608,23 @@ void CodeGen::genCodeForTreeNode(GenTree* treeNode)
45844608
// Do nothing; these nodes are simply markers for debug info.
45854609
break;
45864610

4611+
case GT_SH1ADD:
4612+
case GT_SH1ADD_UW:
4613+
case GT_SH2ADD:
4614+
case GT_SH2ADD_UW:
4615+
case GT_SH3ADD:
4616+
case GT_SH3ADD_UW:
4617+
genCodeForShxadd(treeNode->AsOp());
4618+
break;
4619+
4620+
case GT_ADD_UW:
4621+
genCodeForAddUw(treeNode->AsOp());
4622+
break;
4623+
4624+
case GT_SLLI_UW:
4625+
genCodeForSlliUw(treeNode->AsOp());
4626+
break;
4627+
45874628
default:
45884629
{
45894630
#ifdef DEBUG
@@ -5624,7 +5665,16 @@ void CodeGen::genCodeForIndexAddr(GenTreeIndexAddr* node)
56245665
// dest = base + (index << scale)
56255666
if (node->gtElemSize <= 64)
56265667
{
5627-
genScaledAdd(attr, node->GetRegNum(), base->GetRegNum(), index->GetRegNum(), scale, tempReg);
5668+
instruction shxaddIns = getShxaddVariant(scale, (genTypeSize(index) == 4));
5669+
5670+
if (compiler->compOpportunisticallyDependsOn(InstructionSet_Zba) && (shxaddIns != INS_none))
5671+
{
5672+
GetEmitter()->emitIns_R_R_R(shxaddIns, attr, node->GetRegNum(), index->GetRegNum(), base->GetRegNum());
5673+
}
5674+
else
5675+
{
5676+
genScaledAdd(attr, node->GetRegNum(), base->GetRegNum(), index->GetRegNum(), scale, tempReg);
5677+
}
56285678
}
56295679
else
56305680
{
@@ -6447,9 +6497,15 @@ void CodeGen::genIntToIntCast(GenTreeCast* cast)
64476497
}
64486498

64496499
case GenIntCastDesc::ZERO_EXTEND_INT:
6450-
6451-
emit->emitIns_R_R_I(INS_slli, EA_PTRSIZE, dstReg, srcReg, 32);
6452-
emit->emitIns_R_R_I(INS_srli, EA_PTRSIZE, dstReg, dstReg, 32);
6500+
if (compiler->compOpportunisticallyDependsOn(InstructionSet_Zba))
6501+
{
6502+
emit->emitIns_R_R_R(INS_add_uw, EA_PTRSIZE, dstReg, srcReg, REG_R0);
6503+
}
6504+
else
6505+
{
6506+
emit->emitIns_R_R_I(INS_slli, EA_PTRSIZE, dstReg, srcReg, 32);
6507+
emit->emitIns_R_R_I(INS_srli, EA_PTRSIZE, dstReg, dstReg, 32);
6508+
}
64536509
break;
64546510
case GenIntCastDesc::SIGN_EXTEND_INT:
64556511
emit->emitIns_R_R_I(INS_slliw, EA_4BYTE, dstReg, srcReg, 0);
@@ -6737,6 +6793,83 @@ void CodeGen::genLeaInstruction(GenTreeAddrMode* lea)
67376793
genProduceReg(lea);
67386794
}
67396795

6796+
instruction CodeGen::getShxaddVariant(int scale, bool useUnsignedVariant)
6797+
{
6798+
if (useUnsignedVariant)
6799+
{
6800+
switch (scale)
6801+
{
6802+
case 1:
6803+
return INS_sh1add_uw;
6804+
case 2:
6805+
return INS_sh2add_uw;
6806+
case 3:
6807+
return INS_sh3add_uw;
6808+
}
6809+
}
6810+
else
6811+
{
6812+
switch (scale)
6813+
{
6814+
case 1:
6815+
return INS_sh1add;
6816+
case 2:
6817+
return INS_sh2add;
6818+
case 3:
6819+
return INS_sh3add;
6820+
}
6821+
}
6822+
return INS_none;
6823+
}
6824+
6825+
void CodeGen::genCodeForShxadd(GenTreeOp* tree)
6826+
{
6827+
instruction ins = genGetInsForOper(tree);
6828+
6829+
assert(ins == INS_sh1add || ins == INS_sh2add || ins == INS_sh3add || ins == INS_sh1add_uw ||
6830+
ins == INS_sh2add_uw || ins == INS_sh3add_uw);
6831+
6832+
genConsumeOperands(tree);
6833+
6834+
emitAttr attr = emitActualTypeSize(tree);
6835+
6836+
GetEmitter()->emitIns_R_R_R(ins, attr, tree->GetRegNum(), tree->gtOp1->GetRegNum(), tree->gtOp2->GetRegNum());
6837+
6838+
genProduceReg(tree);
6839+
}
6840+
6841+
void CodeGen::genCodeForAddUw(GenTreeOp* tree)
6842+
{
6843+
assert(tree->gtOper == GT_ADD_UW);
6844+
6845+
genConsumeOperands(tree);
6846+
6847+
emitAttr attr = emitActualTypeSize(tree);
6848+
6849+
GetEmitter()->emitIns_R_R_R(INS_add_uw, attr, tree->GetRegNum(), tree->gtOp1->GetRegNum(),
6850+
tree->gtOp2->GetRegNum());
6851+
6852+
genProduceReg(tree);
6853+
}
6854+
6855+
void CodeGen::genCodeForSlliUw(GenTreeOp* tree)
6856+
{
6857+
assert(tree->gtOper == GT_SLLI_UW);
6858+
6859+
genConsumeOperands(tree);
6860+
6861+
emitAttr attr = emitActualTypeSize(tree);
6862+
GenTree* shiftBy = tree->gtOp2;
6863+
6864+
assert(shiftBy->IsCnsIntOrI());
6865+
6866+
unsigned shamt = (unsigned)shiftBy->AsIntCon()->gtIconVal;
6867+
6868+
GetEmitter()->emitIns_R_R_I(INS_slli_uw, attr, tree->GetRegNum(), tree->gtOp1->GetRegNum(), shamt);
6869+
6870+
genProduceReg(tree);
6871+
}
6872+
67406873
//------------------------------------------------------------------------
67416874
// genEstablishFramePointer: Set up the frame pointer by adding an offset to the stack pointer.
67426875
//

src/coreclr/jit/emitriscv64.cpp

Lines changed: 87 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,7 @@ void emitter::emitIns_R_R_I(
728728

729729
if ((INS_addi <= ins && INS_srai >= ins) || (INS_addiw <= ins && INS_sraiw >= ins) ||
730730
(INS_lb <= ins && INS_lhu >= ins) || INS_ld == ins || INS_lw == ins || INS_jalr == ins || INS_fld == ins ||
731-
INS_flw == ins || INS_rori == ins || INS_roriw == ins)
731+
INS_flw == ins || INS_slli_uw == ins || INS_rori == ins || INS_roriw == ins)
732732
{
733733
assert(isGeneralRegister(reg2));
734734
code |= (reg1 & 0x1f) << 7; // rd
@@ -827,7 +827,7 @@ void emitter::emitIns_R_R_R(
827827
(INS_addw <= ins && ins <= INS_sraw) || (INS_fadd_s <= ins && ins <= INS_fmax_s) ||
828828
(INS_fadd_d <= ins && ins <= INS_fmax_d) || (INS_feq_s <= ins && ins <= INS_fle_s) ||
829829
(INS_feq_d <= ins && ins <= INS_fle_d) || (INS_lr_w <= ins && ins <= INS_amomaxu_d) ||
830-
(INS_rol <= ins && ins <= INS_maxu))
830+
(INS_sh1add <= ins && ins <= INS_sh3add_uw) || (INS_rol <= ins && ins <= INS_maxu))
831831
{
832832
#ifdef DEBUG
833833
switch (ins)
@@ -915,6 +915,14 @@ void emitter::emitIns_R_R_R(
915915
case INS_amomaxu_w:
916916
case INS_amomaxu_d:
917917

918+
case INS_sh1add:
919+
case INS_sh2add:
920+
case INS_sh3add:
921+
case INS_add_uw:
922+
case INS_sh1add_uw:
923+
case INS_sh2add_uw:
924+
case INS_sh3add_uw:
925+
918926
case INS_rol:
919927
case INS_rolw:
920928
case INS_ror:
@@ -3978,28 +3986,36 @@ void emitter::emitDispInsName(
39783986
emitDispImmediate(imm12, !willPrintLoadImmValue);
39793987
}
39803988
return;
3981-
case 0x1:
3989+
case 0x1: // SLLIW, SLLI.UW, CLZW, CTZW, & CPOPW
39823990
{
3991+
static constexpr unsigned kSlliwFunct7 = 0b0000000;
3992+
static constexpr unsigned kSlliUwFunct6 = 0b000010;
3993+
39833994
unsigned funct7 = (imm12 >> 5) & 0x7f;
3984-
unsigned shamt = imm12 & 0x1f; // 5 BITS for SHAMT in RISCV64
3985-
switch (funct7)
3995+
unsigned funct6 = (imm12 >> 6) & 0x3f;
3996+
// SLLIW's instruction code's upper 7 bits have to be equal to zero
3997+
if (funct7 == kSlliwFunct7)
39863998
{
3987-
case 0b0110000:
3988-
{
3989-
static const char* names[] = {"clzw ", "ctzw ", "cpopw"};
3990-
// shift amount is treated as funct additional opcode bits
3991-
if (shamt >= ARRAY_SIZE(names))
3992-
return emitDispIllegalInstruction(code);
3993-
3994-
printf("%s %s, %s\n", names[shamt], rd, rs1);
3995-
return;
3996-
}
3997-
case 0b0000000:
3998-
printf("slliw %s, %s, %d\n", rd, rs1, shamt);
3999-
return;
4000-
4001-
default:
3999+
printf("slliw %s, %s, %d\n", rd, rs1, imm12 & 0x1f); // 5 BITS for SHAMT in RISCV64
4000+
}
4001+
// SLLI.UW's instruction code's upper 6 bits have to be equal to 0b000010
4002+
else if (funct6 == kSlliUwFunct6)
4003+
{
4004+
printf("slli.uw %s, %s, %d\n", rd, rs1, imm12 & 0x3f); // 6 BITS for SHAMT in RISCV64
4005+
}
4006+
else if (funct7 == 0b0110000)
4007+
{
4008+
static const char* names[] = {"clzw ", "ctzw ", "cpopw"};
4009+
// shift amount is treated as funct additional opcode bits
4010+
unsigned shamt = imm12 & 0x1f; // 5 BITS for SHAMT in RISCV64
4011+
if (shamt >= ARRAY_SIZE(names))
40024012
return emitDispIllegalInstruction(code);
4013+
4014+
printf("%s %s, %s\n", names[shamt], rd, rs1);
4015+
}
4016+
else
4017+
{
4018+
emitDispIllegalInstruction(code);
40034019
}
40044020
}
40054021
return;
@@ -4121,6 +4137,20 @@ void emitter::emitDispInsName(
41214137
return emitDispIllegalInstruction(code);
41224138
}
41234139
return;
4140+
case 0b0010000:
4141+
switch (opcode3)
4142+
{
4143+
case 0x2: // SH1ADD
4144+
printf("sh1add %s, %s, %s\n", rd, rs1, rs2);
4145+
return;
4146+
case 0x4: // SH2ADD
4147+
printf("sh2add %s, %s, %s\n", rd, rs1, rs2);
4148+
return;
4149+
case 0x6: // SH3ADD
4150+
printf("sh3add %s, %s, %s\n", rd, rs1, rs2);
4151+
return;
4152+
}
4153+
return;
41244154
case 0b0110000:
41254155
switch (opcode3)
41264156
{
@@ -4209,6 +4239,22 @@ void emitter::emitDispInsName(
42094239
return emitDispIllegalInstruction(code);
42104240
}
42114241
return;
4242+
case 0b0010000:
4243+
switch (opcode3)
4244+
{
4245+
case 0x2: // SH1ADD.UW
4246+
printf("sh1add.uw %s, %s, %s\n", rd, rs1, rs2);
4247+
return;
4248+
case 0x4: // SH2ADD.UW
4249+
printf("sh2add.uw %s, %s, %s\n", rd, rs1, rs2);
4250+
return;
4251+
case 0x6: // SH3ADD.UW
4252+
printf("sh3add.uw %s, %s, %s\n", rd, rs1, rs2);
4253+
return;
4254+
default:
4255+
return emitDispIllegalInstruction(code);
4256+
}
4257+
return;
42124258
case 0b0110000:
42134259
switch (opcode3)
42144260
{
@@ -4223,12 +4269,28 @@ void emitter::emitDispInsName(
42234269
}
42244270
return;
42254271
case 0b0000100:
4226-
// Currently only zext.h for this opcode2.
4227-
// Note: zext.h is encoded as a pseudo for 'packw rd, rs1, zero' which is not in Zbb.
4228-
if (opcode3 != 0b100 || rs2Num != REG_ZERO)
4229-
return emitDispIllegalInstruction(code);
4272+
switch (opcode3)
4273+
{
4274+
case 0b000: // ZEXT.W & ADD.UW
4275+
if (rs2Num == REG_ZERO)
4276+
{
4277+
printf("zext.w %s, %s\n", rd, rs1);
4278+
}
4279+
else
4280+
{
4281+
printf("add.uw %s, %s, %s\n", rd, rs1, rs2);
4282+
}
4283+
return;
4284+
case 0b100: // ZEXT.H
4285+
// Note: zext.h is encoded as a pseudo for 'packw rd, rs1, zero' which is not in Zbb.
4286+
if (rs2Num != REG_ZERO)
4287+
return emitDispIllegalInstruction(code);
42304288

4231-
printf("zext.h %s, %s\n", rd, rs1);
4289+
printf("zext.h %s, %s\n", rd, rs1);
4290+
return;
4291+
default:
4292+
return emitDispIllegalInstruction(code);
4293+
}
42324294
return;
42334295

42344296
default:

src/coreclr/jit/gentree.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6526,6 +6526,9 @@ unsigned GenTree::GetScaledIndex()
65266526
case GT_MUL:
65276527
return AsOp()->gtOp2->GetScaleIndexMul();
65286528

6529+
#ifdef TARGET_RISCV64
6530+
case GT_SLLI_UW:
6531+
#endif
65296532
case GT_LSH:
65306533
return AsOp()->gtOp2->GetScaleIndexShf();
65316534

@@ -12839,7 +12842,6 @@ void Compiler::gtDispTree(GenTree* tree,
1283912842
InsCflagsToString(tree->AsCCMP()->gtFlagsVal));
1284012843
}
1284112844
#endif
12842-
1284312845
gtDispCommonEndLine(tree);
1284412846

1284512847
if (!topOnly)

0 commit comments

Comments
 (0)