Skip to content

Commit 124bbb6

Browse files
committed
better fix
1 parent 88c7527 commit 124bbb6

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

src/norecompile.jl

+5-3
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,12 @@ function wrapfun_iip(ff,
5858
dualT1 = ArrayInterface.promote_eltype(T1, dualT)
5959
dualT2 = ArrayInterface.promote_eltype(T2, dualT)
6060
dualT4 = dualgen(T4)
61+
dualT4_T = promote_dual(dualT4, dualT)
6162

62-
iip_arglists = (Tuple{T1, T2, T3, T4}, # primal
63-
Tuple{dualT1, dualT2, T3, T4}, # vjp
64-
Tuple{dualT1, T2, T3, dualT4}, # tgrad
63+
iip_arglists = (Tuple{T1, T2, T3, T4}, # primal
64+
Tuple{dualT1, dualT2, T3, T4}, # vjp
65+
Tuple{dualT1, T2, T3, dualT4}, # tgrad
66+
Tuple{dualT1, T2, T3, dualT4_T}, # tgrad inside gradient wrt initial conditions
6567
)
6668

6769
iip_returnlists = ntuple(x -> Nothing, length(iip_arglists))

0 commit comments

Comments
 (0)