Skip to content

Commit c298f70

Browse files
wsmosesUbuntuvchuravy
authored
Fix shadow augmented rematerialization (rust-lang#557)
* Shadow promotable remat * Preserve forward shadow if non-float Co-authored-by: Ubuntu <[email protected]> Co-authored-by: Valentin Churavy <[email protected]>
1 parent 798c15d commit c298f70

File tree

3 files changed

+104
-7
lines changed

3 files changed

+104
-7
lines changed

enzyme/Enzyme/AdjointGenerator.h

+5-3
Original file line numberDiff line numberDiff line change
@@ -10443,10 +10443,12 @@ class AdjointGenerator
1044310443
#endif
1044410444
tapeIdx = 0;
1044510445

10446-
if (subretType == DIFFE_TYPE::DUP_ARG ||
10447-
subretType == DIFFE_TYPE::DUP_NONEED) {
10446+
if (!orig->getType()->isVoidTy()) {
1044810447
returnIdx = 1;
10449-
differetIdx = 2;
10448+
if (subretType == DIFFE_TYPE::DUP_ARG ||
10449+
subretType == DIFFE_TYPE::DUP_NONEED) {
10450+
differetIdx = 2;
10451+
}
1045010452
}
1045110453

1045210454
} else {

enzyme/Enzyme/GradientUtils.h

+23-4
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,28 @@ class GradientUtils : public CacheUtility {
611611
todo.push_back(std::make_pair(I, (Value *)cur));
612612
}
613613
} else if (auto load = dyn_cast<LoadInst>(cur)) {
614+
615+
// If loaded value is an int or pointer, may need
616+
// to preserve initialization within the primal.
617+
auto TT = TR.query(load)[{-1}];
618+
if (!TT.isFloat()) {
619+
// ok to duplicate in forward /
620+
// reverse if it is a stack or GC allocation.
621+
// Said memory will still be shadow initialized.
622+
StringRef funcName = "";
623+
if (auto CI = dyn_cast<CallInst>(V))
624+
if (Function *originCall = getFunctionFromCall(CI))
625+
funcName = originCall->getName();
626+
if (isa<AllocaInst>(V) || hasMetadata(V, "enzyme_fromstack") ||
627+
funcName == "jl_alloc_array_1d" ||
628+
funcName == "jl_alloc_array_2d" ||
629+
funcName == "jl_alloc_array_3d" || funcName == "jl_array_copy" ||
630+
funcName == "julia.gc_alloc_obj") {
631+
primalInitializationOfShadow = true;
632+
} else {
633+
shadowpromotable = false;
634+
}
635+
}
614636
loads.insert(load);
615637
} else if (auto store = dyn_cast<StoreInst>(cur)) {
616638
// TODO only add store to shadow iff non float type
@@ -770,10 +792,7 @@ class GradientUtils : public CacheUtility {
770792
outer = getAncestor(outer, OrigLI.getLoopFor(S->getParent()));
771793
}
772794

773-
if (!shadowpromotable)
774-
return;
775-
776-
if (!isConstantValue(V)) {
795+
if (shadowpromotable && !isConstantValue(V)) {
777796
backwardsOnlyShadows[V] = ShadowRematerializer(
778797
stores, frees, primalInitializationOfShadow, outer);
779798
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -adce -simplifycfg -S | FileCheck %s
2+
3+
; ModuleID = 'ode-unopt.ll'
4+
source_filename = "ode.cpp"
5+
target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
6+
target triple = "x86_64-unknown-linux-gnu"
7+
8+
define double @inner(double* %data) {
9+
entry:
10+
%d = load double, double* %data, align 8
11+
%r = fadd double %d, %d
12+
ret double %r
13+
}
14+
15+
define double @inner2(double* %data) {
16+
entry:
17+
%d = load double, double* %data, align 8
18+
%r = fsub double %d, %d
19+
ret double %r
20+
}
21+
22+
define double @sub(double* %in, i64 %v) {
23+
entry:
24+
%a = alloca double (double*)*, i64 2, align 8
25+
store double (double*)* @inner, double (double*)** %a, align 8
26+
%a1 = getelementptr double (double*)*, double (double*)** %a, i64 1
27+
store double (double*)* @inner2, double (double*)** %a1, align 8
28+
%fa = getelementptr double (double*)*, double (double*)** %a, i64 %v
29+
%f = load double (double*)*, double (double*)** %fa, align 8
30+
%r = call double %f(double* %in)
31+
ret double %r
32+
}
33+
34+
define void @outer(double* %in, i64 %v) {
35+
entry:
36+
%r = call double @sub(double* %in, i64 %v)
37+
store double %r, double* %in, align 8
38+
ret void
39+
}
40+
41+
define void @caller(double* %in, double* %d_in) {
42+
entry:
43+
call void (...) @__enzyme_autodiff(void (double*, i64)* nonnull @outer, double* %in, double* %d_in, i64 0)
44+
ret void
45+
}
46+
47+
declare void @__enzyme_autodiff(...)
48+
49+
; CHECK: define internal { i8*, double } @augmented_sub(double* %in, double* %"in'", i64 %v)
50+
; CHECK-NEXT: entry:
51+
; CHECK-NEXT: %0 = alloca { i8*, double }
52+
; CHECK-NEXT: %1 = getelementptr inbounds { i8*, double }, { i8*, double }* %0, i32 0, i32 0
53+
; CHECK-NEXT: %2 = alloca i8, i64 16, align 8
54+
; CHECK-NEXT: %"malloccall'mi" = alloca i8, i64 16, align 8
55+
; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* nonnull dereferenceable(16) dereferenceable_or_null(16) %"malloccall'mi", i8 0, i64 16, i1 false)
56+
; CHECK-NEXT: %"a'ipc" = bitcast i8* %"malloccall'mi" to double (double*)*
57+
; CHECK-NEXT: %a = bitcast i8* %2 to double (double*)**
58+
; CHECK-NEXT: store double (double*)* bitcast ({ { i8*, double } (double*, double*)*, void (double*, double*, double, i8*)* }* @"_enzyme_reverse_inner'" to double (double*)*), double (double*)** %"a'ipc", align 8
59+
; CHECK-NEXT: store double (double*)* @inner, double (double*)** %a, align 8
60+
; CHECK-NEXT: %"a1'ipg" = getelementptr double (double*)*, double (double*)** %"a'ipc", i64 1
61+
; CHECK-NEXT: %a1 = getelementptr double (double*)*, double (double*)** %a, i64 1
62+
; CHECK-NEXT: store double (double*)* bitcast ({ { i8*, double } (double*, double*)*, void (double*, double*, double, i8*)* }* @"_enzyme_reverse_inner2'" to double (double*)*), double (double*)** %"a1'ipg", align 8
63+
; CHECK-NEXT: store double (double*)* @inner2, double (double*)** %a1, align 8
64+
; CHECK-NEXT: %"fa'ipg" = getelementptr double (double*)*, double (double*)** %"a'ipc", i64 %v
65+
; CHECK-NEXT: %"f'ipl" = load double (double*)*, double (double*)** %"fa'ipg", align 8
66+
; CHECK-NEXT: %3 = bitcast double (double*)* %"f'ipl" to { i8*, double } (double*, double*)**
67+
; CHECK-NEXT: %4 = load { i8*, double } (double*, double*)*, { i8*, double } (double*, double*)** %3
68+
; CHECK-NEXT: %r_augmented = call { i8*, double } %4(double* %in, double* %"in'")
69+
; CHECK-NEXT: %subcache = extractvalue { i8*, double } %r_augmented, 0
70+
; CHECK-NEXT: store i8* %subcache, i8** %1
71+
; CHECK-NEXT: %r = extractvalue { i8*, double } %r_augmented, 1
72+
; CHECK-NEXT: %5 = getelementptr inbounds { i8*, double }, { i8*, double }* %0, i32 0, i32 1
73+
; CHECK-NEXT: store double %r, double* %5
74+
; CHECK-NEXT: %6 = load { i8*, double }, { i8*, double }* %0
75+
; CHECK-NEXT: ret { i8*, double } %6
76+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)