Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 117 additions & 6 deletions src/enzyme_ad/jax/Implementations/EnzymeXLAAutoDiffOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===- CHLOAutoDiffOpInterfaceImpl.cpp - Interface external model --------===//
//===- EnzymeXLAAutoDiffOpInterfaceImpl.cpp - Interface external model ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand All @@ -7,16 +7,21 @@
//===----------------------------------------------------------------------===//
//
// This file contains the external model implementation of the automatic
// differentiation op interfaces for the upstream MLIR arithmetic dialect.
// differentiation op interfaces for the EnzymeXLA dialect.
//
//===----------------------------------------------------------------------===//

#include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h"
#include "Enzyme/MLIR/Interfaces/AutoDiffOpInterface.h"
#include "Enzyme/MLIR/Interfaces/GradientUtils.h"
#include "Enzyme/MLIR/Interfaces/GradientUtilsReverse.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/RegionUtils.h"
#include "src/enzyme_ad/jax/Implementations/SHLOGenericBatchOpInterface.h"

#include "Dialect/Ops.h"
Expand Down Expand Up @@ -69,12 +74,92 @@ struct GPUWrapperOpEnzymeOpsRemover
if (gradients.empty() && pushedCaches.empty())
return success();

if (gradients.size())
return failure();
llvm::MapVector<Value, CacheInfo> cachesMap;
for (auto &it : *wrapOp.getBody()) {
Operation *op = &it;
if (auto pushOp = dyn_cast<enzyme::PushOp>(op)) {
CacheInfo info(pushOp.getCache());
if (cachesMap.contains(pushOp.getValue()))
info = info.merge(cachesMap.lookup(pushOp.getValue()), rewriter);
cachesMap[pushOp.getValue()] = info;
}
}
SmallVector<CacheInfo> caches =
llvm::map_to_vector(cachesMap, [](auto p) { return std::get<1>(p); });

if (caches.empty())
return success();

if (pushedCaches.size())
return failure();
SetVector<Value> visited;
getUsedValuesDefinedAbove(wrapOp.getBodyRegion(), visited);
SmallVector<Value> frontier = llvm::map_to_vector(
caches, [](CacheInfo info) { return info.pushedValue(); });
SetVector<Operation *> opsToMove;
// Traverse backward from pushed values to find operations that the pushed
// value depends on
while (!frontier.empty()) {
Value v = frontier.back();
Operation *definingOp = v.getDefiningOp();
frontier.pop_back();

if (!definingOp)
continue;

// Assume allocations and frees are legal to move
if (hasEffect<MemoryEffects::Read>(definingOp) ||
hasEffect<MemoryEffects::Write>(definingOp)) {
definingOp->emitError() << "cannot move op with side effects";
return failure();
}
opsToMove.insert(definingOp);

for (Value operand : definingOp->getOperands()) {
if (visited.contains(operand))
continue;

frontier.push_back(operand);
visited.insert(operand);
}
}

// Move the push and dependent values outside of the wrapper
OpBuilder::InsertionGuard guard(rewriter);
IRMapping map;
rewriter.setInsertionPoint(wrapOp);
for (Operation *toMove : llvm::reverse(opsToMove)) {
Operation *cloned = rewriter.clone(*toMove, map);
toMove->replaceAllUsesWith(cloned->getResults());

if (auto allocOp = dyn_cast<memref::AllocOp>(cloned)) {
// Assume GPU allocations need to be in address space 1
auto gpuAlloc = gpu::AllocOp::create(
rewriter, allocOp.getLoc(),
*allocOp.getType().clonePtrWith(rewriter.getI64IntegerAttr(1),
std::nullopt),
/*asyncDependencies=*/ValueRange(), allocOp.getDynamicSizes(),
/*symbolOperands=*/ValueRange());
allocOp.replaceAllUsesWith(gpuAlloc.getResult(0));
rewriter.eraseOp(allocOp);
}
}

