Skip to content

Handle byval, allow overallocating max cache, and improve caching #120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions enzyme/Enzyme/ActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,10 @@ bool ActivityAnalyzer::isConstantValue(TypeResults &TR, Value *Val) {
// was created outside a function (e.g. global, constant), that is allowed
assert(Val);
if (auto I = dyn_cast<Instruction>(Val)) {
if (TR.info.Function != I->getParent()->getParent()) {
llvm::errs() << *TR.info.Function << "\n";
llvm::errs() << *I << "\n";
}
assert(TR.info.Function == I->getParent()->getParent());
}
if (auto Arg = dyn_cast<Argument>(Val)) {
Expand Down Expand Up @@ -487,7 +491,7 @@ bool ActivityAnalyzer::isConstantValue(TypeResults &TR, Value *Val) {
}

// All arguments must be marked constant/nonconstant ahead of time
if (isa<Argument>(Val)) {
if (isa<Argument>(Val) && !cast<Argument>(Val)->hasByValAttr()) {
llvm::errs() << *(cast<Argument>(Val)->getParent()) << "\n";
llvm::errs() << *Val << "\n";
assert(0 && "must've put arguments in constant/nonconstant");
Expand Down Expand Up @@ -656,14 +660,16 @@ bool ActivityAnalyzer::isConstantValue(TypeResults &TR, Value *Val) {
if (directions & UP) {
// If we are derived from an argument our activity is equal to the
// activity of the argument by definition
if (isa<Argument>(TmpOrig)) {
bool res = isConstantValue(TR, TmpOrig);
if (res) {
ConstantValues.insert(Val);
} else {
ActiveValues.insert(Val);
if (auto arg = dyn_cast<Argument>(TmpOrig)) {
if (!arg->hasByValAttr()) {
bool res = isConstantValue(TR, TmpOrig);
if (res) {
ConstantValues.insert(Val);
} else {
ActiveValues.insert(Val);
}
return res;
}
return res;
}

UpHypothesis =
Expand Down Expand Up @@ -1017,7 +1023,8 @@ bool ActivityAnalyzer::isConstantValue(TypeResults &TR, Value *Val) {
// UpHypothesis.ConstantValues.insert(val);
UpHypothesis->insertConstantsFrom(*Hypothesis);
assert(directions & UP);
bool ActiveUp = !UpHypothesis->isInstructionInactiveFromOrigin(TR, Val);
bool ActiveUp = !isa<Argument>(Val) &&
!UpHypothesis->isInstructionInactiveFromOrigin(TR, Val);

// Case b) can occur if:
// 1) this memory is used as part of an active return
Expand Down Expand Up @@ -1598,7 +1605,8 @@ bool ActivityAnalyzer::isValueActivelyStoredOrReturned(TypeResults &TR,
(isa<CallInst>(inst) && AA.onlyReadsMemory(cast<CallInst>(inst)))) {
// if not written to memory and returning a known constant, this
// cannot be actively returned/stored
if (isConstantValue(TR, a)) {
if (inst->getParent()->getParent() == TR.info.Function &&
isConstantValue(TR, a)) {
continue;
}
// if not written to memory and returning a value itself
Expand Down
124 changes: 94 additions & 30 deletions enzyme/Enzyme/CacheUtility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ llvm::cl::opt<bool>
EnzymePrintPerf("enzyme-print-perf", cl::init(false), cl::Hidden,
cl::desc("Enable Enzyme to print performance info"));

llvm::cl::opt<bool> EfficientMaxCache(
"enzyme-max-cache", cl::init(false), cl::Hidden,
cl::desc(
"Avoid reallocs when possible by potentially overallocating cache"));

CacheUtility::~CacheUtility() {}

/// Erase this instruction both from LLVM modules and any local data-structures
Expand All @@ -51,7 +56,8 @@ void CacheUtility::erase(Instruction *I) {
assert(ctx.second.var != I);
assert(ctx.second.incvar != I);
assert(ctx.second.antivaralloc != I);
assert(ctx.second.limit != I);
assert(ctx.second.trueLimit != I);
assert(ctx.second.maxLimit != I);
}
for (const auto &pair : scopeMap) {
if (pair.second.first == I) {
Expand Down Expand Up @@ -116,8 +122,11 @@ void CacheUtility::erase(Instruction *I) {
/// Replace this instruction both in LLVM modules and any local data-structures
void CacheUtility::replaceAWithB(Value *A, Value *B, bool storeInCache) {
for (auto &ctx : loopContexts) {
if (ctx.second.limit == A) {
ctx.second.limit = B;
if (ctx.second.maxLimit == A) {
ctx.second.maxLimit = B;
}
if (ctx.second.trueLimit == A) {
ctx.second.trueLimit = B;
}
}

Expand Down Expand Up @@ -467,8 +476,8 @@ bool CacheUtility::getContext(BasicBlock *BB, LoopContext &loopContext) {
SCEVUnionPredicate BackedgePred;

const SCEV *Limit = nullptr;
const SCEV *MaxIterations = nullptr;
{

const SCEV *MayExitMaxBECount = nullptr;

SmallVector<BasicBlock *, 8> ExitingBlocks;
Expand Down Expand Up @@ -500,33 +509,55 @@ bool CacheUtility::getContext(BasicBlock *BB, LoopContext &loopContext) {

ScalarEvolution::ExitLimit EL =
SE.computeExitLimit(L, ExitingBlock, /*AllowPredicates*/ true);
if (MayExitMaxBECount != SE.getCouldNotCompute()) {
if (!MayExitMaxBECount || EL.ExactNotTaken == SE.getCouldNotCompute())
MayExitMaxBECount = EL.ExactNotTaken;
else {
if (MayExitMaxBECount != EL.ExactNotTaken) {
llvm::errs() << MayExitMaxBECount << "\n";
if (EnzymePrintPerf)
llvm::errs() << "Missed cache optimization opportunity! could "
"allocate max!\n";
MayExitMaxBECount = SE.getCouldNotCompute();
break;

bool seenHeaders = false;
SmallPtrSet<BasicBlock *, 4> Seen;
std::deque<BasicBlock *> Todo = {ExitingBlock};
while (Todo.size()) {
auto cur = Todo.front();
Todo.pop_front();
if (Seen.count(cur))
continue;
if (!L->contains(cur))
continue;
if (cur == loopContexts[L].header) {
seenHeaders = true;
break;
}
for (auto S : successors(cur)) {
Todo.push_back(S);
}
}
if (seenHeaders) {
if (MaxIterations == nullptr ||
MaxIterations == SE.getCouldNotCompute()) {
MaxIterations = EL.ExactNotTaken;
}
if (MaxIterations != SE.getCouldNotCompute()) {
if (EL.ExactNotTaken != SE.getCouldNotCompute()) {
MaxIterations =
SE.getUMaxFromMismatchedTypes(MaxIterations, EL.ExactNotTaken);
}
}

if (MayExitMaxBECount == nullptr ||
EL.ExactNotTaken == SE.getCouldNotCompute())
MayExitMaxBECount = EL.ExactNotTaken;

MayExitMaxBECount = SE.getUMaxFromMismatchedTypes(MayExitMaxBECount,
EL.ExactNotTaken);
if (EL.ExactNotTaken != MayExitMaxBECount) {
MayExitMaxBECount = SE.getCouldNotCompute();
}
} else {
MayExitMaxBECount = SE.getCouldNotCompute();
}
}
if (ExitingBlocks.size() == 0) {
if (MayExitMaxBECount == nullptr) {
MayExitMaxBECount = SE.getCouldNotCompute();
}
if (MaxIterations == nullptr) {
MaxIterations = SE.getCouldNotCompute();
}
Limit = MayExitMaxBECount;
}
assert(Limit);

Value *LimitVar = nullptr;

if (SE.getCouldNotCompute() != Limit) {
Expand All @@ -548,11 +579,14 @@ bool CacheUtility::getContext(BasicBlock *BB, LoopContext &loopContext) {
LimitVar = Exp.expandCodeFor(Limit, CanonicalIV->getType(),
loopContexts[L].preheader->getTerminator());
loopContexts[L].dynamic = false;
loopContexts[L].maxLimit = LimitVar;
} else {
if (EnzymePrintPerf)
llvm::errs() << "SE could not compute loop limit of "
<< L->getHeader()->getName() << " "
<< L->getHeader()->getParent()->getName() << "\n";
<< L->getHeader()->getParent()->getName()
<< " lim: " << *Limit << " maxlim: " << *MaxIterations
<< "\n";

LimitVar = createCacheForScope(LimitContext(loopContexts[L].preheader),
CanonicalIV->getType(), "loopLimit",
Expand All @@ -574,8 +608,27 @@ bool CacheUtility::getContext(BasicBlock *BB, LoopContext &loopContext) {
cast<AllocaInst>(LimitVar));
}
loopContexts[L].dynamic = true;
loopContexts[L].maxLimit = nullptr;
}
loopContexts[L].trueLimit = LimitVar;
if (EfficientMaxCache && loopContexts[L].dynamic &&
SE.getCouldNotCompute() != MaxIterations) {
if (MaxIterations->getType() != CanonicalIV->getType())
MaxIterations =
SE.getZeroExtendExpr(MaxIterations, CanonicalIV->getType());

#if LLVM_VERSION_MAJOR >= 12
SCEVExpander Exp(SE, BB->getParent()->getParent()->getDataLayout(),
"enzyme");
#else
fake::SCEVExpander Exp(SE, BB->getParent()->getParent()->getDataLayout(),
"enzyme");
#endif

loopContexts[L].maxLimit =
Exp.expandCodeFor(MaxIterations, CanonicalIV->getType(),
loopContexts[L].preheader->getTerminator());
}
loopContexts[L].limit = LimitVar;

loopContext = loopContexts.find(L)->second;
return true;
Expand Down Expand Up @@ -670,7 +723,7 @@ AllocaInst *CacheUtility::createCacheForScope(LimitContext ctx, Type *T,

StoreInst *storealloc = nullptr;
// Statically allocate memory for all iterations if possible
if (!sublimits[i].second.back().first.dynamic) {
if (sublimits[i].second.back().first.maxLimit) {
auto firstallocation = CallInst::CreateMalloc(
&allocationBuilder.GetInsertBlock()->back(), size->getType(),
myType, byteSizeOfType, size, nullptr, name + "_malloccache");
Expand Down Expand Up @@ -892,7 +945,8 @@ CacheUtility::SubLimitType CacheUtility::getSubLimits(LimitContext ctx) {
idx.var = nullptr; // = zero;
idx.incvar = nullptr;
idx.antivaralloc = nullptr;
idx.limit = zero;
idx.trueLimit = zero;
idx.maxLimit = zero;
idx.header = subctx;
idx.preheader = subctx;
idx.dynamic = false;
Expand All @@ -914,7 +968,8 @@ CacheUtility::SubLimitType CacheUtility::getSubLimits(LimitContext ctx) {
blk = idx.preheader;
}
if (ompTrueLimit && contexts.size()) {
contexts.back().limit = ompTrueLimit;
contexts.back().trueLimit = ompTrueLimit;
contexts.back().maxLimit = ompTrueLimit;
}

// Legal preheaders for loop i (indexed from inner => outer)
Expand All @@ -928,7 +983,7 @@ CacheUtility::SubLimitType CacheUtility::getSubLimits(LimitContext ctx) {
// outside the loop nest
if ((unsigned)i == contexts.size() - 1) {
allocationPreheaders[i] = contexts[i].preheader;
} else if (contexts[i].dynamic) {
} else if (!contexts[i].maxLimit) {
// For dynamic loops, the preheader is now forced to be the preheader
// of that loop
allocationPreheaders[i] = contexts[i].preheader;
Expand All @@ -941,7 +996,7 @@ CacheUtility::SubLimitType CacheUtility::getSubLimits(LimitContext ctx) {
// Dynamic loops are considered to have a limit of one for allocation
// purposes This is because we want to allocate 1 x (# of iterations inside
// chunk) inside every dynamic iteration
if (contexts[i].dynamic) {
if (!contexts[i].maxLimit) {
limits[i] =
ConstantInt::get(Type::getInt64Ty(ctx.Block->getContext()), 1);
} else {
Expand All @@ -968,17 +1023,26 @@ CacheUtility::SubLimitType CacheUtility::getSubLimits(LimitContext ctx) {

// Attempt to compute the limit of this loop at the corresponding
// allocation preheader. This is null if it was not legal to compute
limitMinus1 = unwrapM(contexts[i].limit, allocationBuilder, prevMap,
limitMinus1 = unwrapM(contexts[i].maxLimit, allocationBuilder, prevMap,
UnwrapMode::AttemptFullUnwrap);

// We have a loop with static bounds, but whose limit is not available
// to be computed at the current loop preheader (such as the innermost
// loop of triangular iteration domain) Handle this case like a dynamic
// loop and create a new chunk.
if (limitMinus1 == nullptr) {
if (EnzymePrintPerf)
EmitWarning("NoOuterLimit",
cast<Instruction>(contexts[i].maxLimit)->getDebugLoc(),
newFunc,
cast<Instruction>(contexts[i].maxLimit)->getParent(),
"Could not compute outermost loop limit by moving value ",
*contexts[i].maxLimit, " computed at block",
contexts[i].header->getName(), " function ",
contexts[i].header->getParent()->getName());
allocationPreheaders[i] = contexts[i].preheader;
allocationBuilder.SetInsertPoint(&allocationPreheaders[i]->back());
limitMinus1 = unwrapM(contexts[i].limit, allocationBuilder, prevMap,
limitMinus1 = unwrapM(contexts[i].maxLimit, allocationBuilder, prevMap,
UnwrapMode::AttemptFullUnwrap);
}
assert(limitMinus1 != nullptr);
Expand Down
6 changes: 4 additions & 2 deletions enzyme/Enzyme/CacheUtility.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ struct LoopContext {

/// limit is last value of a canonical induction variable
/// iters is number of times loop is run (thus iters = limit + 1)
llvm::Value *limit;
llvm::Value *maxLimit;

llvm::Value *trueLimit;

/// All blocks this loop exits too
llvm::SmallPtrSet<llvm::BasicBlock *, 8> exitBlocks;
Expand Down Expand Up @@ -141,7 +143,7 @@ class CacheUtility {
bool isInstructionUsedInLoopInduction(llvm::Instruction &I) {
for (auto &context : loopContexts) {
if (context.second.var == &I || context.second.incvar == &I ||
context.second.limit == &I) {
context.second.maxLimit == &I || context.second.trueLimit == &I) {
return true;
}
}
Expand Down
8 changes: 6 additions & 2 deletions enzyme/Enzyme/DifferentialUseAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,13 @@ bool is_value_needed_in_reverse(
// write in reverse) or we need this value for the reverse pass (we
// conservatively assume that if legal it is recomputed and not
// stored)
IRBuilder<> IB(gutils->getNewFromOriginal(ci->getParent()));
if (!gutils->isConstantInstruction(ci) ||
!gutils->isConstantValue(
const_cast<Value *>((const Value *)ci)) ||
(ci->mayWriteToMemory() && topLevel) ||
(gutils->legalRecompute(ci, ValueToValueMapTy(), nullptr) &&
(gutils->legalRecompute(ci, ValueToValueMapTy(), &IB,
/*reverse*/ true) &&
is_value_needed_in_reverse<VT>(TR, gutils, ci, topLevel, seen,
oldUnreachable))) {
return seen[idx] = true;
Expand Down Expand Up @@ -295,10 +297,12 @@ bool is_value_needed_in_reverse(
// it may write memory and is topLevel (and thus we need to do the write
// in reverse) or we need this value for the reverse pass (we
// conservatively assume that if legal it is recomputed and not stored)
IRBuilder<> IB(gutils->getNewFromOriginal(ci->getParent()));
if (!gutils->isConstantInstruction(ci) ||
!gutils->isConstantValue(const_cast<Value *>((const Value *)ci)) ||
(ci->mayWriteToMemory() && topLevel) ||
(gutils->legalRecompute(ci, ValueToValueMapTy(), nullptr) &&
(gutils->legalRecompute(ci, ValueToValueMapTy(), &IB,
/*reverse*/ true) &&
is_value_needed_in_reverse<VT>(TR, gutils, ci, topLevel, seen,
oldUnreachable))) {
return seen[idx] = true;
Expand Down
8 changes: 3 additions & 5 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,14 +190,12 @@ bool is_load_uncacheable(
#endif
if (castinst->isCast()) {
if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) {
if (isAllocationFunction(*fn, TLI) ||
isDeallocationFunction(*fn, TLI)) {
called = fn;
}
called = fn;
}
}
}
if (called && isCertainMallocOrFree(called)) {
if (called && (isCertainMallocOrFree(called) ||
isMemFreeLibMFunction(called->getName()))) {
return false;
}
}
Expand Down
Loading