Skip to content

Commit 72d7f6f

Browse files
PawelJurekigcbot
authored andcommitted
Handle bfloat fneg instruction
Modifier is not supported with :bf type in vISA and HW. This change adds legalization for this case.
1 parent eddfc14 commit 72d7f6f

11 files changed

+58
-37
lines changed

IGC/Compiler/CISACodeGen/PatternMatchPass.cpp

+13-13
Original file line numberDiff line numberDiff line change
@@ -3015,11 +3015,11 @@ namespace IGC
30153015
if (cmpInst->getOperand(0)->getType()->getPrimitiveSizeInBits() == I.getType()->getPrimitiveSizeInBits())
30163016
{
30173017
CmpSextPattern* pattern = new (m_allocator) CmpSextPattern();
3018-
bool supportModifer = SupportsModifier(cmpInst);
3018+
bool supportModifier = SupportsModifier(cmpInst, m_Platform);
30193019

30203020
pattern->inst = cmpInst;
3021-
pattern->sources[0] = GetSource(cmpInst->getOperand(0), supportModifer, false, IsSourceOfSample(&I));
3022-
pattern->sources[1] = GetSource(cmpInst->getOperand(1), supportModifer, false, IsSourceOfSample(&I));
3021+
pattern->sources[0] = GetSource(cmpInst->getOperand(0), supportModifier, false, IsSourceOfSample(&I));
3022+
pattern->sources[1] = GetSource(cmpInst->getOperand(1), supportModifier, false, IsSourceOfSample(&I));
30233023
AddPattern(pattern);
30243024
match = true;
30253025
}
@@ -3362,21 +3362,21 @@ namespace IGC
33623362
return SkipCanonicalize(src);
33633363
});
33643364
}
3365-
bool supportModiferSrc0 = SupportsModifier(&I);
3365+
bool supportModifierSrc0 = SupportsModifier(&I, m_Platform);
33663366
bool supportRegioning = SupportsRegioning(&I);
33673367
llvm::Instruction* src0Inst = llvm::dyn_cast<llvm::Instruction>(I.getOperand(0));
33683368
if (I.getOpcode() == llvm::Instruction::UDiv && src0Inst && src0Inst->getOpcode() == llvm::Instruction::Sub) {
3369-
supportModiferSrc0 = false;
3369+
supportModifierSrc0 = false;
33703370
}
3371-
pattern->sources[0] = GetSource(sources[0], supportModiferSrc0 && SupportSrc0Mod, supportRegioning, IsSourceOfSample(&I));
3371+
pattern->sources[0] = GetSource(sources[0], supportModifierSrc0 && SupportSrc0Mod, supportRegioning, IsSourceOfSample(&I));
33723372
if (nbSources > 1)
33733373
{
3374-
bool supportModiferSrc1 = SupportsModifier(&I);
3374+
bool supportModifierSrc1 = SupportsModifier(&I, m_Platform);
33753375
llvm::Instruction* src1Inst = llvm::dyn_cast<llvm::Instruction>(I.getOperand(1));
33763376
if (I.getOpcode() == llvm::Instruction::UDiv && src1Inst && src1Inst->getOpcode() == llvm::Instruction::Sub) {
3377-
supportModiferSrc1 = false;
3377+
supportModifierSrc1 = false;
33783378
}
3379-
pattern->sources[1] = GetSource(sources[1], supportModiferSrc1, supportRegioning, IsSourceOfSample(&I));
3379+
pattern->sources[1] = GetSource(sources[1], supportModifierSrc1, supportRegioning, IsSourceOfSample(&I));
33803380

33813381
// add df imm to constant pool for binary/ternary inst
33823382
// we do 64-bit int imm bigger than 32 bits, since smaller may fit in D/W
@@ -3953,7 +3953,7 @@ namespace IGC
39533953
{
39543954
GenericPointersCmpPattern* pattern = new (m_allocator) GenericPointersCmpPattern();
39553955

3956-
bool supportsMod = SupportsModifier(&I);
3956+
bool supportsMod = SupportsModifier(&I, m_Platform);
39573957
pattern->cmpSources[0] = GetSource(I.getOperand(0), supportsMod, false, IsSourceOfSample(&I));
39583958
pattern->cmpSources[1] = GetSource(I.getOperand(1), supportsMod, false, IsSourceOfSample(&I));
39593959
pattern->cmp = &I;
@@ -4540,9 +4540,9 @@ namespace IGC
45404540

45414541
CmpSelectPattern* pattern = new (m_allocator) CmpSelectPattern();
45424542
pattern->predicate = cmp->getPredicate();
4543-
bool supportsModifer = SupportsModifier(cmp);
4544-
pattern->cmpSources[0] = GetSource(cmp->getOperand(0), supportsModifer, false, IsSourceOfSample(&I));
4545-
pattern->cmpSources[1] = GetSource(cmp->getOperand(1), supportsModifer, false, IsSourceOfSample(&I));
4543+
bool supportsModifier = SupportsModifier(cmp, m_Platform);
4544+
pattern->cmpSources[0] = GetSource(cmp->getOperand(0), supportsModifier, false, IsSourceOfSample(&I));
4545+
pattern->cmpSources[1] = GetSource(cmp->getOperand(1), supportsModifier, false, IsSourceOfSample(&I));
45464546

45474547
pattern->bfnSources[1] = GetSource(selSources[0], false, false, IsSourceOfSample(&I));
45484548
pattern->bfnSources[2] = GetSource(selSources[1], false, false, IsSourceOfSample(&I));

IGC/Compiler/CISACodeGen/Platform.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -1613,6 +1613,7 @@ bool supportTriggerLargeGRFRetry() const
16131613
}
16141614

16151615

1616+
16161617
bool EnableCSWalkerPass() const
16171618
{
16181619
return isCoreChildOf(IGFX_XE2_LPG_CORE);

IGC/Compiler/CISACodeGen/helper.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -1627,7 +1627,7 @@ namespace IGC
16271627
#define DECLARE_OPCODE(instName, llvmType, name, modifiers, sat, pred, condMod, mathIntrinsic, atomicIntrinsic, regioning) \
16281628
case name:\
16291629
return modifiers;
1630-
bool SupportsModifier(llvm::Instruction* inst)
1630+
bool SupportsModifier(llvm::Instruction* inst, const IGC::CPlatform& platform)
16311631
{
16321632
// Special cases
16331633
switch (inst->getOpcode())
@@ -1647,6 +1647,9 @@ namespace IGC
16471647
default:
16481648
break;
16491649
}
1650+
if (IGCLLVM::isBFloatTy(inst->getType())) {
1651+
return false;
1652+
}
16501653

16511654
switch (GetOpCode(inst))
16521655
{

IGC/Compiler/CISACodeGen/helper.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ namespace IGC
9696
#undef DECLARE_OPCODE
9797

9898
EOPCODE GetOpCode(const llvm::Instruction* inst);
99-
bool SupportsModifier(llvm::Instruction* inst);
99+
bool SupportsModifier(llvm::Instruction* inst, const CPlatform& platform);
100100
bool SupportsSaturate(llvm::Instruction* inst);
101101
bool SupportsPredicate(llvm::Instruction* inst);
102102
bool SupportsCondModifier(llvm::Instruction* inst);

IGC/Compiler/CustomSafeOptPass.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -3609,9 +3609,12 @@ void GenSpecificPattern::visitFNeg(llvm::UnaryOperator& I)
36093609
// and adds source modifier for this region/value.
36103610

36113611
IRBuilder<> builder(&I);
3612-
36133612
Value* fsub = nullptr;
36143613

3614+
if (IGCLLVM::isBFloatTy(I.getType())) {
3615+
return;
3616+
}
3617+
36153618
if (!I.getType()->isVectorTy())
36163619
{
36173620
fsub = builder.CreateFSub(ConstantFP::get(I.getType(), 0.0f), I.getOperand(0));

IGC/Compiler/HandleFRemInstructions.cpp

+5-7
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,14 @@ void HandleFRemInstructions::visitFRem(llvm::BinaryOperator& I)
5454
auto TypeWidth = ScalarType->getScalarSizeInBits();
5555
FpTypeStr = "f" + std::to_string(TypeWidth);
5656
}
57-
#if LLVM_VERSION_MAJOR >= 14
58-
else if (ScalarType->isBFloatTy())
57+
else if (IGCLLVM::isBFloatTy(ScalarType))
5958
{
6059
FpTypeStr = "f32";
6160
Type *FloatTy = Type::getFloatTy(I.getContext());
6261
ValType = ValType->isVectorTy()
63-
? VectorType::get(FloatTy, cast<VectorType>(ValType))
62+
? IGCLLVM::FixedVectorType::get(
63+
FloatTy,
64+
(unsigned)cast<IGCLLVM::FixedVectorType>(ValType)->getNumElements())
6465
: FloatTy;
6566

6667
auto Val1Float = new FPExtInst(Val1, ValType, "", &I);
@@ -70,7 +71,6 @@ void HandleFRemInstructions::visitFRem(llvm::BinaryOperator& I)
7071
Val1 = Val1Float;
7172
Val2 = Val2Float;
7273
}
73-
#endif
7474
else
7575
{
7676
IGC_ASSERT_MESSAGE(0, "Unsupported type");
@@ -96,13 +96,11 @@ void HandleFRemInstructions::visitFRem(llvm::BinaryOperator& I)
9696
auto Callee = m_module->getOrInsertFunction(FuncName, FT);
9797
SmallVector<Value*, 2> FuncArgs{ Val1, Val2 };
9898
Instruction* NewFRem = CallInst::Create(Callee, FuncArgs, "");
99-
#if LLVM_VERSION_MAJOR >= 14
100-
if (ScalarType->isBFloatTy()) {
99+
if (IGCLLVM::isBFloatTy(ScalarType)) {
101100
NewFRem->insertBefore(&I);
102101
NewFRem->setDebugLoc(I.getDebugLoc());
103102
NewFRem = new FPTruncInst(NewFRem, I.getOperand(0)->getType());
104103
}
105-
#endif
106104
ReplaceInstWithInst(&I, NewFRem);
107105
m_changed = true;
108106
}

IGC/Compiler/LegalizationPass.cpp

+23-10
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,26 @@ void Legalization::visitInstruction(llvm::Instruction& I)
177177
m_ctx->m_instrTypes.numLocalInsts++;
178178
}
179179

180+
#if LLVM_VERSION_MAJOR >= 14
181+
void Legalization::visitFNeg(llvm::UnaryOperator &I) {
182+
if (IGCLLVM::isBFloatTy(I.getType())) {
183+
m_builder->SetInsertPoint(&I);
184+
auto ExtendedOp = m_builder->CreateFPExt(
185+
I.getOperand(0), Type::getFloatTy(I.getContext()));
186+
auto FloatFneg = m_builder->CreateFNeg(ExtendedOp);
187+
auto Res = m_builder->CreateFPTrunc(FloatFneg, I.getType());
188+
189+
cast<Instruction>(ExtendedOp)->setDebugLoc(I.getDebugLoc());
190+
cast<Instruction>(FloatFneg)->setDebugLoc(I.getDebugLoc());
191+
cast<Instruction>(Res)->setDebugLoc(I.getDebugLoc());
192+
193+
I.replaceAllUsesWith(Res);
194+
m_instructionsToRemove.push_back(&I);
195+
}
196+
m_ctx->m_instrTypes.numInsts++;
197+
}
198+
#endif // LLVM_VERSION_MAJOR >= 14
199+
180200
void Legalization::visitBinaryOperator(llvm::BinaryOperator& I)
181201
{
182202
if (I.getOpcode() == Instruction::FRem)
@@ -2680,11 +2700,7 @@ static bool isCandidateFDiv(Instruction* Inst)
26802700
return false;
26812701

26822702
Type* Ty = Inst->getType();
2683-
if (!Ty->isFloatTy() && !Ty->isHalfTy()
2684-
#if LLVM_VERSION_MAJOR >= 14
2685-
&& !Ty->isBFloatTy()
2686-
#endif
2687-
)
2703+
if (!Ty->isFloatTy() && !Ty->isHalfTy() && !IGCLLVM::isBFloatTy(Ty))
26882704
return false;
26892705

26902706
auto Op = dyn_cast<FPMathOperator>(Inst);
@@ -2739,11 +2755,8 @@ bool IGC::expandFDIVInstructions(llvm::Function& F)
27392755
Value* Y = Inst->getOperand(1);
27402756
Value* V = nullptr;
27412757

2742-
if (Inst->getType()->isHalfTy()
2743-
#if LLVM_VERSION_MAJOR >= 14
2744-
|| Inst->getType()->isBFloatTy()
2745-
#endif
2746-
) {
2758+
if (Inst->getType()->isHalfTy() ||
2759+
IGCLLVM::isBFloatTy(Inst->getType())) {
27472760
if (Inst->hasAllowReciprocal()) {
27482761
APFloat Val(1.0f);
27492762
bool ignored;

IGC/Compiler/LegalizationPass.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ namespace IGC
8181
void visitBitCastInst(llvm::BitCastInst& I);
8282
void visitBasicBlock(llvm::BasicBlock& BB);
8383
void visitTruncInst(llvm::TruncInst&);
84+
#if LLVM_VERSION_MAJOR >= 14
85+
void visitFNeg(llvm::UnaryOperator &I);
86+
#endif
8487
void visitBinaryOperator(llvm::BinaryOperator& I);
8588
void visitAddrSpaceCastInst(llvm::AddrSpaceCastInst&);
8689

visa/Common_BinaryEncoding.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -2402,7 +2402,7 @@ inline bool BinaryEncodingBase::uncompactOneInstruction(G4_INST *inst) {
24022402
uint32_t subRegIndex100 = CompactSubRegTable.GetBits_100_096(subRegIndex);
24032403
uint32_t subRegIndex68 = CompactSubRegTable.GetBits_068_064(subRegIndex);
24042404
uint32_t subRegIndex52 = CompactSubRegTable.GetBits_052_048(subRegIndex);
2405-
uint32_t condModifer = GetCondModifier(mybin);
2405+
uint32_t condModifier = GetCondModifier(mybin);
24062406
uint32_t accWrCtrl = GetCmpAccWrCtrl(mybin);
24072407
uint32_t flagSubRegNum = GetCmpFlagSubRegNum(mybin);
24082408
uint32_t bits88 = CompactSourceTable.GetBits_088_077(src0Index);
@@ -2420,7 +2420,7 @@ inline bool BinaryEncodingBase::uncompactOneInstruction(G4_INST *inst) {
24202420
mybin->SetBits(100, 96, subRegIndex100);
24212421
mybin->SetBits(68, 64, subRegIndex68);
24222422
mybin->SetBits(52, 48, subRegIndex52);
2423-
SetCondModifier(mybin, condModifer);
2423+
SetCondModifier(mybin, condModifier);
24242424
SetAccWrCtrl(mybin, accWrCtrl);
24252425
SetFlagRegNum(mybin, flagSubRegNum);
24262426
SetCompactCtrl(mybin, 0); // uncompaction

visa/G4_Operand.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ class G4_Operand {
181181
// lb = bit offset of the first flag bit
182182
// rb = bit offset of the last flag bit
183183
// (rb - lb) < 32 always holds for flags
184-
// for predicate and conditonal modifers, the bounds are also effected by the
184+
// for predicate and conditonal modifiers, the bounds are also effected by the
185185
// quarter control
186186
uint16_t left_bound;
187187
uint16_t right_bound;

visa/HWConformity.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,7 @@ bool HWConformity::fixMathInst(INST_LIST_ITER it, G4_BB *bb) {
589589
G4_SrcRegRegion *srcRegion = src->asSrcRegRegion();
590590
const RegionDesc *rd = srcRegion->getRegion();
591591
if (srcRegion->getModifier() != Mod_src_undef && isIntDivide) {
592-
// no source modifer for int divide
592+
// no source modifier for int divide
593593
return true;
594594
} else if (srcRegion->getRegAccess() != Direct) {
595595
return true;

0 commit comments

Comments
 (0)