for (auto &info : caches) {
rewriter.moveOpBefore(info.pushOp, wrapOp);
auto revWrapper = info.popOp->getParentOfType<enzymexla::GPUWrapperOp>();
assert(revWrapper && "failed to find reverse gpu_wrapper");
rewriter.moveOpBefore(info.popOp, revWrapper);

for (auto user : info.popOp.getResult().getUsers()) {
if (isa<memref::DeallocOp>(user)) {
rewriter.eraseOp(user);
}
}
rewriter.setInsertionPointAfter(revWrapper);
gpu::DeallocOp::create(rewriter, wrapOp.getLoc(), TypeRange(),
info.popOp.getResult());
}

return success();
// TODO need to convert to gpu allocations and conversion/copy

/*
Expand Down Expand Up @@ -193,6 +278,31 @@ struct GPUWrapperOpInterfaceReverse
MGradientUtilsReverse *gutils) const {}
};

class Pointer2MemrefRev : public ReverseAutoDiffOpInterface::ExternalModel<
Pointer2MemrefRev, enzymexla::Pointer2MemrefOp> {
public:
LogicalResult createReverseModeAdjoint(Operation *orig, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
return success();
}

SmallVector<Value> cacheValues(Operation *orig,
MGradientUtilsReverse *gutils) const {
return SmallVector<Value>();
}

void createShadowValues(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils) const {
auto p2m = cast<enzymexla::Pointer2MemrefOp>(op);
if (!gutils->isConstantValue(p2m)) {
Value dres = gutils->invertPointerM(p2m.getSource(), builder);
Value shadow = builder.create<enzymexla::Pointer2MemrefOp>(
p2m.getLoc(), p2m.getType(), dres);
gutils->setInvertedPointer(p2m, shadow);
}
}
};
} // namespace

void mlir::enzyme::registerEnzymeXLADialectAutoDiffInterface(
Expand All @@ -201,6 +311,7 @@ void mlir::enzyme::registerEnzymeXLADialectAutoDiffInterface(
registerInterfaces(context);
GPUWrapperOp::attachInterface<GPUWrapperOpInterfaceReverse>(*context);
GPUWrapperOp::attachInterface<GPUWrapperOpEnzymeOpsRemover>(*context);
enzymexla::Pointer2MemrefOp::attachInterface<Pointer2MemrefRev>(*context);

// Register batching interfaces
JITCallOp::attachInterface<SHLOGenericBatchOpInterface<JITCallOp>>(
Expand Down
9 changes: 5 additions & 4 deletions src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2717,20 +2717,23 @@ class ConvertDeallocOpToGpuRuntimeCallPattern
auto i32 = rewriter.getIntegerType(32);
auto moduleOp = deallocOp->getParentOfType<ModuleOp>();

auto ptr1ty = LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
auto ptrty = LLVM::LLVMPointerType::get(rewriter.getContext());

if (backend == "cuda") {
auto one = LLVM::ConstantOp::create(rewriter, loc, i64,
rewriter.getI64IntegerAttr(1));

Type tys[] = {ptr1ty};
Type tys[] = {ptrty};
auto cudaFreeFn =
LLVM::lookupOrCreateFn(rewriter, moduleOp, "cudaFree", tys, i32);
if (failed(cudaFreeFn)) {
llvm::errs() << " cudafree already exists with different types\n";
return failure();
}

if (cast<LLVM::LLVMPointerType>(ptr.getType()).getAddressSpace() != 0)
ptr = LLVM::AddrSpaceCastOp::create(rewriter, loc, ptrty, ptr);

Value args[] = {
ptr,
};
Expand All @@ -2750,8 +2753,6 @@ class ConvertDeallocOpToGpuRuntimeCallPattern
};
LLVM::CallOp::create(rewriter, loc, freeFunc.value(), args);
} else if (backend.starts_with("xla")) {
auto ptrty = LLVM::LLVMPointerType::get(rewriter.getContext());

// handle, ptr
Type tys[] = {ptrty, ptrty};

Expand Down
Loading