Skip to content

Commit ddea446

Browse files
Layered ForwardDiff Update
Should fix #141
1 parent 9767b29 commit ddea446

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

src/solve/flux.jl

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -71,22 +71,24 @@ function __solve(prob::OptimizationProblem, opt, data = DEFAULT_DATA;
7171
# here should be build_solution to create the output message
7272
end
7373

74-
function Flux.update!(x::AbstractArray, x̄::AbstractArray{<:ForwardDiff.Dual})
75-
x .-=
74+
function Flux.update!(opt, xs::Flux.Zygote.Params, gs)
75+
update!(opt, xs[1], gs)
7676
end
7777

78-
function Flux.update!(x::AbstractArray, x̄)
79-
x .-= getindex.(ForwardDiff.partials.(x̄),1)
80-
end
78+
@require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin
79+
function Flux.update!(x::AbstractArray, x̄::AbstractArray{<:ForwardDiff.Dual})
80+
x .-=
81+
end
8182

82-
function Flux.update!(opt, x, x̄)
83-
x .-= Flux.Optimise.apply!(opt, x, x̄)
84-
end
83+
function Flux.update!(x::AbstractArray, x̄)
84+
x .-= getindex.(ForwardDiff.partials.(x̄),1)
85+
end
8586

86-
function Flux.update!(opt, x, x̄::AbstractArray{<:ForwardDiff.Dual})
87-
x .-= Flux.Optimise.apply!(opt, x, getindex.(ForwardDiff.partials.(x̄),1))
88-
end
87+
function Flux.update!(opt, x, x̄)
88+
x .-= Flux.Optimise.apply!(opt, x, )
89+
end
8990

90-
function Flux.update!(opt, xs::Flux.Zygote.Params, gs)
91-
update!(opt, xs[1], gs)
91+
function Flux.update!(opt, x, x̄::AbstractArray{<:ForwardDiff.Dual})
92+
x .-= Flux.Optimise.apply!(opt, x, getindex.(ForwardDiff.partials.(x̄),1))
93+
end
9294
end

0 commit comments

Comments
 (0)