Skip to content

Commit c8cf546

Browse files
gkluczekgfxbot
authored andcommitted
Bfrev pattern match
Change-Id: I0a022e64828990818864b51339f5fb522e88cfe8
1 parent 630d16b commit c8cf546

File tree

2 files changed

+165
-11
lines changed

2 files changed

+165
-11
lines changed

IGC/Compiler/CustomSafeOptPass.cpp

Lines changed: 163 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -871,25 +871,177 @@ void GenSpecificPattern::visitSDiv(llvm::BinaryOperator& I)
871871
}
872872
}
873873

874+
/*
875+
Optimizes bit reversing pattern:
876+
877+
%and = shl i32 %0, 1
878+
%shl = and i32 %and, 0xAAAAAAAA
879+
%and2 = lshr i32 %0, 1
880+
%shr = and i32 %and2, 0x55555555
881+
%or = or i32 %shl, %shr
882+
%and3 = shl i32 %or, 2
883+
%shl4 = and i32 %and3, 0xCCCCCCCC
884+
%and5 = lshr i32 %or, 2
885+
%shr6 = and i32 %and5, 0x33333333
886+
%or7 = or i32 %shl4, %shr6
887+
%and8 = shl i32 %or7, 4
888+
%shl9 = and i32 %and8, 0xF0F0F0F0
889+
%and10 = lshr i32 %or7, 4
890+
%shr11 = and i32 %and10, 0x0F0F0F0F
891+
%or12 = or i32 %shl9, %shr11
892+
%and13 = shl i32 %or12, 8
893+
%shl14 = and i32 %and13, 0xFF00FF00
894+
%and15 = lshr i32 %or12, 8
895+
%shr16 = and i32 %and15, 0x00FF00FF
896+
%or17 = or i32 %shl14, %shr16
897+
%shl19 = shl i32 %or17, 16
898+
%shr21 = lshr i32 %or17, 16
899+
%or22 = or i32 %shl19, %shr21
900+
901+
into:
902+
903+
%or22 = call i32 @llvm.genx.GenISA.bfrev.i32(i32 %0)
904+
905+
And similarly for patterns reversing 16 and 64 bit type values.
906+
*/
907+
template <typename MaskType>
908+
void GenSpecificPattern::matchReverse(BinaryOperator &I)
909+
{
910+
using namespace llvm::PatternMatch;
911+
assert(I.getType()->isIntegerTy());
912+
Value *nextOrShl, *nextOrShr;
913+
uint64_t currentShiftShl = 0, currentShiftShr = 0;
914+
uint64_t currentMaskShl = 0, currentMaskShr = 0;
915+
auto patternBfrevFirst =
916+
m_Or(
917+
m_Shl(m_Value(nextOrShl), m_ConstantInt(currentShiftShl)),
918+
m_LShr(m_Value(nextOrShr), m_ConstantInt(currentShiftShr)));
919+
920+
auto patternBfrev =
921+
m_Or(
922+
m_And(
923+
m_Shl(m_Value(nextOrShl), m_ConstantInt(currentShiftShl)),
924+
m_ConstantInt(currentMaskShl)),
925+
m_And(
926+
m_LShr(m_Value(nextOrShr), m_ConstantInt(currentShiftShr)),
927+
m_ConstantInt(currentMaskShr)));
928+
929+
unsigned int bitWidth = std::numeric_limits<MaskType>::digits;
930+
assert(bitWidth == 16 || bitWidth == 32 || bitWidth == 64);
931+
932+
unsigned int currentShift = bitWidth / 2;
933+
// First mask is a value with all upper half bits present.
934+
MaskType mask = std::numeric_limits<MaskType>::max() << currentShift;
935+
936+
bool isBfrevMatchFound = false;
937+
nextOrShl = &I;
938+
if (match(nextOrShl, patternBfrevFirst) &&
939+
nextOrShl == nextOrShr &&
940+
currentShiftShl == currentShift &&
941+
currentShiftShr == currentShift)
942+
{
943+
// NextOrShl is assigned to next one by match().
944+
currentShift /= 2;
945+
// Constructing next mask to match.
946+
mask ^= mask >> currentShift;
947+
}
948+
949+
while (currentShift > 0)
950+
{
951+
if (match(nextOrShl, patternBfrev) &&
952+
nextOrShl == nextOrShr &&
953+
currentShiftShl == currentShift &&
954+
currentShiftShr == currentShift &&
955+
currentMaskShl == mask &&
956+
currentMaskShr == (MaskType)~mask)
957+
{
958+
// NextOrShl is assigned to next one by match().
959+
if (currentShift == 1)
960+
{
961+
isBfrevMatchFound = true;
962+
break;
963+
}
964+
965+
currentShift /= 2;
966+
// Constructing next mask to match.
967+
mask ^= mask >> currentShift;
968+
}
969+
else
970+
{
971+
break;
972+
}
973+
}
974+
975+
if (isBfrevMatchFound)
976+
{
977+
llvm::IRBuilder<> builder(&I);
978+
Function *bfrevFunc = GenISAIntrinsic::getDeclaration(
979+
I.getParent()->getParent()->getParent(), GenISAIntrinsic::GenISA_bfrev, builder.getInt32Ty());
980+
if (bitWidth == 16)
981+
{
982+
Value* zext = builder.CreateZExt(nextOrShl, builder.getInt32Ty());
983+
Value* bfrev = builder.CreateCall(bfrevFunc, zext);
984+
Value* lshr = builder.CreateLShr(bfrev, 16);
985+
Value* trunc = builder.CreateTrunc(lshr, I.getType());
986+
I.replaceAllUsesWith(trunc);
987+
}
988+
else if (bitWidth == 32)
989+
{
990+
Value* bfrev = builder.CreateCall(bfrevFunc, nextOrShl);
991+
I.replaceAllUsesWith(bfrev);
992+
}
993+
else
994+
{ // bitWidth == 64
995+
Value* int32Source = builder.CreateBitCast(nextOrShl, llvm::VectorType::get(builder.getInt32Ty(), 2));
996+
Value* extractElement0 = builder.CreateExtractElement(int32Source, builder.getInt32(0));
997+
Value* extractElement1 = builder.CreateExtractElement(int32Source, builder.getInt32(1));
998+
Value* bfrevLow = builder.CreateCall(bfrevFunc, extractElement0);
999+
Value* bfrevHigh = builder.CreateCall(bfrevFunc, extractElement1);
1000+
Value* bfrev64Result = llvm::UndefValue::get(int32Source->getType());
1001+
bfrev64Result = builder.CreateInsertElement(bfrev64Result, bfrevHigh, builder.getInt32(0));
1002+
bfrev64Result = builder.CreateInsertElement(bfrev64Result, bfrevLow, builder.getInt32(1));
1003+
Value* bfrevBitcast = builder.CreateBitCast(bfrev64Result, I.getType());
1004+
I.replaceAllUsesWith(bfrevBitcast);
1005+
}
1006+
}
1007+
}
1008+
8741009
void GenSpecificPattern::visitBinaryOperator(BinaryOperator &I)
8751010
{
876-
/*
877-
llvm changes ADD to OR when possible, and this optimization changes it back and allow 2 ADDs to merge.
878-
This can avoid scattered read for constant buffer when the index is calculated by shl + or + add.
1011+
if (I.getOpcode() == Instruction::Or)
1012+
{
1013+
using namespace llvm::PatternMatch;
8791014

880-
ex:
881-
from
1015+
if (I.getType()->isIntegerTy())
1016+
{
1017+
unsigned int bitWidth = cast<IntegerType>(I.getType())->getBitWidth();
1018+
switch (bitWidth)
1019+
{
1020+
case 16:
1021+
matchReverse<unsigned short>(I);
1022+
break;
1023+
case 32:
1024+
matchReverse<unsigned int>(I);
1025+
break;
1026+
case 64:
1027+
matchReverse<unsigned long long>(I);
1028+
break;
1029+
}
1030+
}
1031+
1032+
/*
1033+
llvm changes ADD to OR when possible, and this optimization changes it back and allow 2 ADDs to merge.
1034+
This can avoid scattered read for constant buffer when the index is calculated by shl + or + add.
1035+
1036+
ex:
1037+
from
8821038
%22 = shl i32 %14, 2
8831039
%23 = or i32 %22, 3
8841040
%24 = add i32 %23, 16
885-
to
1041+
to
8861042
%22 = shl i32 %14, 2
8871043
%23 = add i32 %22, 19
888-
*/
889-
890-
if (I.getOpcode() == Instruction::Or)
891-
{
892-
using namespace llvm::PatternMatch;
1044+
*/
8931045
Value *AndOp1, *EltOp1;
8941046
auto pattern1 = m_Or(
8951047
m_And(m_Value(AndOp1), m_SpecificInt(0xFFFFFFFF)),

IGC/Compiler/CustomSafeOptPass.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ namespace IGC
127127
void visitIntToPtr(llvm::IntToPtrInst& I);
128128
void visitSDiv(llvm::BinaryOperator& I);
129129
void visitTruncInst(llvm::TruncInst &I);
130+
131+
template <typename MaskType> void matchReverse(llvm::BinaryOperator &I);
130132
};
131133

132134
class IGCConstProp : public llvm::FunctionPass

0 commit comments

Comments
 (0)