From cb18b16485d1fb544d26eaa45de43023e94ef893 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Mon, 22 Apr 2024 10:39:41 +0900 Subject: [PATCH] Add debug info to fprt runtime calls Also log op name Fix tests Fix file location Fix older llvm vers Handle casts and emit warnings only emit warning/error if trunc from type is used Clean up fprt headers Add unary op [Truncate] Clean up FPRT headers Fix tracing runtime [Truncate] Fix bug with fcmp handling We did not pass the truncated values to the runtime [Truncate] Corrently handle constant returns Unify logging fix Refactor printing in Trace.cpp Better diagnostics PHINode fixes Only need them in mem mode Trace inputs Fix tracing Handle stores of const floats fprt improvements ADAPT-style analysis Limit fps to handle for propagation Assign dbg info to inlineable callsites Set correct fprt_original linkage Fix debug metadata issues Always provide valid location string to FPRT Explicitly remap originalToNew when we dont RAUW Add missing checks for fp type Fix wrong fma function name Add __*_finite versions of derivatives fix Some fixes Add used attribute so that mpfr.h can be compiled as a library --- enzyme/Enzyme/Enzyme.cpp | 3 +- enzyme/Enzyme/EnzymeLogic.cpp | 292 ++++++++--- enzyme/Enzyme/EnzymeLogic.h | 6 + enzyme/Enzyme/Runtimes/FPRT/Trace.cpp | 461 +++++++++++++++--- enzyme/include/enzyme/fprt/flops.def | 4 + enzyme/include/enzyme/fprt/fprt.h | 56 +++ enzyme/include/enzyme/fprt/mpfr-test.h | 271 ++++++++++ enzyme/include/enzyme/fprt/mpfr.h | 105 ++-- enzyme/test/Enzyme/ForwardMode/hypot.ll | 13 + enzyme/test/Enzyme/Truncate/cmp.ll | 2 +- enzyme/test/Enzyme/Truncate/const.ll | 6 +- enzyme/test/Enzyme/Truncate/intrinsic.ll | 24 +- enzyme/test/Enzyme/Truncate/simple.ll | 6 +- enzyme/test/Enzyme/Truncate/value.ll | 4 +- enzyme/test/Integration/Truncate/simple.cpp | 64 ++- .../Truncate/truncate-all-header.h | 15 + .../Integration/Truncate/truncate-all.cpp | 35 +- enzyme/test/Integration/Truncate/warnings.cpp | 62 +++ enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 15 +- 19 files changed, 1215 insertions(+), 229 deletions(-) create mode 100644 enzyme/include/enzyme/fprt/fprt.h create mode 100644 enzyme/include/enzyme/fprt/mpfr-test.h create mode 100644 enzyme/test/Integration/Truncate/truncate-all-header.h create mode 100644 enzyme/test/Integration/Truncate/warnings.cpp diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index ede1cbd1a52e..58ee7145da8c 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -24,6 +24,7 @@ // //===----------------------------------------------------------------------===// #include +#include #include #if LLVM_VERSION_MAJOR >= 16 @@ -2217,7 +2218,7 @@ class EnzymeBase { #endif RemapFunction(F, Mapping, RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); - TruncatedFunc->deleteBody(); + TruncatedFunc->eraseFromParent(); } return true; } diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 7cb75a72932d..d5f011495dd0 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -42,6 +42,10 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/Support/ErrorHandling.h" #include +#include +#include +#include +#include #if LLVM_VERSION_MAJOR >= 16 #define private public @@ -1724,7 +1728,7 @@ void clearFunctionAttributes(Function *f) { } Attribute::AttrKind attrs[] = { #if LLVM_VERSION_MAJOR >= 17 - Attribute::NoFPClass, + Attribute::NoFPClass, #endif Attribute::NoUndef, Attribute::NonNull, @@ -2553,7 +2557,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( llvm::Attribute::AttrKind attrs[] = { #if LLVM_VERSION_MAJOR >= 17 - llvm::Attribute::NoFPClass, + llvm::Attribute::NoFPClass, #endif llvm::Attribute::NoAlias, llvm::Attribute::NoUndef, @@ -4923,6 +4927,8 @@ class TruncateUtils { Type *fromType; Type *toType; LLVMContext &ctx; + EnzymeLogic &Logic; + Value *UnknownLoc; private: std::string getOriginalFPRTName(std::string Name) { @@ -4946,7 +4952,7 @@ class TruncateUtils { ArgTypes.push_back(Arg->getType()); FunctionType *FnTy = FunctionType::get(RetTy, ArgTypes, /*is_vararg*/ false); - F = Function::Create(FnTy, Function::ExternalLinkage, MangledName, M); + F = Function::Create(FnTy, Function::WeakODRLinkage, MangledName, M); } if (F->isDeclaration()) { BasicBlock *Entry = BasicBlock::Create(F->getContext(), "entry", F); @@ -4955,6 +4961,9 @@ class TruncateUtils { ClonedI->setOperand(It, F->getArg(It)); auto Return = ReturnInst::Create(F->getContext(), ClonedI, Entry); ClonedI->insertBefore(Return); + F->setLinkage(GlobalValue::WeakODRLinkage); + // Clear invalidated debug metadata now that we defined the function + F->clearMetadata(); } } @@ -4975,22 +4984,40 @@ class TruncateUtils { CallInst *createFPRTGeneric(llvm::IRBuilderBase &B, std::string Name, const SmallVectorImpl &ArgsIn, - llvm::Type *RetTy) { + llvm::Type *RetTy, Value *LocStr) { SmallVector Args(ArgsIn.begin(), ArgsIn.end()); Args.push_back(B.getInt64(truncation.getTo().exponentWidth)); Args.push_back(B.getInt64(truncation.getTo().significandWidth)); Args.push_back(B.getInt64(truncation.getMode())); +#if LLVM_VERSION_MAJOR <= 14 + Args.push_back(B.CreateBitCast(LocStr, NullPtr->getType())); +#else + Args.push_back(LocStr); +#endif + auto FprtFunc = getFPRTFunc(Name, Args, RetTy); - return cast(B.CreateCall(FprtFunc, Args)); + // Explicitly assign a dbg location if it didn't exist, as the FPRT + // functions are inlineable and the backend fails if the callsite does not + // have dbg metadata + // TODO consider using InstrumentationIRBuilder + Function *ContainingF = B.GetInsertBlock()->getParent(); + if (!B.getCurrentDebugLocation() && ContainingF->getSubprogram()) + B.SetCurrentDebugLocation(DILocation::get(ContainingF->getContext(), 0, 0, + ContainingF->getSubprogram())); + auto *CI = cast(B.CreateCall(FprtFunc, Args)); + + return CI; } public: - TruncateUtils(FloatTruncation truncation, Module *M) - : truncation(truncation), M(M), ctx(M->getContext()) { + TruncateUtils(FloatTruncation truncation, Module *M, EnzymeLogic &Logic) + : truncation(truncation), M(M), ctx(M->getContext()), Logic(Logic) { fromType = truncation.getFromType(ctx); toType = truncation.getToType(ctx); if (fromType == toType) assert(truncation.isToFPRT()); + + UnknownLoc = getUniquedLocStr(nullptr); } Type *getFromType() { return fromType; } @@ -5001,23 +5028,54 @@ class TruncateUtils { assert(V->getType() == getFromType()); SmallVector Args; Args.push_back(V); - return createFPRTGeneric(B, "const", Args, getToType()); + return createFPRTGeneric(B, "const", Args, getToType(), UnknownLoc); } CallInst *createFPRTNewCall(llvm::IRBuilderBase &B, Value *V) { assert(V->getType() == getFromType()); SmallVector Args; Args.push_back(V); - return createFPRTGeneric(B, "new", Args, getToType()); + return createFPRTGeneric(B, "new", Args, getToType(), UnknownLoc); } CallInst *createFPRTGetCall(llvm::IRBuilderBase &B, Value *V) { SmallVector Args; Args.push_back(V); - return createFPRTGeneric(B, "get", Args, getToType()); + return createFPRTGeneric(B, "get", Args, getToType(), UnknownLoc); } CallInst *createFPRTDeleteCall(llvm::IRBuilderBase &B, Value *V) { SmallVector Args; Args.push_back(V); - return createFPRTGeneric(B, "delete", Args, B.getVoidTy()); + return createFPRTGeneric(B, "delete", Args, B.getVoidTy(), UnknownLoc); + } + // This will result in a unique string for each location, which means the + // runtime can check whether two operations are the same with a simple pointer + // comparison. However, we need LTO for this to be the case across different + // compilation units. + GlobalValue *getUniquedLocStr(Instruction *I) { + std::string FileName = "unknown"; + unsigned LineNo = 0; + unsigned ColNo = 0; + + if (I) { + DILocation *DL = I->getDebugLoc(); + if (DL) { + FileName = DL->getFilename(); + LineNo = DL->getLine(); + ColNo = DL->getColumn(); + } + } + + auto Key = std::make_tuple(FileName, LineNo, ColNo); + auto It = Logic.UniqDebugLocStrs.find(Key); + + if (It != Logic.UniqDebugLocStrs.end()) + return It->second; + + std::string LocStr = + FileName + ":" + std::to_string(LineNo) + ":" + std::to_string(ColNo); + auto GV = createPrivateGlobalForString(*M, LocStr, true); + Logic.UniqDebugLocStrs[Key] = GV; + + return GV; } CallInst *createFPRTOpCall(llvm::IRBuilderBase &B, llvm::Instruction &I, llvm::Type *RetTy, @@ -5040,21 +5098,28 @@ class TruncateUtils { "Unexpected indirect call inst for conversion to FPRT"); } else if (auto CI = dyn_cast(&I)) { Name = "fcmp_" + std::string(CI->getPredicateName(CI->getPredicate())); + } else if (auto UO = dyn_cast(&I)) { + Name = "unaryop_" + std::string(UO->getOpcodeName()); } else { llvm_unreachable("Unexpected instruction for conversion to FPRT"); } createOriginalFPRTFunc(I, Name, ArgsIn, RetTy); - return createFPRTGeneric(B, Name, ArgsIn, RetTy); + return createFPRTGeneric(B, Name, ArgsIn, RetTy, getUniquedLocStr(&I)); } }; +// TODO we need to handle cases where constant aggregates are used and they +// contain constant fp's in them. +// +// e.g. store {0 : i64, 1.0: f64} %ptr +// +// Currently in mem mode the float will remain unconverted and we will likely +// crash somewhere. class TruncateGenerator : public llvm::InstVisitor, public TruncateUtils { private: ValueToValueMapTy &originalToNewFn; FloatTruncation truncation; - Function *oldFunc; - Function *newFunc; TruncateMode mode; EnzymeLogic &Logic; LLVMContext &ctx; @@ -5063,36 +5128,43 @@ class TruncateGenerator : public llvm::InstVisitor, TruncateGenerator(ValueToValueMapTy &originalToNewFn, FloatTruncation truncation, Function *oldFunc, Function *newFunc, EnzymeLogic &Logic) - : TruncateUtils(truncation, newFunc->getParent()), + : TruncateUtils(truncation, newFunc->getParent(), Logic), originalToNewFn(originalToNewFn), truncation(truncation), - oldFunc(oldFunc), newFunc(newFunc), mode(truncation.getMode()), - Logic(Logic), ctx(newFunc->getContext()) {} + mode(truncation.getMode()), Logic(Logic), ctx(newFunc->getContext()) {} - void checkHandled(llvm::Instruction &inst) { - // TODO - // if (all_of(inst.getOperandList(), - // [&](Use *use) { return use->get()->getType() == fromType; })) - // todo(inst); - } + void todo(llvm::Instruction &I) { + if (all_of(I.operands(), + [&](Use &U) { return U.get()->getType() != fromType; }) && + I.getType() != fromType) + return; - // TODO - void handleTrunc(); - void hendleIntToFloat(); - void handleFloatToInt(); + switch (mode) { + case TruncMemMode: + llvm::errs() << I << "\n"; + EmitFailure("FPEscaping", I.getDebugLoc(), &I, "FP value escapes!"); + break; + case TruncOpMode: + case TruncOpFullModuleMode: + EmitWarning( + "UnhandledTrunc", I, + "Operation not handled - it will be executed in the original way.", + I); + break; + default: + llvm_unreachable("Unknown trunc mode"); + } + } - void visitInstruction(llvm::Instruction &inst) { + void visitInstruction(llvm::Instruction &I) { using namespace llvm; - // TODO explicitly handle all instructions rather than using the catch all - // below - - switch (inst.getOpcode()) { + switch (I.getOpcode()) { // #include "InstructionDerivatives.inc" default: break; } - checkHandled(inst); + todo(I); } Value *truncate(IRBuilder<> &B, Value *v) { @@ -5119,17 +5191,24 @@ class TruncateGenerator : public llvm::InstVisitor, llvm_unreachable("Unknown trunc mode"); } - void todo(llvm::Instruction &I) { - std::string s; - llvm::raw_string_ostream ss(s); - ss << "cannot handle unknown instruction\n" << I; - if (CustomErrorHandler) { - IRBuilder<> Builder2(getNewFromOriginal(&I)); - CustomErrorHandler(ss.str().c_str(), wrap(&I), ErrorType::NoTruncate, - this, nullptr, wrap(&Builder2)); + void visitUnaryOperator(UnaryOperator &I) { + switch (I.getOpcode()) { + case UnaryOperator::FNeg: { + if (I.getOperand(0)->getType() != getFromType()) + return; + + auto newI = getNewFromOriginal(&I); + IRBuilder<> B(newI); + SmallVector Args = {newI->getOperand(0)}; + auto nres = createFPRTOpCall(B, I, newI->getType(), Args); + nres->takeName(newI); + nres->copyIRFlags(newI); + newI->replaceAllUsesWith(nres); + newI->eraseFromParent(); return; - } else { - EmitFailure("NoTruncate", I.getDebugLoc(), &I, ss.str()); + } + default: + todo(I); return; } } @@ -5150,8 +5229,8 @@ class TruncateGenerator : public llvm::InstVisitor, auto truncRHS = truncate(B, RHS); SmallVector Args; - Args.push_back(LHS); - Args.push_back(RHS); + Args.push_back(truncLHS); + Args.push_back(truncRHS); Instruction *nres; if (truncation.isToFPRT()) nres = createFPRTOpCall(B, CI, B.getInt1Ty(), Args); @@ -5179,13 +5258,32 @@ class TruncateGenerator : public llvm::InstVisitor, SI.isVolatile(), SI.getOrdering(), SI.getSyncScopeID(), /*mask=*/nullptr); } + // TODO Is there a possibility we GEP a const and get a FP value? void visitGetElementPtrInst(llvm::GetElementPtrInst &gep) { return; } - void visitPHINode(llvm::PHINode &phi) { return; } void visitCastInst(llvm::CastInst &CI) { + // TODO Try to follow fps through trunc/exts switch (mode) { case TruncMemMode: { - if (CI.getSrcTy() == getFromType() || CI.getDestTy() == getFromType()) - todo(CI); + auto newI = getNewFromOriginal(&CI); + auto newSrc = newI->getOperand(0); + if (CI.getSrcTy() == getFromType()) { + IRBuilder<> B(newI); + if (isa(newSrc)) + return; + newI->setOperand(0, createFPRTGetCall(B, newSrc)); + EmitWarning("FPNoFollow", CI, "Will not follow FP through this cast.", + CI); + } else if (CI.getDestTy() == getFromType()) { + IRBuilder<> B(newI->getNextNode()); + EmitWarning("FPNoFollow", CI, "Will not follow FP through this cast.", + CI); + auto nres = createFPRTNewCall(B, newI); + nres->takeName(newI); + nres->copyIRFlags(newI); + newI->replaceUsesWithIf(nres, + [&](Use &U) { return U.getUser() != nres; }); + originalToNewFn[const_cast(cast(&CI))] = nres; + } return; } case TruncOpMode: @@ -5196,6 +5294,8 @@ class TruncateGenerator : public llvm::InstVisitor, void visitSelectInst(llvm::SelectInst &SI) { switch (mode) { case TruncMemMode: { + if (SI.getType() != getFromType()) + return; auto newI = getNewFromOriginal(&SI); IRBuilder<> B(newI); auto newT = truncate(B, getNewFromOriginal(SI.getTrueValue())); @@ -5329,11 +5429,31 @@ class TruncateGenerator : public llvm::InstVisitor, newI->eraseFromParent(); return true; } + void visitIntrinsicInst(llvm::IntrinsicInst &II) { handleIntrinsic(II, II.getIntrinsicID()); } - void visitReturnInst(llvm::ReturnInst &I) { return; } + void visitReturnInst(llvm::ReturnInst &I) { + switch (mode) { + case TruncMemMode: { + if (I.getNumOperands() == 0) + return; + if (I.getReturnValue()->getType() != getFromType()) + return; + auto newI = cast(getNewFromOriginal(&I)); + IRBuilder<> B(newI); + if (isa(newI->getOperand(0))) + newI->setOperand(0, createFPRTConstCall(B, newI->getReturnValue())); + return; + } + case TruncOpMode: + case TruncOpFullModuleMode: + break; + default: + llvm_unreachable("Unknown trunc mode"); + } + } void visitBranchInst(llvm::BranchInst &I) { return; } void visitSwitchInst(llvm::SwitchInst &I) { return; } @@ -5348,6 +5468,23 @@ class TruncateGenerator : public llvm::InstVisitor, llvm::Value *orig_val, llvm::MaybeAlign prevalign, bool isVolatile, llvm::AtomicOrdering ordering, llvm::SyncScope::ID syncScope, llvm::Value *mask) { + switch (mode) { + case TruncMemMode: { + if (orig_val->getType() != getFromType()) + return; + if (!isa(orig_val)) + return; + auto newI = getNewFromOriginal(&I); + IRBuilder<> B(newI); + newI->setOperand(0, createFPRTConstCall(B, getNewFromOriginal(orig_val))); + return; + } + case TruncOpMode: + case TruncOpFullModuleMode: + break; + default: + llvm_unreachable("Unknown trunc mode"); + } return; } @@ -5435,17 +5572,55 @@ class TruncateGenerator : public llvm::InstVisitor, if (mode != TruncOpFullModuleMode) { RequestContext ctx(&CI, &BuilderZ); - auto val = GetShadow(ctx, getNewFromOriginal(CI.getCalledOperand())); - newCall->setCalledOperand(val); + Function *Func = CI.getCalledFunction(); + if (Func && !Func->empty()) { + auto val = GetShadow(ctx, getNewFromOriginal(CI.getCalledOperand())); + newCall->setCalledOperand(val); + } else { + switch (mode) { + case TruncMemMode: + EmitWarning("FPNoFollow", CI, + "Will not follow FP through this function call as the " + "definition is not available.", + CI); + break; + case TruncOpMode: + case TruncOpFullModuleMode: + EmitWarning("FPNoFollow", CI, + "Will not truncate flops in this function call as the " + "definition is not available.", + CI); + break; + default: + llvm_unreachable("Unknown trunc mode"); + } + } } return; } - void visitFPTruncInst(FPTruncInst &I) { return; } - void visitFPExtInst(FPExtInst &I) { return; } - void visitFPToUIInst(FPToUIInst &I) { return; } - void visitFPToSIInst(FPToSIInst &I) { return; } - void visitUIToFPInst(UIToFPInst &I) { return; } - void visitSIToFPInst(SIToFPInst &I) { return; } + void visitPHINode(llvm::PHINode &PN) { + switch (mode) { + case TruncMemMode: { + if (PN.getType() != getFromType()) + return; + auto NewPN = cast(getNewFromOriginal(&PN)); + IRBuilder<> B( + NewPN->getParent()->getParent()->getEntryBlock().getFirstNonPHI()); + for (unsigned It = 0; It < NewPN->getNumIncomingValues(); It++) { + if (isa(NewPN->getIncomingValue(It))) { + NewPN->setOperand( + It, createFPRTConstCall(B, NewPN->getIncomingValue(It))); + } + } + break; + } + case TruncOpMode: + case TruncOpFullModuleMode: + break; + default: + llvm_unreachable("Unknown trunc mode"); + } + } }; bool EnzymeLogic::CreateTruncateValue(RequestContext context, Value *v, @@ -5457,7 +5632,8 @@ bool EnzymeLogic::CreateTruncateValue(RequestContext context, Value *v, Value *converted = nullptr; auto truncation = FloatTruncation(from, to, TruncMemMode); - TruncateUtils TU(truncation, B.GetInsertBlock()->getParent()->getParent()); + TruncateUtils TU(truncation, B.GetInsertBlock()->getParent()->getParent(), + *this); if (isTruncate) converted = TU.createFPRTNewCall(B, v); else diff --git a/enzyme/Enzyme/EnzymeLogic.h b/enzyme/Enzyme/EnzymeLogic.h index dd9b877c5ade..6c3611c67a8e 100644 --- a/enzyme/Enzyme/EnzymeLogic.h +++ b/enzyme/Enzyme/EnzymeLogic.h @@ -31,6 +31,7 @@ #define ENZYME_LOGIC_H #include +#include #include #include @@ -413,9 +414,14 @@ struct FloatTruncation { std::string mangleFrom() const { return from.to_string(); } }; +typedef std::map, + llvm::GlobalValue *> + UniqDebugLocStrsTy; + class EnzymeLogic { public: PreProcessCache PPC; + UniqDebugLocStrsTy UniqDebugLocStrs; /// \p PostOpt is whether to perform basic /// optimization of the function after synthesis diff --git a/enzyme/Enzyme/Runtimes/FPRT/Trace.cpp b/enzyme/Enzyme/Runtimes/FPRT/Trace.cpp index e06f02508d59..3166b1087da2 100644 --- a/enzyme/Enzyme/Runtimes/FPRT/Trace.cpp +++ b/enzyme/Enzyme/Runtimes/FPRT/Trace.cpp @@ -25,76 +25,367 @@ // //===----------------------------------------------------------------------===// +#include +#include +#include +#include +#include #include +#include #include #include -#define __ENZYME_MPFR_ATTRIBUTES -#define __ENZYME_MPFR_ORIGINAL_ATTRIBUTES +#include +#include + +#define __ENZYME_MPFR_ATTRIBUTES __attribute__((weak)) +#define __ENZYME_MPFR_ORIGINAL_ATTRIBUTES __attribute__((weak)) + +#ifndef ENZYME_FPRT_TRACE_PRINT +#define ENZYME_FPRT_TRACE_PRINT 1 +#endif + +static constexpr unsigned fp_max_inputs = 3; +static constexpr std::array arg_names = {"x", "y", "z"}; +static_assert(arg_names.size() == fp_max_inputs); extern "C" { +typedef struct __enzyme_fp { +private: + double result; + unsigned char input_num; + const char *loc; + __enzyme_fp *inputs[fp_max_inputs]; + double derivatives[fp_max_inputs]; +#if ENZYME_FPRT_TRACE_PRINT + const char *name; +#endif + +public: + size_t id; + + double getDerivative(unsigned no) const { return derivatives[no]; } + void setDerivative(unsigned no, double d) { derivatives[no] = d; } + + __enzyme_fp *getInput(unsigned no) const { return inputs[no]; } + void setInput(unsigned no, __enzyme_fp *i) { inputs[no] = i; } + + unsigned char getInputNum() const { return input_num; } + void setInputNum(unsigned char i) { input_num = i; } + + double getResult() const { return result; } + void setResult(double r) { result = r; } + + const char *getLoc() const { return loc; } + void setLoc(const char *l) { loc = l; } + +#if ENZYME_FPRT_TRACE_PRINT + const char *getName() const { return name; } + void setName(const char *l) { name = l; } +#endif -typedef struct { - double v; } __enzyme_fp; +} -// TODO ultimately we probably want a linked list of arrays or something like -// that for this -static std::list<__enzyme_fp> FPs; +static void print_enzyme_fp_derivatives(std::ostream &out, + const __enzyme_fp *fp) { + auto seen = false; + for (unsigned i = 0; i < fp->getInputNum(); i++) { + if (seen) + out << ", "; + seen = true; + out << "d" << arg_names[i] << " = " << fp->getDerivative(i); + } +} +static void print_enzyme_fp_value(std::ostream &out, const __enzyme_fp *fp) { + out << "[" << fp << ": " << fp->getResult() << "]"; +} +static void print_enzyme_fp_function(std::ostream &out, const __enzyme_fp *fp) { + std::cerr << fp->getName() << "("; + bool seen = false; + for (unsigned i = 0; i < fp->getInputNum(); i++) { + if (seen) + std::cerr << ", "; + seen = true; + __enzyme_fp *fpinput = fp->getInput(i); + print_enzyme_fp_value(std::cerr, fpinput); + } + std::cerr << ")"; +} +static void print_enzyme_fp(std::ostream &out, const __enzyme_fp *fp) { + print_enzyme_fp_function(out, fp); + out << " -> "; + print_enzyme_fp_value(out, fp); + out << " "; + print_enzyme_fp_derivatives(out, fp); + out << " at " << fp->getLoc(); + out << std::endl; +} + +template +static void __enzyme_fprt_trace_no_res_flop(std::array inputs, + const char *name, const char *loc) { + __enzyme_fp fp; + fp.setInputNum(NumInputs); + fp.setLoc(loc); + for (unsigned i = 0; i < inputs.size(); i++) { + __enzyme_fp *inputfp = __enzyme_fprt_double_to_ptr(inputs[i]); + fp.setInput(i, inputfp); + } + +#if ENZYME_FPRT_TRACE_PRINT + fp.setName(name); + print_enzyme_fp_function(std::cerr, &fp); + std::cerr << " at " << loc << std::endl; +#endif +} + +namespace { +template class Derivative { +public: + __attribute__((always_inline)) static T get(void *fn, + std::array inputs) { + return 0; + } +}; +template class Derivative { +public: + __attribute__((always_inline)) static T get(void *fn, + std::array inputs) { + typedef double (*fty)(double); + return __enzyme_fwddiff((fty)fn, enzyme_dup, inputs[0], 1.0); + } +}; +template class Derivative { +public: + __attribute__((always_inline)) static T get(void *fn, + std::array inputs) { + typedef double (*fty)(double); + return __enzyme_fwddiff((fty)fn, enzyme_dup, inputs[0], 1.0, + enzyme_const, inputs[1]); + } +}; +template class Derivative { +public: + __attribute__((always_inline)) static T get(void *fn, + std::array inputs) { + typedef double (*fty)(double); + return __enzyme_fwddiff((fty)fn, enzyme_const, inputs[0], enzyme_dup, + inputs[1], 1.0); + } +}; +template class Derivative { +public: + __attribute__((always_inline)) static T get(void *fn, + std::array inputs) { + typedef double (*fty)(double); + // clang-format off + return __enzyme_fwddiff((fty)fn, + enzyme_dup, inputs[0], 1.0, + enzyme_const, inputs[1], + enzyme_const, inputs[2] + ); + // clang-format on + } +}; +template class Derivative { +public: + __attribute__((always_inline)) static T get(void *fn, + std::array inputs) { + typedef double (*fty)(double); + // clang-format off + return __enzyme_fwddiff((fty)fn, + enzyme_const, inputs[0], + enzyme_dup, inputs[1], 1.0, + enzyme_const, inputs[2] + ); + // clang-format on + } +}; +template class Derivative { +public: + __attribute__((always_inline)) static T get(void *fn, + std::array inputs) { + typedef double (*fty)(double); + // clang-format off + return __enzyme_fwddiff((fty)fn, + enzyme_const, inputs[0], + enzyme_const, inputs[1], + enzyme_dup, inputs[2], 1.0 + ); + // clang-format on + } +}; +} // namespace + +template +__attribute__((always_inline)) static void +__enzyme_fprt_trace_flop(std::array _inputs, T output_val, + __enzyme_fp *outfp, void *fn, const char *name, + const char *loc) { + std::array<__enzyme_fp *, NumInputs> inputs; + std::array input_vals; + for (unsigned i = 0; i < _inputs.size(); i++) { + __enzyme_fp *inputfp = __enzyme_fprt_double_to_ptr(_inputs[i]); + inputs[i] = inputfp; + input_vals[i] = inputfp->getResult(); + } -static bool __enzyme_fprt_is_mem_mode(int64_t mode) { return mode & 0b0001; } -static bool __enzyme_fprt_is_op_mode(int64_t mode) { return mode & 0b0010; } + outfp->setResult(output_val); + outfp->setInputNum(inputs.size()); + outfp->setLoc(loc); + for (unsigned i = 0; i < inputs.size(); i++) { + outfp->setInput(i, inputs[i]); + T d; + static_assert(inputs.size() <= fp_max_inputs); + if (i == 0) + d = Derivative::get(fn, input_vals); + else if (i == 1) + d = Derivative::get(fn, input_vals); + else if (i == 2) + d = Derivative::get(fn, input_vals); + else + llvm_unreachable("impossible"); + outfp->setDerivative(i, d); + } -static double __enzyme_fprt_ptr_to_double(__enzyme_fp *p) { - return *((double *)(&p)); +#if ENZYME_FPRT_TRACE_PRINT + outfp->setName(name); + print_enzyme_fp(std::cerr, outfp); +#endif } -static __enzyme_fp *__enzyme_fprt_double_to_ptr(double d) { - return *((__enzyme_fp **)(&d)); + +// TODO ultimately we probably want a linked list of arrays or something like +// that for this (std::list probably is that but we may want our own impl) +struct { + std::list<__enzyme_fp> all; + std::list<__enzyme_fp *> outputs; + std::list<__enzyme_fp *> inputs; + std::list<__enzyme_fp *> consts; + void clear() { + all.clear(); + outputs.clear(); + inputs.clear(); + } +} FPs; + +extern "C" { + +__enzyme_fp *__enzyme_fprt_64_52_new_intermediate(int64_t exponent, + int64_t significand, + int64_t mode, + const char *loc) { + size_t id = FPs.all.size(); + FPs.all.push_back({}); + __enzyme_fp *a = &FPs.all.back(); + a->id = id; + return a; } -__ENZYME_MPFR_ATTRIBUTES double __enzyme_fprt_64_52_get(double _a, int64_t exponent, int64_t significand, - int64_t mode) { + int64_t mode, const char *loc) { __enzyme_fp *a = __enzyme_fprt_double_to_ptr(_a); - return a->v; + FPs.outputs.push_back(a); + __enzyme_fprt_trace_no_res_flop({_a}, "get", loc); + return a->getResult(); } -__ENZYME_MPFR_ATTRIBUTES double __enzyme_fprt_64_52_new(double _a, int64_t exponent, int64_t significand, - int64_t mode) { - FPs.push_back({_a}); - __enzyme_fp *a = &FPs.back(); - return __enzyme_fprt_ptr_to_double(a); + int64_t mode, const char *loc) { + __enzyme_fp *a = + __enzyme_fprt_64_52_new_intermediate(exponent, significand, mode, loc); + FPs.inputs.push_back(a); + __enzyme_fprt_trace_flop({}, _a, a, nullptr, "new", loc); + auto ret = __enzyme_fprt_ptr_to_double(a); + return ret; } -__ENZYME_MPFR_ATTRIBUTES -__enzyme_fp *__enzyme_fprt_64_52_new_intermediate(int64_t exponent, - int64_t significand, - int64_t mode) { - FPs.push_back({0}); - __enzyme_fp *a = &FPs.back(); - return a; +double __enzyme_fprt_64_52_const(double _a, int64_t exponent, + int64_t significand, int64_t mode, + const char *loc) { + // TODO This should really be called only once for an appearance in the code, + // currently it is called every time a flop uses a constant. + __enzyme_fp *a = + __enzyme_fprt_64_52_new_intermediate(exponent, significand, mode, loc); + FPs.consts.push_back(a); + __enzyme_fprt_trace_flop({}, _a, a, nullptr, "const", loc); + auto ret = __enzyme_fprt_ptr_to_double(a); + return ret; } -__ENZYME_MPFR_ATTRIBUTES void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, - int64_t mode) { + int64_t mode, const char *loc) { // TODO + __enzyme_fprt_trace_no_res_flop({a}, "delete", loc); +} + +// Below sensitivity computation is taken frmo ADAPT +static double __enzyme_estimate_truncation_error(double a) { + return abs(a - (float)a); +} + +void __enzyme_fprt_delete_all() { + size_t size = FPs.all.size(); + size_t i = 0; + for (auto it = FPs.all.begin(); it != FPs.all.end(); i++, it++) { + // Do not truncate inputs + if (std::find(FPs.inputs.begin(), FPs.inputs.end(), &*it) != + FPs.inputs.end()) + continue; + // Or consts + if (std::find(FPs.consts.begin(), FPs.consts.end(), &*it) != + FPs.consts.end()) + continue; + + // Zero out all errors + // TODO is it faster to calloc each time or should we pre-allocate and + // memset? + double *errors = (double *)std::calloc(size, sizeof(*errors)); + // Introduce truncation error into the current op + // TODO we can probably re-run the original operation in the truncated + // precision thus get the real error and not an estimation + errors[i] = __enzyme_estimate_truncation_error(it->getResult()); + + size_t j = i; + for (auto jt = it; jt != FPs.all.end(); j++, jt++) + for (unsigned char k = 0; k < jt->getInputNum(); k++) + errors[j] += abs(jt->getDerivative(k) * errors[jt->getInput(k)->id]); + +#if ENZYME_FPRT_TRACE_PRINT + std::cerr << "For instance "; + print_enzyme_fp_value(std::cerr, &*it); + std::cerr << " when truncated from double to float:" << std::endl; + + for (__enzyme_fp *output : FPs.outputs) { + std::cerr << " wrt output "; + print_enzyme_fp_value(std::cerr, output); + std::cerr << " at " << output->getLoc() + << ", sensitivity = " << errors[output->id] << std::endl; + } +#endif + } + FPs.clear(); } #define __ENZYME_MPFR_SINGOP(OP_TYPE, LLVM_OP_NAME, MPFR_FUNC_NAME, FROM_TYPE, \ RET, MPFR_GET, ARG1, MPFR_SET_ARG1, \ ROUNDING_MODE) \ __ENZYME_MPFR_ATTRIBUTES \ - RET __enzyme_fprt_original_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ - ARG1 a, int64_t exponent, int64_t significand, int64_t mode); \ + RET __enzyme_fprt_original_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME(ARG1 a); \ __ENZYME_MPFR_ATTRIBUTES \ RET __enzyme_fprt_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ - ARG1 a, int64_t exponent, int64_t significand, int64_t mode) { \ - RET res = \ - __enzyme_fprt_original_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME(a); \ - __enzyme_trace_flop({a}, ret, #LLVM_OP_NAME); \ - return res; \ + ARG1 a, int64_t exponent, int64_t significand, int64_t mode, \ + const char *loc) { \ + auto originalfn = \ + __enzyme_fprt_original_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME; \ + RET res = originalfn(__enzyme_fprt_double_to_ptr(a)->getResult()); \ + __enzyme_fp *intermediate = __enzyme_fprt_64_52_new_intermediate( \ + exponent, significand, mode, loc); \ + intermediate->setResult(res); \ + double ret = __enzyme_fprt_ptr_to_double(intermediate); \ + __enzyme_fprt_trace_flop({a}, res, intermediate, \ + (void *)originalfn, #LLVM_OP_NAME, loc); \ + return ret; \ } // TODO this is a bit sketchy if the user cast their float to int before calling @@ -107,53 +398,93 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, ARG2 b); \ __ENZYME_MPFR_ATTRIBUTES RET \ __enzyme_fprt_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ - ARG1 a, ARG2 b, int64_t exponent, int64_t significand, \ - int64_t mode) { \ - RET res = \ - __enzyme_fprt_original_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME(a, b); \ - __enzyme_trace_flop({a, b}, ret, #LLVM_OP_NAME); \ - return res; \ + ARG1 a, ARG2 b, int64_t exponent, int64_t significand, int64_t mode, \ + const char *loc) { \ + auto originalfn = \ + __enzyme_fprt_original_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME; \ + RET res = originalfn(__enzyme_fprt_double_to_ptr(a)->getResult(), \ + __enzyme_fprt_double_to_ptr(b)->getResult()); \ + __enzyme_fp *intermediate = __enzyme_fprt_64_52_new_intermediate( \ + exponent, significand, mode, loc); \ + intermediate->setResult(res); \ + double ret = __enzyme_fprt_ptr_to_double(intermediate); \ + __enzyme_fprt_trace_flop({a}, res, intermediate, \ + (void *)originalfn, #LLVM_OP_NAME, loc); \ + return ret; \ } #define __ENZYME_MPFR_BIN(OP_TYPE, LLVM_OP_NAME, MPFR_FUNC_NAME, FROM_TYPE, \ RET, MPFR_GET, ARG1, MPFR_SET_ARG1, ARG2, \ MPFR_SET_ARG2, ROUNDING_MODE) \ __ENZYME_MPFR_ORIGINAL_ATTRIBUTES \ - RET __enzyme_fprt_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME(ARG1 a, ARG2 b); \ + RET __enzyme_fprt_original_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME(ARG1 a, \ + ARG2 b); \ __ENZYME_MPFR_ATTRIBUTES \ RET __enzyme_fprt_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ - ARG1 a, ARG2 b, int64_t exponent, int64_t significand, int64_t mode) { \ - RET res = \ - __enzyme_fprt_original_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME(a, b); \ - __enzyme_trace_flop({a, b}, ret, #LLVM_OP_NAME); \ - return res; \ + ARG1 a, ARG2 b, int64_t exponent, int64_t significand, int64_t mode, \ + const char *loc) { \ + auto originalfn = \ + __enzyme_fprt_original_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME; \ + RET res = originalfn(__enzyme_fprt_double_to_ptr(a)->getResult(), \ + __enzyme_fprt_double_to_ptr(b)->getResult()); \ + __enzyme_fp *intermediate = __enzyme_fprt_64_52_new_intermediate( \ + exponent, significand, mode, loc); \ + intermediate->setResult(res); \ + double ret = __enzyme_fprt_ptr_to_double(intermediate); \ + __enzyme_fprt_trace_flop({a, b}, res, intermediate, \ + (void *)originalfn, #LLVM_OP_NAME, loc); \ + return ret; \ } -#define __ENZYME_MPFR_FMULADD(LLVM_OP_NAME, FROM_TYPE, TYPE, MPFR_TYPE, \ - LLVM_TYPE, ROUNDING_MODE) \ +#define __ENZYME_MPFR_FMULADD(LLVM_OP_NAME, FROM_TYPE, TYPE, MPFR_TYPE, \ + LLVM_TYPE, ROUNDING_MODE) \ + __ENZYME_MPFR_ORIGINAL_ATTRIBUTES \ + TYPE __enzyme_fprt_original_##FROM_TYPE##_intr_##LLVM_OP_NAME##_##LLVM_TYPE( \ + TYPE a, TYPE b, TYPE c); \ + __ENZYME_MPFR_ATTRIBUTES \ + TYPE __enzyme_fprt_##FROM_TYPE##_intr_##LLVM_OP_NAME##_##LLVM_TYPE( \ + TYPE a, TYPE b, TYPE c, int64_t exponent, int64_t significand, \ + int64_t mode, const char *loc) { \ + auto originalfn = \ + __enzyme_fprt_original_##FROM_TYPE##_intr_##LLVM_OP_NAME##_##LLVM_TYPE; \ + TYPE res = originalfn(__enzyme_fprt_double_to_ptr(a)->getResult(), \ + __enzyme_fprt_double_to_ptr(b)->getResult(), \ + __enzyme_fprt_double_to_ptr(c)->getResult()); \ + __enzyme_fp *intermediate = __enzyme_fprt_64_52_new_intermediate( \ + exponent, significand, mode, loc); \ + intermediate->setResult(res); \ + double ret = __enzyme_fprt_ptr_to_double(intermediate); \ + __enzyme_fprt_trace_flop({a, b, c}, res, intermediate, \ + (void *)originalfn, #LLVM_OP_NAME, loc); \ + return ret; \ + } + +#define __ENZYME_MPFR_FCMP_IMPL(NAME, ORDERED, CMP, FROM_TYPE, TYPE, MPFR_GET, \ + ROUNDING_MODE) \ __ENZYME_MPFR_ORIGINAL_ATTRIBUTES \ - RET __enzyme_fprt_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME(TYPE a, TYPE b, \ - TYPE c); \ + bool __enzyme_fprt_original_##FROM_TYPE##_fcmp_##NAME(TYPE a, TYPE b); \ __ENZYME_MPFR_ATTRIBUTES \ - TYPE __enzyme_fprt_##FROM_TYPE##_intr_##LLVM_OP_NAME##_##LLVM_TYPE( \ - TYPE a, TYPE b, TYPE c, int64_t exponent, int64_t significand, \ - int64_t mode) { \ - RET res = __enzyme_fprt_original_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ - a, b, c); \ - __enzyme_trace_flop({a, b, c}, ret, #LLVM_OP_NAME); \ + bool __enzyme_fprt_##FROM_TYPE##_fcmp_##NAME( \ + TYPE a, TYPE b, int64_t exponent, int64_t significand, int64_t mode, \ + const char *loc) { \ + bool res = __enzyme_fprt_original_##FROM_TYPE##_fcmp_##NAME( \ + __enzyme_fprt_double_to_ptr(a)->getResult(), \ + __enzyme_fprt_double_to_ptr(b)->getResult()); \ + __enzyme_fprt_trace_no_res_flop({a, b}, "fcmp_" #NAME, loc); \ return res; \ } __ENZYME_MPFR_ORIGINAL_ATTRIBUTES bool __enzyme_fprt_original_64_52_intr_llvm_is_fpclass_f64(double a, int32_t tests); -__ENZYME_MPFR_ATTRIBUTES bool -__enzyme_fprt_64_52_intr_llvm_is_fpclass_f64(double a, int32_t tests) { - return __enzyme_fprt_original_64_52_intr_llvm_is_fpclass_f64(a, tests); +__ENZYME_MPFR_ATTRIBUTES bool __enzyme_fprt_64_52_intr_llvm_is_fpclass_f64( + double a, int32_t tests, int64_t exponent, int64_t significand, + int64_t mode, const char *loc) { + __enzyme_fprt_trace_no_res_flop({a}, "llvm_is_fpclass_f64", loc); + return __enzyme_fprt_original_64_52_intr_llvm_is_fpclass_f64( + __enzyme_fprt_double_to_ptr(a)->getResult(), tests); } -#include "enzyme/fprt/flops.def" +#include } // extern "C" - -#endif // #ifndef __ENZYME_RUNTIME_ENZYME_MPFR__ diff --git a/enzyme/include/enzyme/fprt/flops.def b/enzyme/include/enzyme/fprt/flops.def index a3c4d7fcbac6..62e4b48ba00d 100644 --- a/enzyme/include/enzyme/fprt/flops.def +++ b/enzyme/include/enzyme/fprt/flops.def @@ -1,3 +1,4 @@ +// -*- mode: c++ -*- #define __ENZYME_MPFR_DOUBLE_BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, \ ROUNDING_MODE) \ @@ -112,6 +113,9 @@ __ENZYME_MPFR_SINGOP_DOUBLE_FLOAT(lgamma, lngamma); // TODO This is not accurate (I think we cast int to double) __ENZYME_MPFR_SINGOP_DOUBLE_FLOAT(nearbyint, rint); +__ENZYME_MPFR_SINGOP(unaryop, fneg, neg, 64_52, double, d, double, + d, __ENZYME_MPFR_DEFAULT_ROUNDING_MODE) + // Ternary operation __ENZYME_MPFR_FMULADD(llvm_fmuladd, 64_52, double, d, f64, __ENZYME_MPFR_DEFAULT_ROUNDING_MODE); diff --git a/enzyme/include/enzyme/fprt/fprt.h b/enzyme/include/enzyme/fprt/fprt.h new file mode 100644 index 000000000000..2796c76536ab --- /dev/null +++ b/enzyme/include/enzyme/fprt/fprt.h @@ -0,0 +1,56 @@ +#ifndef _ENZYME_FPRT_FPRT_H_ +#define _ENZYME_FPRT_FPRT_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// User-facing API +double __enzyme_fprt_64_52_get(double _a, int64_t exponent, int64_t significand, + int64_t mode, const char *loc); +double __enzyme_fprt_64_52_new(double _a, int64_t exponent, int64_t significand, + int64_t mode, const char *loc); +void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, + int64_t mode, const char *loc); +double __enzyme_truncate_mem_value_d(double, int, int); +float __enzyme_truncate_mem_value_f(float, int, int); +double __enzyme_expand_mem_value_d(double, int, int); +float __enzyme_expand_mem_value_f(float, int, int); +void __enzyme_fprt_delete_all(); + +// For internal use +struct __enzyme_fp; +__enzyme_fp *__enzyme_fprt_64_52_new_intermediate(int64_t exponent, + int64_t significand, + int64_t mode, + const char *loc); +double __enzyme_fprt_64_52_const(double _a, int64_t exponent, + int64_t significand, int64_t mode, + const char *loc); + +[[maybe_unused]] static bool __enzyme_fprt_is_mem_mode(int64_t mode) { + return mode & 0b0001; +} +[[maybe_unused]] static bool __enzyme_fprt_is_op_mode(int64_t mode) { + return mode & 0b0010; +} +[[maybe_unused]] static double __enzyme_fprt_idx_to_double(uint64_t p) { + return *((double *)(&p)); +} +[[maybe_unused]] static uint64_t __enzyme_fprt_double_to_idx(double d) { + return *((uint64_t *)(&d)); +} +[[maybe_unused]] static double __enzyme_fprt_ptr_to_double(__enzyme_fp *p) { + return *((double *)(&p)); +} +[[maybe_unused]] static __enzyme_fp *__enzyme_fprt_double_to_ptr(double d) { + return *((__enzyme_fp **)(&d)); +} + +#ifdef __cplusplus +} +#endif + +#endif // _ENZYME_FPRT_FPRT_H_ diff --git a/enzyme/include/enzyme/fprt/mpfr-test.h b/enzyme/include/enzyme/fprt/mpfr-test.h new file mode 100644 index 000000000000..5a48977d256c --- /dev/null +++ b/enzyme/include/enzyme/fprt/mpfr-test.h @@ -0,0 +1,271 @@ +//===- fprt/mpfr - MPFR wrappers ---------------------------------------===// +// +// Enzyme Project +// +// Part of the Enzyme Project, under the Apache License v2.0 with LLVM +// Exceptions. See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// If using this code in an academic setting, please cite the following: +// @incollection{enzymeNeurips, +// title = {Instead of Rewriting Foreign Code for Machine Learning, +// Automatically Synthesize Fast Gradients}, +// author = {Moses, William S. and Churavy, Valentin}, +// booktitle = {Advances in Neural Information Processing Systems 33}, +// year = {2020}, +// note = {To appear in}, +// } +// +//===----------------------------------------------------------------------===// +// +// This file contains easy to use wrappers around MPFR functions. +// +//===----------------------------------------------------------------------===// +#ifndef __ENZYME_RUNTIME_ENZYME_MPFR__ +#define __ENZYME_RUNTIME_ENZYME_MPFR__ + +#include +#include +#include + +#include "fprt.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define __ENZYME_MPFR_ATTRIBUTES __attribute__((weak)) +#define __ENZYME_MPFR_ORIGINAL_ATTRIBUTES __attribute__((weak)) +#define __ENZYME_MPFR_DEFAULT_ROUNDING_MODE GMP_RNDN + +typedef struct __enzyme_fp { + mpfr_t result; +} __enzyme_fp; + +__ENZYME_MPFR_ATTRIBUTES +double __enzyme_fprt_64_52_get(double _a, int64_t exponent, int64_t significand, + int64_t mode, const char *loc) { + printf("%p, %s\n", loc, loc); + __enzyme_fp *a = __enzyme_fprt_double_to_ptr(_a); + return mpfr_get_d(a->result, __ENZYME_MPFR_DEFAULT_ROUNDING_MODE); +} + +__ENZYME_MPFR_ATTRIBUTES +double __enzyme_fprt_64_52_new(double _a, int64_t exponent, int64_t significand, + int64_t mode, const char *loc) { + printf("%p, %s\n", loc, loc); + __enzyme_fp *a = (__enzyme_fp *)malloc(sizeof(__enzyme_fp)); + mpfr_init2(a->result, significand); + mpfr_set_d(a->result, _a, __ENZYME_MPFR_DEFAULT_ROUNDING_MODE); + return __enzyme_fprt_ptr_to_double(a); +} + +__ENZYME_MPFR_ATTRIBUTES +double __enzyme_fprt_64_52_const(double _a, int64_t exponent, + int64_t significand, int64_t mode, + const char *loc) { + printf("%p, %s\n", loc, loc); + // TODO This should really be called only once for an appearance in the code, + // currently it is called every time a flop uses a constant. + return __enzyme_fprt_64_52_new(_a, exponent, significand, mode, loc); +} + +__ENZYME_MPFR_ATTRIBUTES +__enzyme_fp *__enzyme_fprt_64_52_new_intermediate(int64_t exponent, + int64_t significand, + int64_t mode, + const char *loc) { + printf("%p, %s\n", loc, loc); + __enzyme_fp *a = (__enzyme_fp *)malloc(sizeof(__enzyme_fp)); + mpfr_init2(a->result, significand); + return a; +} + +__ENZYME_MPFR_ATTRIBUTES +void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, + int64_t mode, const char *loc) { + printf("%p, %s\n", loc, loc); + free(__enzyme_fprt_double_to_ptr(a)); +} + +#define __ENZYME_MPFR_SINGOP(OP_TYPE, LLVM_OP_NAME, MPFR_FUNC_NAME, FROM_TYPE, \ + RET, MPFR_GET, ARG1, MPFR_SET_ARG1, \ + ROUNDING_MODE) \ + __ENZYME_MPFR_ATTRIBUTES \ + RET __enzyme_fprt_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ + ARG1 a, int64_t exponent, int64_t significand, int64_t mode, \ + const char *loc) { \ + printf("%p, %s, %s\n", loc, #LLVM_OP_NAME, loc); \ + if (__enzyme_fprt_is_op_mode(mode)) { \ + mpfr_t ma, mc; \ + mpfr_init2(ma, significand); \ + mpfr_init2(mc, significand); \ + mpfr_set_##MPFR_SET_ARG1(ma, a, ROUNDING_MODE); \ + mpfr_##MPFR_FUNC_NAME(mc, ma, ROUNDING_MODE); \ + RET c = mpfr_get_##MPFR_GET(mc, ROUNDING_MODE); \ + mpfr_clear(ma); \ + mpfr_clear(mc); \ + return c; \ + } else if (__enzyme_fprt_is_mem_mode(mode)) { \ + __enzyme_fp *ma = __enzyme_fprt_double_to_ptr(a); \ + __enzyme_fp *mc = __enzyme_fprt_64_52_new_intermediate( \ + exponent, significand, mode, loc); \ + mpfr_##MPFR_FUNC_NAME(mc->result, ma->result, ROUNDING_MODE); \ + return __enzyme_fprt_ptr_to_double(mc); \ + } else { \ + abort(); \ + } \ + } + +// TODO this is a bit sketchy if the user cast their float to int before calling +// this. We need to detect these patterns +#define __ENZYME_MPFR_BIN_INT(OP_TYPE, LLVM_OP_NAME, MPFR_FUNC_NAME, \ + FROM_TYPE, RET, MPFR_GET, ARG1, MPFR_SET_ARG1, \ + ARG2, ROUNDING_MODE) \ + __ENZYME_MPFR_ATTRIBUTES \ + RET __enzyme_fprt_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ + ARG1 a, ARG2 b, int64_t exponent, int64_t significand, int64_t mode, \ + const char *loc) { \ + printf("%p, %s, %s\n", loc, #LLVM_OP_NAME, loc); \ + if (__enzyme_fprt_is_op_mode(mode)) { \ + mpfr_t ma, mc; \ + mpfr_init2(ma, significand); \ + mpfr_init2(mc, significand); \ + mpfr_set_##MPFR_SET_ARG1(ma, a, ROUNDING_MODE); \ + mpfr_##MPFR_FUNC_NAME(mc, ma, b, ROUNDING_MODE); \ + RET c = mpfr_get_##MPFR_GET(mc, ROUNDING_MODE); \ + mpfr_clear(ma); \ + mpfr_clear(mc); \ + return c; \ + } else if (__enzyme_fprt_is_mem_mode(mode)) { \ + __enzyme_fp *ma = __enzyme_fprt_double_to_ptr(a); \ + __enzyme_fp *mc = __enzyme_fprt_64_52_new_intermediate( \ + exponent, significand, mode, loc); \ + mpfr_##MPFR_FUNC_NAME(mc->result, ma->result, b, ROUNDING_MODE); \ + return __enzyme_fprt_ptr_to_double(mc); \ + } else { \ + abort(); \ + } \ + } + +#define __ENZYME_MPFR_BIN(OP_TYPE, LLVM_OP_NAME, MPFR_FUNC_NAME, FROM_TYPE, \ + RET, MPFR_GET, ARG1, MPFR_SET_ARG1, ARG2, \ + MPFR_SET_ARG2, ROUNDING_MODE) \ + __ENZYME_MPFR_ATTRIBUTES \ + RET __enzyme_fprt_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ + ARG1 a, ARG2 b, int64_t exponent, int64_t significand, int64_t mode, \ + const char *loc) { \ + printf("%p, %s, %s\n", loc, #LLVM_OP_NAME, loc); \ + if (__enzyme_fprt_is_op_mode(mode)) { \ + mpfr_t ma, mb, mc; \ + mpfr_init2(ma, significand); \ + mpfr_init2(mb, significand); \ + mpfr_init2(mc, significand); \ + mpfr_set_##MPFR_SET_ARG1(ma, a, ROUNDING_MODE); \ + mpfr_set_##MPFR_SET_ARG2(mb, b, ROUNDING_MODE); \ + mpfr_##MPFR_FUNC_NAME(mc, ma, mb, ROUNDING_MODE); \ + RET c = mpfr_get_##MPFR_GET(mc, ROUNDING_MODE); \ + mpfr_clear(ma); \ + mpfr_clear(mb); \ + mpfr_clear(mc); \ + return c; \ + } else if (__enzyme_fprt_is_mem_mode(mode)) { \ + __enzyme_fp *ma = __enzyme_fprt_double_to_ptr(a); \ + __enzyme_fp *mb = __enzyme_fprt_double_to_ptr(b); \ + __enzyme_fp *mc = __enzyme_fprt_64_52_new_intermediate( \ + exponent, significand, mode, loc); \ + mpfr_##MPFR_FUNC_NAME(mc->result, ma->result, mb->result, \ + ROUNDING_MODE); \ + return __enzyme_fprt_ptr_to_double(mc); \ + } else { \ + abort(); \ + } \ + } + +#define __ENZYME_MPFR_FMULADD(LLVM_OP_NAME, FROM_TYPE, TYPE, MPFR_TYPE, \ + LLVM_TYPE, ROUNDING_MODE) \ + __ENZYME_MPFR_ATTRIBUTES \ + TYPE __enzyme_fprt_##FROM_TYPE##_intr_##LLVM_OP_NAME##_##LLVM_TYPE( \ + TYPE a, TYPE b, TYPE c, int64_t exponent, int64_t significand, \ + int64_t mode, const char *loc) { \ + printf("%p, %s, %s\n", loc, #LLVM_OP_NAME, loc); \ + if (__enzyme_fprt_is_op_mode(mode)) { \ + mpfr_t ma, mb, mc, mmul, madd; \ + mpfr_init2(ma, significand); \ + mpfr_init2(mb, significand); \ + mpfr_init2(mc, significand); \ + mpfr_init2(mmul, significand); \ + mpfr_init2(madd, significand); \ + mpfr_set_##MPFR_TYPE(ma, a, ROUNDING_MODE); \ + mpfr_set_##MPFR_TYPE(mb, b, ROUNDING_MODE); \ + mpfr_set_##MPFR_TYPE(mc, c, ROUNDING_MODE); \ + mpfr_mul(mmul, ma, mb, ROUNDING_MODE); \ + mpfr_add(madd, mmul, mc, ROUNDING_MODE); \ + TYPE res = mpfr_get_##MPFR_TYPE(madd, ROUNDING_MODE); \ + mpfr_clear(ma); \ + mpfr_clear(mb); \ + mpfr_clear(mc); \ + mpfr_clear(mmul); \ + mpfr_clear(madd); \ + return res; \ + } else if (__enzyme_fprt_is_mem_mode(mode)) { \ + __enzyme_fp *ma = __enzyme_fprt_double_to_ptr(a); \ + __enzyme_fp *mb = __enzyme_fprt_double_to_ptr(b); \ + __enzyme_fp *mc = __enzyme_fprt_double_to_ptr(c); \ + double mmul = __enzyme_fprt_##FROM_TYPE##_binop_fmul( \ + __enzyme_fprt_ptr_to_double(ma), __enzyme_fprt_ptr_to_double(mb), \ + exponent, significand, mode, loc); \ + double madd = __enzyme_fprt_##FROM_TYPE##_binop_fadd( \ + mmul, __enzyme_fprt_ptr_to_double(mc), exponent, significand, mode, \ + loc); \ + return madd; \ + } else { \ + abort(); \ + } \ + } + +// TODO This does not currently make distinctions between ordered/unordered. +#define __ENZYME_MPFR_FCMP_IMPL(NAME, ORDERED, CMP, FROM_TYPE, TYPE, MPFR_GET, \ + ROUNDING_MODE) \ + __ENZYME_MPFR_ATTRIBUTES \ + bool __enzyme_fprt_##FROM_TYPE##_fcmp_##NAME( \ + TYPE a, TYPE b, int64_t exponent, int64_t significand, int64_t mode, \ + const char *loc) { \ + printf("%p, %s, %s\n", loc, "fcmp" #NAME, loc); \ + if (__enzyme_fprt_is_op_mode(mode)) { \ + mpfr_t ma, mb; \ + mpfr_init2(ma, significand); \ + mpfr_init2(mb, significand); \ + mpfr_set_##MPFR_GET(ma, a, ROUNDING_MODE); \ + mpfr_set_##MPFR_GET(mb, b, ROUNDING_MODE); \ + int ret = mpfr_cmp(ma, mb); \ + mpfr_clear(ma); \ + mpfr_clear(mb); \ + return ret CMP; \ + } else if (__enzyme_fprt_is_mem_mode(mode)) { \ + __enzyme_fp *ma = __enzyme_fprt_double_to_ptr(a); \ + __enzyme_fp *mb = __enzyme_fprt_double_to_ptr(b); \ + int ret = mpfr_cmp(ma->result, mb->result); \ + return ret CMP; \ + } else { \ + abort(); \ + } \ + } + +__ENZYME_MPFR_ORIGINAL_ATTRIBUTES +bool __enzyme_fprt_original_64_52_intr_llvm_is_fpclass_f64(double a, + int32_t tests); +__ENZYME_MPFR_ATTRIBUTES bool __enzyme_fprt_64_52_intr_llvm_is_fpclass_f64( + double a, int32_t tests, int64_t exponent, int64_t significand, + int64_t mode, const char *loc) { + return __enzyme_fprt_original_64_52_intr_llvm_is_fpclass_f64( + __enzyme_fprt_64_52_get(a, exponent, significand, mode, loc), tests); +} + +#include "flops.def" + +#ifdef __cplusplus +} +#endif + +#endif // #ifndef __ENZYME_RUNTIME_ENZYME_MPFR__ diff --git a/enzyme/include/enzyme/fprt/mpfr.h b/enzyme/include/enzyme/fprt/mpfr.h index a75cfbd84f15..58783f86242b 100644 --- a/enzyme/include/enzyme/fprt/mpfr.h +++ b/enzyme/include/enzyme/fprt/mpfr.h @@ -28,22 +28,12 @@ #include #include +#include "fprt.h" + #ifdef __cplusplus extern "C" { #endif -// TODO TODO TODO -// TODO TODO TODO -// TODO TODO TODO -// TODO TODO TODO -// TODO TODO TODO -// I dont think we intercept comparisons - we most definitely should. -// TODO TODO TODO -// TODO TODO TODO -// TODO TODO TODO -// TODO TODO TODO -// TODO TODO TODO - // TODO s // // (for MPFR ver. 2.1) @@ -73,66 +63,55 @@ extern "C" { // simulation: // [...] subnormal numbers are not implemented. // -// TODO maybe take debug info as parameter - then we can emit warnings or tie -// operations to source location -// // TODO we need to provide f32 versions, and also instrument the // truncation/expansion between f32/f64/etc -#define __ENZYME_MPFR_ATTRIBUTES __attribute__((weak)) -#define __ENZYME_MPFR_ORIGINAL_ATTRIBUTES __attribute__((weak)) +#define __ENZYME_MPFR_ATTRIBUTES __attribute__((weak)) __attribute__((used)) +#define __ENZYME_MPFR_ORIGINAL_ATTRIBUTES __attribute__((weak)) __attribute__((used)) #define __ENZYME_MPFR_DEFAULT_ROUNDING_MODE GMP_RNDN -static bool __enzyme_fprt_is_mem_mode(int64_t mode) { return mode & 0b0001; } -static bool __enzyme_fprt_is_op_mode(int64_t mode) { return mode & 0b0010; } - -typedef struct { - mpfr_t v; +typedef struct __enzyme_fp { + mpfr_t result; } __enzyme_fp; -static double __enzyme_fprt_ptr_to_double(__enzyme_fp *p) { - return *((double *)(&p)); -} -static __enzyme_fp *__enzyme_fprt_double_to_ptr(double d) { - return *((__enzyme_fp **)(&d)); -} - __ENZYME_MPFR_ATTRIBUTES double __enzyme_fprt_64_52_get(double _a, int64_t exponent, int64_t significand, - int64_t mode) { + int64_t mode, const char *loc) { __enzyme_fp *a = __enzyme_fprt_double_to_ptr(_a); - return mpfr_get_d(a->v, __ENZYME_MPFR_DEFAULT_ROUNDING_MODE); + return mpfr_get_d(a->result, __ENZYME_MPFR_DEFAULT_ROUNDING_MODE); } __ENZYME_MPFR_ATTRIBUTES double __enzyme_fprt_64_52_new(double _a, int64_t exponent, int64_t significand, - int64_t mode) { + int64_t mode, const char *loc) { __enzyme_fp *a = (__enzyme_fp *)malloc(sizeof(__enzyme_fp)); - mpfr_init2(a->v, significand); - mpfr_set_d(a->v, _a, __ENZYME_MPFR_DEFAULT_ROUNDING_MODE); + mpfr_init2(a->result, significand); + mpfr_set_d(a->result, _a, __ENZYME_MPFR_DEFAULT_ROUNDING_MODE); return __enzyme_fprt_ptr_to_double(a); } __ENZYME_MPFR_ATTRIBUTES double __enzyme_fprt_64_52_const(double _a, int64_t exponent, - int64_t significand, int64_t mode) { + int64_t significand, int64_t mode, + const char *loc) { // TODO This should really be called only once for an appearance in the code, // currently it is called every time a flop uses a constant. - return __enzyme_fprt_64_52_new(_a, exponent, significand, mode); + return __enzyme_fprt_64_52_new(_a, exponent, significand, mode, loc); } __ENZYME_MPFR_ATTRIBUTES __enzyme_fp *__enzyme_fprt_64_52_new_intermediate(int64_t exponent, int64_t significand, - int64_t mode) { + int64_t mode, + const char *loc) { __enzyme_fp *a = (__enzyme_fp *)malloc(sizeof(__enzyme_fp)); - mpfr_init2(a->v, significand); + mpfr_init2(a->result, significand); return a; } __ENZYME_MPFR_ATTRIBUTES void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, - int64_t mode) { + int64_t mode, const char *loc) { free(__enzyme_fprt_double_to_ptr(a)); } @@ -141,7 +120,8 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, ROUNDING_MODE) \ __ENZYME_MPFR_ATTRIBUTES \ RET __enzyme_fprt_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ - ARG1 a, int64_t exponent, int64_t significand, int64_t mode) { \ + ARG1 a, int64_t exponent, int64_t significand, int64_t mode, \ + const char *loc) { \ if (__enzyme_fprt_is_op_mode(mode)) { \ mpfr_t ma, mc; \ mpfr_init2(ma, significand); \ @@ -154,9 +134,9 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, return c; \ } else if (__enzyme_fprt_is_mem_mode(mode)) { \ __enzyme_fp *ma = __enzyme_fprt_double_to_ptr(a); \ - __enzyme_fp *mc = \ - __enzyme_fprt_64_52_new_intermediate(exponent, significand, mode); \ - mpfr_##MPFR_FUNC_NAME(mc->v, ma->v, ROUNDING_MODE); \ + __enzyme_fp *mc = __enzyme_fprt_64_52_new_intermediate( \ + exponent, significand, mode, loc); \ + mpfr_##MPFR_FUNC_NAME(mc->result, ma->result, ROUNDING_MODE); \ return __enzyme_fprt_ptr_to_double(mc); \ } else { \ abort(); \ @@ -170,7 +150,8 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, ARG2, ROUNDING_MODE) \ __ENZYME_MPFR_ATTRIBUTES \ RET __enzyme_fprt_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ - ARG1 a, ARG2 b, int64_t exponent, int64_t significand, int64_t mode) { \ + ARG1 a, ARG2 b, int64_t exponent, int64_t significand, int64_t mode, \ + const char *loc) { \ if (__enzyme_fprt_is_op_mode(mode)) { \ mpfr_t ma, mc; \ mpfr_init2(ma, significand); \ @@ -183,9 +164,9 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, return c; \ } else if (__enzyme_fprt_is_mem_mode(mode)) { \ __enzyme_fp *ma = __enzyme_fprt_double_to_ptr(a); \ - __enzyme_fp *mc = \ - __enzyme_fprt_64_52_new_intermediate(exponent, significand, mode); \ - mpfr_##MPFR_FUNC_NAME(mc->v, ma->v, b, ROUNDING_MODE); \ + __enzyme_fp *mc = __enzyme_fprt_64_52_new_intermediate( \ + exponent, significand, mode, loc); \ + mpfr_##MPFR_FUNC_NAME(mc->result, ma->result, b, ROUNDING_MODE); \ return __enzyme_fprt_ptr_to_double(mc); \ } else { \ abort(); \ @@ -197,7 +178,8 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, MPFR_SET_ARG2, ROUNDING_MODE) \ __ENZYME_MPFR_ATTRIBUTES \ RET __enzyme_fprt_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ - ARG1 a, ARG2 b, int64_t exponent, int64_t significand, int64_t mode) { \ + ARG1 a, ARG2 b, int64_t exponent, int64_t significand, int64_t mode, \ + const char *loc) { \ if (__enzyme_fprt_is_op_mode(mode)) { \ mpfr_t ma, mb, mc; \ mpfr_init2(ma, significand); \ @@ -214,9 +196,10 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, } else if (__enzyme_fprt_is_mem_mode(mode)) { \ __enzyme_fp *ma = __enzyme_fprt_double_to_ptr(a); \ __enzyme_fp *mb = __enzyme_fprt_double_to_ptr(b); \ - __enzyme_fp *mc = \ - __enzyme_fprt_64_52_new_intermediate(exponent, significand, mode); \ - mpfr_##MPFR_FUNC_NAME(mc->v, ma->v, mb->v, ROUNDING_MODE); \ + __enzyme_fp *mc = __enzyme_fprt_64_52_new_intermediate( \ + exponent, significand, mode, loc); \ + mpfr_##MPFR_FUNC_NAME(mc->result, ma->result, mb->result, \ + ROUNDING_MODE); \ return __enzyme_fprt_ptr_to_double(mc); \ } else { \ abort(); \ @@ -228,7 +211,7 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, __ENZYME_MPFR_ATTRIBUTES \ TYPE __enzyme_fprt_##FROM_TYPE##_intr_##LLVM_OP_NAME##_##LLVM_TYPE( \ TYPE a, TYPE b, TYPE c, int64_t exponent, int64_t significand, \ - int64_t mode) { \ + int64_t mode, const char *loc) { \ if (__enzyme_fprt_is_op_mode(mode)) { \ mpfr_t ma, mb, mc, mmul, madd; \ mpfr_init2(ma, significand); \ @@ -254,9 +237,10 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, __enzyme_fp *mc = __enzyme_fprt_double_to_ptr(c); \ double mmul = __enzyme_fprt_##FROM_TYPE##_binop_fmul( \ __enzyme_fprt_ptr_to_double(ma), __enzyme_fprt_ptr_to_double(mb), \ - exponent, significand, mode); \ + exponent, significand, mode, loc); \ double madd = __enzyme_fprt_##FROM_TYPE##_binop_fadd( \ - mmul, __enzyme_fprt_ptr_to_double(mc), exponent, significand, mode); \ + mmul, __enzyme_fprt_ptr_to_double(mc), exponent, significand, mode, \ + loc); \ return madd; \ } else { \ abort(); \ @@ -268,7 +252,8 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, ROUNDING_MODE) \ __ENZYME_MPFR_ATTRIBUTES \ bool __enzyme_fprt_##FROM_TYPE##_fcmp_##NAME( \ - TYPE a, TYPE b, int64_t exponent, int64_t significand, int64_t mode) { \ + TYPE a, TYPE b, int64_t exponent, int64_t significand, int64_t mode, \ + const char *loc) { \ if (__enzyme_fprt_is_op_mode(mode)) { \ mpfr_t ma, mb; \ mpfr_init2(ma, significand); \ @@ -282,7 +267,7 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, } else if (__enzyme_fprt_is_mem_mode(mode)) { \ __enzyme_fp *ma = __enzyme_fprt_double_to_ptr(a); \ __enzyme_fp *mb = __enzyme_fprt_double_to_ptr(b); \ - int ret = mpfr_cmp(ma->v, mb->v); \ + int ret = mpfr_cmp(ma->result, mb->result); \ return ret CMP; \ } else { \ abort(); \ @@ -292,9 +277,11 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, __ENZYME_MPFR_ORIGINAL_ATTRIBUTES bool __enzyme_fprt_original_64_52_intr_llvm_is_fpclass_f64(double a, int32_t tests); -__ENZYME_MPFR_ATTRIBUTES bool -__enzyme_fprt_64_52_intr_llvm_is_fpclass_f64(double a, int32_t tests) { - return __enzyme_fprt_original_64_52_intr_llvm_is_fpclass_f64(a, tests); +__ENZYME_MPFR_ATTRIBUTES bool __enzyme_fprt_64_52_intr_llvm_is_fpclass_f64( + double a, int32_t tests, int64_t exponent, int64_t significand, + int64_t mode, const char *loc) { + return __enzyme_fprt_original_64_52_intr_llvm_is_fpclass_f64( + __enzyme_fprt_64_52_get(a, exponent, significand, mode, loc), tests); } #include "flops.def" diff --git a/enzyme/test/Enzyme/ForwardMode/hypot.ll b/enzyme/test/Enzyme/ForwardMode/hypot.ll index 564bab7cd72d..623f4094f077 100644 --- a/enzyme/test/Enzyme/ForwardMode/hypot.ll +++ b/enzyme/test/Enzyme/ForwardMode/hypot.ll @@ -8,13 +8,26 @@ entry: ret double %call } +define double @tester2(double %x, double %y) { +entry: + %call = tail call double @__hypot_finite(double %x, double %y) + ret double %call +} + define double @test_derivative(double %x, double %y) { entry: %0 = tail call double (...) @__enzyme_fwddiff(double (double, double)* nonnull @tester, double %x, double 1.000000e+00, double %y, double 1.000000e+00) ret double %0 } +define double @test_derivative2(double %x, double %y) { +entry: + %0 = tail call double (...) @__enzyme_fwddiff(double (double, double)* nonnull @tester2, double %x, double 1.000000e+00, double %y, double 1.000000e+00) + ret double %0 +} + declare double @hypot(double, double) +declare double @__hypot_finite(double, double) ; Function Attrs: nounwind declare double @__enzyme_fwddiff(...) diff --git a/enzyme/test/Enzyme/Truncate/cmp.ll b/enzyme/test/Enzyme/Truncate/cmp.ll index d33c40d7de11..15140bdb5f75 100644 --- a/enzyme/test/Enzyme/Truncate/cmp.ll +++ b/enzyme/test/Enzyme/Truncate/cmp.ll @@ -29,7 +29,7 @@ entry: } ; CHECK: define internal i1 @__enzyme_done_truncate_mem_func_64_52to32_23_f(double %x, double %y) { -; CHECK-NEXT: %res = call i1 @__enzyme_fprt_64_52_fcmp_olt(double %x, double %y, i64 8, i64 23, i64 1) +; CHECK-NEXT: %res = call i1 @__enzyme_fprt_64_52_fcmp_olt(double %x, double %y, i64 8, i64 23, i64 1, {{.*}}i8{{.*}}) ; CHECK-NEXT: ret i1 %res ; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/Truncate/const.ll b/enzyme/test/Enzyme/Truncate/const.ll index 25c5c5ee4c3b..b90b20615a93 100644 --- a/enzyme/test/Enzyme/Truncate/const.ll +++ b/enzyme/test/Enzyme/Truncate/const.ll @@ -23,12 +23,12 @@ entry: } ; CHECK: define internal double @__enzyme_done_truncate_mem_func_64_52to32_23_f(double %x) { -; CHECK-NEXT: %1 = call double @__enzyme_fprt_64_52_const(double 1.000000e+00, i64 8, i64 23, i64 1) -; CHECK-NEXT: %res = call double @__enzyme_fprt_64_52_binop_fadd(double %x, double %1, i64 8, i64 23, i64 1) +; CHECK-NEXT: %1 = call double @__enzyme_fprt_64_52_const(double 1.000000e+00, i64 8, i64 23, i64 1, {{.*}}i8{{.*}}) +; CHECK-NEXT: %res = call double @__enzyme_fprt_64_52_binop_fadd(double %x, double %1, i64 8, i64 23, i64 1, {{.*}}i8{{.*}}) ; CHECK-NEXT: ret double %res ; CHECK-NEXT: } ; CHECK: define internal double @__enzyme_done_truncate_op_func_64_52to11_7_f(double %x) { -; CHECK-NEXT: %res = call double @__enzyme_fprt_64_52_binop_fadd(double %x, double 1.000000e+00, i64 3, i64 7, i64 2) +; CHECK-NEXT: %res = call double @__enzyme_fprt_64_52_binop_fadd(double %x, double 1.000000e+00, i64 3, i64 7, i64 2, {{.*}}i8{{.*}}) ; CHECK-NEXT: ret double %res ; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/Truncate/intrinsic.ll b/enzyme/test/Enzyme/Truncate/intrinsic.ll index 3e5fe36b3784..a5899e75c68c 100644 --- a/enzyme/test/Enzyme/Truncate/intrinsic.ll +++ b/enzyme/test/Enzyme/Truncate/intrinsic.ll @@ -42,28 +42,28 @@ entry: } ; CHECK: define internal double @__enzyme_done_truncate_mem_func_64_52to32_23_f(double %x, double %y) { -; CHECK-DAG: %1 = call double @__enzyme_fprt_64_52_func_pow(double %x, double %y, i64 8, i64 23, i64 1) -; CHECK-DAG: %2 = call double @__enzyme_fprt_64_52_intr_llvm_pow_f64(double %x, double %y, i64 8, i64 23, i64 1) -; CHECK-DAG: %3 = call double @__enzyme_fprt_64_52_intr_llvm_powi_f64_i16(double %x, i16 2, i64 8, i64 23, i64 1) -; CHECK-DAG: %res = call double @__enzyme_fprt_64_52_binop_fadd(double %2, double %3, i64 8, i64 23, i64 1) +; CHECK-DAG: %1 = call double @__enzyme_fprt_64_52_func_pow(double %x, double %y, i64 8, i64 23, i64 1, {{.*}}i8{{.*}}) +; CHECK-DAG: %2 = call double @__enzyme_fprt_64_52_intr_llvm_pow_f64(double %x, double %y, i64 8, i64 23, i64 1, {{.*}}i8{{.*}}) +; CHECK-DAG: %3 = call double @__enzyme_fprt_64_52_intr_llvm_powi_f64_i16(double %x, i16 2, i64 8, i64 23, i64 1, {{.*}}i8{{.*}}) +; CHECK-DAG: %res = call double @__enzyme_fprt_64_52_binop_fadd(double %2, double %3, i64 8, i64 23, i64 1, {{.*}}i8{{.*}}) ; CHECK-DAG: call void @llvm.nvvm.barrier0() ; CHECK-DAG: ret double %res ; CHECK-DAG: } ; CHECK: define internal double @__enzyme_done_truncate_op_func_64_52to32_23_f(double %x, double %y) { -; CHECK-DAG: %1 = call double @__enzyme_fprt_64_52_func_pow(double %x, double %y, i64 8, i64 23, i64 2) -; CHECK-DAG: %2 = call double @__enzyme_fprt_64_52_intr_llvm_pow_f64(double %x, double %y, i64 8, i64 23, i64 2) -; CHECK-DAG: %3 = call double @__enzyme_fprt_64_52_intr_llvm_powi_f64_i16(double %x, i16 2, i64 8, i64 23, i64 2) -; CHECK-DAG: %res = call double @__enzyme_fprt_64_52_binop_fadd(double %2, double %3, i64 8, i64 23, i64 2) +; CHECK-DAG: %1 = call double @__enzyme_fprt_64_52_func_pow(double %x, double %y, i64 8, i64 23, i64 2, {{.*}}i8{{.*}}) +; CHECK-DAG: %2 = call double @__enzyme_fprt_64_52_intr_llvm_pow_f64(double %x, double %y, i64 8, i64 23, i64 2, {{.*}}i8{{.*}}) +; CHECK-DAG: %3 = call double @__enzyme_fprt_64_52_intr_llvm_powi_f64_i16(double %x, i16 2, i64 8, i64 23, i64 2, {{.*}}i8{{.*}}) +; CHECK-DAG: %res = call double @__enzyme_fprt_64_52_binop_fadd(double %2, double %3, i64 8, i64 23, i64 2, {{.*}}i8{{.*}}) ; CHECK-DAG: call void @llvm.nvvm.barrier0() ; CHECK-DAG: ret double %res ; CHECK-DAG: } ; CHECK: define internal double @__enzyme_done_truncate_op_func_64_52to11_7_f(double %x, double %y) { -; CHECK-DAG: %1 = call double @__enzyme_fprt_64_52_func_pow(double %x, double %y, i64 3, i64 7, i64 2) -; CHECK-DAG: %2 = call double @__enzyme_fprt_64_52_intr_llvm_pow_f64(double %x, double %y, i64 3, i64 7, i64 2) -; CHECK-DAG: %3 = call double @__enzyme_fprt_64_52_intr_llvm_powi_f64_i16(double %x, i16 2, i64 3, i64 7, i64 2) -; CHECK-DAG: %res = call double @__enzyme_fprt_64_52_binop_fadd(double %2, double %3, i64 3, i64 7, i64 2) +; CHECK-DAG: %1 = call double @__enzyme_fprt_64_52_func_pow(double %x, double %y, i64 3, i64 7, i64 2, {{.*}}i8{{.*}}) +; CHECK-DAG: %2 = call double @__enzyme_fprt_64_52_intr_llvm_pow_f64(double %x, double %y, i64 3, i64 7, i64 2, {{.*}}i8{{.*}}) +; CHECK-DAG: %3 = call double @__enzyme_fprt_64_52_intr_llvm_powi_f64_i16(double %x, i16 2, i64 3, i64 7, i64 2, {{.*}}i8{{.*}}) +; CHECK-DAG: %res = call double @__enzyme_fprt_64_52_binop_fadd(double %2, double %3, i64 3, i64 7, i64 2, {{.*}}i8{{.*}}) ; CHECK-DAG: call void @llvm.nvvm.barrier0() ; CHECK-DAG: ret double %res ; CHECK-DAG: } diff --git a/enzyme/test/Enzyme/Truncate/simple.ll b/enzyme/test/Enzyme/Truncate/simple.ll index 747e268ae381..cd94c87aba46 100644 --- a/enzyme/test/Enzyme/Truncate/simple.ll +++ b/enzyme/test/Enzyme/Truncate/simple.ll @@ -36,21 +36,21 @@ entry: ; CHECK: define internal void @__enzyme_done_truncate_mem_func_64_52to32_23_f(double* %x) { ; CHECK-DAG: %y = load double, double* %x, align 8 -; CHECK-DAG: %m = call double @__enzyme_fprt_64_52_binop_fmul(double %y, double %y, i64 8, i64 23, i64 1) +; CHECK-DAG: %m = call double @__enzyme_fprt_64_52_binop_fmul(double %y, double %y, i64 8, i64 23, i64 1, {{.*}}i8{{.*}}) ; CHECK-DAG: store double %m, double* %x, align 8 ; CHECK-DAG: ret void ; CHECK-DAG: } ; CHECK: define internal void @__enzyme_done_truncate_op_func_64_52to32_23_f(double* %x) { ; CHECK-DAG: %y = load double, double* %x, align 8 -; CHECK-DAG: %m = call double @__enzyme_fprt_64_52_binop_fmul(double %y, double %y, i64 8, i64 23, i64 2) +; CHECK-DAG: %m = call double @__enzyme_fprt_64_52_binop_fmul(double %y, double %y, i64 8, i64 23, i64 2, {{.*}}i8{{.*}}) ; CHECK-DAG: store double %m, double* %x, align 8 ; CHECK-DAG: ret void ; CHECK-DAG: } ; CHECK: define internal void @__enzyme_done_truncate_op_func_64_52to11_7_f(double* %x) { ; CHECK-DAG: %y = load double, double* %x, align 8 -; CHECK-DAG: %m = call double @__enzyme_fprt_64_52_binop_fmul(double %y, double %y, i64 3, i64 7, i64 2) +; CHECK-DAG: %m = call double @__enzyme_fprt_64_52_binop_fmul(double %y, double %y, i64 3, i64 7, i64 2, {{.*}}i8{{.*}}) ; CHECK-DAG: store double %m, double* %x, align 8 ; CHECK-DAG: ret void ; CHECK-DAG: } diff --git a/enzyme/test/Enzyme/Truncate/value.ll b/enzyme/test/Enzyme/Truncate/value.ll index fa79e93440bb..1722b4fc1efc 100644 --- a/enzyme/test/Enzyme/Truncate/value.ll +++ b/enzyme/test/Enzyme/Truncate/value.ll @@ -18,10 +18,10 @@ entry: ; CHECK: define double @expand_tester(double %a, double* %c) { ; CHECK-NEXT: entry: -; CHECK-NEXT: %0 = call double @__enzyme_fprt_64_52_get(double %a, i64 8, i64 23, i64 1) +; CHECK-NEXT: %0 = call double @__enzyme_fprt_64_52_get(double %a, i64 8, i64 23, i64 1, {{.*}}i8{{.*}}) ; CHECK-NEXT: ret double %0 ; CHECK: define double @truncate_tester(double %a) { ; CHECK-NEXT: entry: -; CHECK-NEXT: %0 = call double @__enzyme_fprt_64_52_new(double %a, i64 8, i64 23, i64 1) +; CHECK-NEXT: %0 = call double @__enzyme_fprt_64_52_new(double %a, i64 8, i64 23, i64 1, {{.*}}i8{{.*}}) ; CHECK-NEXT: ret double %0 diff --git a/enzyme/test/Integration/Truncate/simple.cpp b/enzyme/test/Integration/Truncate/simple.cpp index 635a2e3bc04c..dff19c9a1e45 100644 --- a/enzyme/test/Integration/Truncate/simple.cpp +++ b/enzyme/test/Integration/Truncate/simple.cpp @@ -5,6 +5,9 @@ // RUN: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -DTRUNC_MEM -DTRUNC_OP -O2 %s -o %s.a.out %newLoadClangEnzyme -include enzyme/fprt/mpfr.h -lm -lmpfr && %s.a.out ; fi // RUN: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -g -DTRUNC_MEM -DTRUNC_OP -O2 %s -o %s.a.out %newLoadClangEnzyme -include enzyme/fprt/mpfr.h -lm -lmpfr && %s.a.out ; fi +// RUN: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -DTRUNC_MEM -DTRUNC_OP -O2 %s -o %s.a.out %newLoadClangEnzyme -include enzyme/fprt/mpfr-test.h -lm -lmpfr && %s.a.out ; fi +// RUN: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -g -DTRUNC_MEM -DTRUNC_OP -O2 %s -o %s.a.out %newLoadClangEnzyme -include enzyme/fprt/mpfr-test.h -lm -lmpfr && %s.a.out ; fi + #include #include "../test_utils.h" @@ -27,21 +30,34 @@ double intrinsics(double a, double b) { double constt(double a, double b) { return 2; } +void const_store(double *a) { + *a = 2.0; +} +double phinode(double a, double b, int n) { + double sum = 0; + for (int i = 0; i < n; i++) { + sum += (exp(a + b) - exp(a)) / b; + b /= 10; + } + return sum; +} double compute(double *A, double *B, double *C, int n) { for (int i = 0; i < n; i++) { C[i] = A[i] * 2 + B[i] * sqrt(A[i]); } return C[0]; } +double intcast(int a) { + double d = (double) a; + return d / 3.14; +} typedef double (*fty)(double *, double *, double *, int); typedef double (*fty2)(double, double); -extern fty __enzyme_truncate_mem_func_2(...); -extern fty2 __enzyme_truncate_mem_func(...); -extern fty __enzyme_truncate_op_func_2(...); -extern fty2 __enzyme_truncate_op_func(...); +template fty *__enzyme_truncate_mem_func(fty *, int, int); +template fty *__enzyme_truncate_op_func(fty *, int, int); extern double __enzyme_truncate_mem_value(...); extern double __enzyme_expand_mem_value(...); @@ -89,16 +105,36 @@ int main() { double trunc = __enzyme_expand_mem_value(__enzyme_truncate_mem_func(intrinsics, FROM, TO)(a, b), FROM, TO); APPROX_EQ(trunc, truth, 1e-5); } + { + double a = 2; + double b = 3; + double truth = constt(a, b); + a = __enzyme_truncate_mem_value(a, FROM, TO); + b = __enzyme_truncate_mem_value(b, FROM, TO); + double trunc = __enzyme_expand_mem_value(__enzyme_truncate_mem_func(constt, FROM, TO)(a, b), FROM, TO); + APPROX_EQ(trunc, truth, 1e-5); + } + { + double a = 2; + double b = 3; + double truth = phinode(a, b, 10); + a = __enzyme_truncate_mem_value(a, FROM, TO); + b = __enzyme_truncate_mem_value(b, FROM, TO); + double trunc = __enzyme_expand_mem_value(__enzyme_truncate_mem_func(phinode, FROM, TO)(a, b, 10), FROM, TO); + APPROX_EQ(trunc, truth, 20.0); + } + { + double truth = 0; + const_store(&truth); + double a = 0; + __enzyme_truncate_mem_func(const_store, FROM, TO)(&a); + a = __enzyme_expand_mem_value(a, FROM, TO); + APPROX_EQ(a, truth, 1e-5); + } + { + __enzyme_truncate_mem_func(intcast, FROM, TO)(64); + } #endif - // { - // double a = 2; - // double b = 3; - // double truth = intrinsics(a, b); - // a = __enzyme_truncate_mem_value(a, FROM, TO); - // b = __enzyme_truncate_mem_value(b, FROM, TO); - // double trunc = __enzyme_expand_mem_value(__enzyme_truncate_mem_func(constt, FROM, TO)(a, b), FROM, TO); - // APPROX_EQ(trunc, truth, 1e-5); - // } #ifdef TRUNC_OP { @@ -120,7 +156,7 @@ int main() { // B[i] = __enzyme_truncate_mem_value(B[i], 64, 32); // } - __enzyme_truncate_op_func_2(compute, 64, 32)(A, B, C, N); + __enzyme_truncate_op_func(compute, 64, 32)(A, B, C, N); // for (int i = 0; i < N; i++) { // C[i] = __enzyme_expand_mem_value(C[i], 64, 32); diff --git a/enzyme/test/Integration/Truncate/truncate-all-header.h b/enzyme/test/Integration/Truncate/truncate-all-header.h new file mode 100644 index 000000000000..3fd9f0780365 --- /dev/null +++ b/enzyme/test/Integration/Truncate/truncate-all-header.h @@ -0,0 +1,15 @@ +#ifndef TRUNCATE_ALL_HEADER_H_ +#define TRUNCATE_ALL_HEADER_H_ + +#include + +#define N 6 + +#define floatty double + +__attribute__((noinline)) static +floatty intrinsics2(floatty a, floatty b) { + return sin(a) * cos(b); +} + +#endif // TRUNCATE_ALL_HEADER_H_ diff --git a/enzyme/test/Integration/Truncate/truncate-all.cpp b/enzyme/test/Integration/Truncate/truncate-all.cpp index d5038d4750cb..818c2c603cac 100644 --- a/enzyme/test/Integration/Truncate/truncate-all.cpp +++ b/enzyme/test/Integration/Truncate/truncate-all.cpp @@ -16,16 +16,27 @@ // RUN: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -include enzyme/fprt/mpfr.h -O3 %s -o %s.a.out %newLoadClangEnzyme -mllvm --enzyme-truncate-all="11-52to3-7" -lmpfr -lm && %s.a.out | FileCheck --check-prefix TO_3_7 %s; fi // TO_3_7: 897581056.000000 -#include - +// RUN: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -g -include enzyme/fprt/mpfr-test.h -O3 %s -o %s.a.out %newLoadClangEnzyme -mllvm --enzyme-truncate-all="11-52to3-7" -lmpfr -lm && %s.a.out | FileCheck --check-prefix CHECK-LOCS %s; fi +// CHECK-LOCS: 0x[[op1:[0-9a-f]*]], {{.*}}truncate-all.cpp:[[op1loc:.*]] +// CHECK-LOCS-NEXT: 0x[[op2:[0-9a-f]*]], {{.*}}truncate-all.cpp:[[op2loc:.*]] +// CHECK-LOCS-NEXT: 0x[[op3:[0-9a-f]*]], {{.*}}truncate-all.cpp:[[op3loc:.*]] +// CHECK-LOCS-NEXT: 0x[[op4:[0-9a-f]*]], {{.*}}truncate-all-header.h:[[op4loc:.*]] +// CHECK-LOCS-NEXT: 0x[[op5:[0-9a-f]*]], {{.*}}truncate-all-header.h:[[op5loc:.*]] +// CHECK-LOCS-NEXT: 0x[[op6:[0-9a-f]*]], {{.*}}truncate-all-header.h:[[op6loc:.*]] +// CHECK-LOCS-NEXT: 0x[[op7:[0-9a-f]*]], {{.*}}truncate-all.cpp:[[op7loc:.*]] +// CHECK-LOCS-NEXT: 0x[[op1]], {{.*}}truncate-all.cpp:[[op1loc]] +// CHECK-LOCS-NEXT: 0x[[op2]], {{.*}}truncate-all.cpp:[[op2loc]] +// CHECK-LOCS-NEXT: 0x[[op3]], {{.*}}truncate-all.cpp:[[op3loc]] +// CHECK-LOCS-NEXT: 0x[[op4]], {{.*}}truncate-all-header.h:[[op4loc]] +// CHECK-LOCS-NEXT: 0x[[op5]], {{.*}}truncate-all-header.h:[[op5loc]] +// CHECK-LOCS-NEXT: 0x[[op6]], {{.*}}truncate-all-header.h:[[op6loc]] +// CHECK-LOCS-NEXT: 0x[[op7]], {{.*}}truncate-all.cpp:[[op7loc]] + + +#include "truncate-all-header.h" #include "../test_utils.h" -#define N 10 - -#define floatty double - - __attribute__((noinline)) floatty simple_add(floatty a, floatty b) { return a + b; @@ -35,6 +46,13 @@ floatty intrinsics(floatty a, floatty b) { return sqrt(a) * pow(b, 2); } __attribute__((noinline)) +floatty compute2(floatty *A, floatty *B, floatty *C, int n) { + for (int i = 0; i < n; i++) { + C[i] = A[i] / 2 + intrinsics2(A[i], simple_add(B[i] * 10000, 0.000001)); + } + return C[0]; +} +__attribute__((noinline)) floatty compute(floatty *A, floatty *B, floatty *C, int n) { for (int i = 0; i < n; i++) { C[i] = A[i] / 2 + intrinsics(A[i], simple_add(B[i] * 10000, 0.000001)); @@ -52,6 +70,9 @@ int main() { B[i] = 1 + i % 3; } + compute2(A, B, C, N); + for (int i = 0; i < N; i++) + C[i] = 0; compute(A, B, C, N); printf("%f\n", C[5]); } diff --git a/enzyme/test/Integration/Truncate/warnings.cpp b/enzyme/test/Integration/Truncate/warnings.cpp new file mode 100644 index 000000000000..b63636440d3d --- /dev/null +++ b/enzyme/test/Integration/Truncate/warnings.cpp @@ -0,0 +1,62 @@ +// RUN: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -c -DTRUNC_MEM -O2 %s -o /dev/null -emit-llvm %newLoadClangEnzyme -include enzyme/fprt/mpfr.h -Xclang -verify -Rpass=enzyme; fi +// RUN: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -c -DTRUNC_MEM -O2 -g %s -o /dev/null -emit-llvm %newLoadClangEnzyme -include enzyme/fprt/mpfr.h -Xclang -verify -Rpass=enzyme; fi +// COM: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -c -DTRUNC_OP -O2 %s -o /dev/null -emit-llvm %newLoadClangEnzyme -include enzyme/fprt/mpfr.h -Xclang -verify -Rpass=enzyme; fi +// COM: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -c -DTRUNC_OP -O2 -g %s -o /dev/null -emit-llvm %newLoadClangEnzyme -include enzyme/fprt/mpfr.h -Xclang -verify -Rpass=enzyme; fi + +#include +#include + +#define FROM 64 +#define TO 32 + +double bithack(double a) { + return *((int64_t *)&a) + 1; // expected-remark {{Will not follow FP through this cast.}}, expected-remark {{Will not follow FP through this cast.}} +} +__attribute__((noinline)) +void print_d(double a) { + printf("%f\n", a); // expected-remark {{Will not follow FP through this function call as the definition is not available.}} +} +__attribute__((noinline)) +float truncf(double a) { + return (float)a; // expected-remark {{Will not follow FP through this cast.}} +} + +double intrinsics(double a, double b) { + return bithack(a) * truncf(b); // expected-remark {{Will not follow FP through this cast.}} +} + +typedef double (*fty)(double *, double *, double *, int); + +typedef double (*fty2)(double, double); + +template fty *__enzyme_truncate_mem_func(fty *, int, int); +extern fty __enzyme_truncate_op_func_2(...); +extern fty2 __enzyme_truncate_op_func(...); +extern double __enzyme_truncate_mem_value(...); +extern double __enzyme_expand_mem_value(...); + + +int main() { + #ifdef TRUNC_MEM + { + double a = 2; + double b = 3; + a = __enzyme_truncate_mem_value(a, FROM, TO); + b = __enzyme_truncate_mem_value(b, FROM, TO); + double trunc = __enzyme_expand_mem_value(__enzyme_truncate_mem_func(intrinsics, FROM, TO)(a, b), FROM, TO); + } + { + double a = 2; + a = __enzyme_truncate_mem_value(a, FROM, TO); + __enzyme_truncate_mem_func(print_d, FROM, TO)(a); + } + #endif + #ifdef TRUNC_OP + { + double a = 2; + double b = 3; + double trunc = __enzyme_truncate_op_func(intrinsics, FROM, TO)(a, b); + } + #endif + +} diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 2e5ee993f988..26619edcdcd4 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -1787,10 +1787,17 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, bool prev = false; for (auto *nameI : *cast(pattern->getValueAsListInit("names"))) { - if (prev) - os << " ||\n "; - os << "funcName == " << cast(nameI)->getAsString() << ""; - prev = true; + auto nameIStr = cast(nameI)->getAsString(); + auto nameIStrFinite = "\"__" + + std::string(std::next(nameIStr.begin()), + std::prev(nameIStr.end())) + + "_finite\""; + for (auto nameIStrAll : {nameIStr, nameIStrFinite}) { + if (prev) + os << " ||\n "; + os << "funcName == " << nameIStrAll << ""; + prev = true; + } } origName = "call"; #if LLVM_VERSION_MAJOR >= 14