-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replacing the EnsembleProblem with an equivalent simple for loop breaks the training #54
Comments
Here is a MWE in the context of LatentDiffEq.jl. Later I'll create a more general (minimal) MWE. using LatentDiffEq
using Flux
using OrdinaryDiffEq
using DiffEqSensitivity
using Random
import LatentDiffEq.Decoder
import LatentDiffEq.diffeq_layer
Random.seed!(3)
struct Pendulum{P,S,T,K}
prob::P
solver::S
sensealg::T
kwargs::K
function Pendulum(; kwargs...)
# Parameters and initial conditions only
# used to initialize the ODE problem
u₀ = Float32[1.0, 1.0]
p = Float32[1.]
tspan = (0.f0, 1.f0)
# Define differential equations
function f!(du, u, p, t)
x, y = u
G = 10.0f0
L = p[1]
du[1] = y
du[2] = -G/L*sin(x)
end
# Build ODE Problem
prob = ODEProblem(f!, u₀, tspan, p)
# Chose a solver and sensitivity algorithm
solver = Tsit5()
sensalg = ForwardDiffSensitivity()
# sensalg = BacksolveAdjoint()
# sensalg = InterpolatingAdjoint()
# sensalg = QuadratureAdjoint()
P = typeof(prob)
S = typeof(solver)
T = typeof(sensalg)
K = typeof(kwargs)
new{P,S,T,K}(prob, solver, sensalg, kwargs)
end
end
model_type = GOKU()
diffeq = Pendulum()
input_dim = 784
encoder_layers, decoder_layers = default_layers(model_type, input_dim, diffeq)
model = LatentDiffEqModel(model_type, encoder_layers, decoder_layers)
ps = Flux.params(model)
l̂ = (rand(Float32, 2, 64), rand(Float32, 1, 64))
t = range(0.f0, step=0.05, length=50)
loss(model, l̂, t) = sum(diffeq_layer(model.decoder, l̂, t) .- 1)
function evaluate(model, l̂, t)
out = diffeq_layer(model.decoder, l̂, t)
l_direct = loss(model, l̂, t)
l, back = Flux.pullback(ps) do
loss(model, l̂, t)
end
return out, l_direct, l, back
end
function diffeq_layer(decoder::Decoder{GOKU}, l̂, t)
ẑ₀, θ̂ = l̂
prob = decoder.diffeq.prob
solver = decoder.diffeq.solver
sensealg = decoder.diffeq.sensealg
kwargs = decoder.diffeq.kwargs
# Function definition for ensemble problem
prob_func(prob,i,repeat) = remake(prob, u0=ẑ₀[:,i], p = θ̂[:,i])
# Check if solve was successful, if not, return NaNs to avoid problems with dimensions matches
output_func(sol, i) = sol.retcode == :Success ? (Array(sol), false) : (fill(NaN32,(size(ẑ₀, 1), length(t))), false)
## Adapt problem to given time span and create ensemble problem definition
prob = remake(prob; tspan = (t[1],t[end]))
ens_prob = EnsembleProblem(prob, prob_func = prob_func, output_func = output_func)
## Solve
ẑ = solve(ens_prob, solver, EnsembleSerial(); sensealg = sensealg, trajectories = size(θ̂, 2), saveat = t, kwargs...)
ẑ = permutedims(ẑ, [1,3,2])
return ẑ
end
res1 = evaluate(model, l̂, t)
function diffeq_layer(decoder::Decoder{GOKU}, l̂, t)
ẑ₀, θ̂ = l̂
prob = decoder.diffeq.prob
solver = decoder.diffeq.solver
sensealg = decoder.diffeq.sensealg
kwargs = decoder.diffeq.kwargs
prob = remake(prob; tspan = (t[1],t[end]))
sols = Array{Float32,2}[]
for i in 1:size(ẑ₀,2)
prob = remake(prob, u0=ẑ₀[:,i], p = θ̂[:,i])
sol = solve(prob, solver; sensealg = sensealg, saveat = t, kwargs...)
push!(sols, Array(sol))
end
ẑ = Flux.stack(sols, 2)
return ẑ
end
res2 = evaluate(model, l̂, t)
@show res1 .== res2
@show res1[2]
@show res1[3]
@show res2[2]
@show res2[3] If res1 .== res2 = (true, true, true, false)
res1[2] = -6916.1006f0
res1[3] = -6916.1006f0
res2[2] = -6916.1006f0
res2[3] = -6916.1006f0 If res1 .== res2 = (true, true, false, false)
res1[2] = -6916.1006f0
res1[3] = -6916.099838162956
res2[2] = -6916.1006f0
res2[3] = -6916.1006f0 So in both cases the Note that I'm using (issue#54) pkg> st
Status `~/Documents/GOKU_experiments/issue#54/Project.toml`
[41bf760c] DiffEqSensitivity v6.58.0
[587475ba] Flux v0.12.6
[5e00f16f] LatentDiffEq v0.2.5 `https://github.com/gabrevaya/LatentDiffEq.jl.git#master`
[1dea7af3] OrdinaryDiffEq v5.64.0
[9a3f8284] Random |
I reported the more general MWE in SciML/SciMLSensitivity.jl#611. |
If we change the differential equations solving inside the
diffeq_layer
from the currentEnsembleProblem
to a simple for loop, which forward pass gives the exact same result, the whole training breaks. You can quickly try this by overriding thediffeq_layer
function and running the pendulum example or tutorial:For reference, this is the original
diffeq_layer
function:The forward output is exactly the same however the training of the model malfunctions. After a few epochs, the loss function tends to greatly increase ( >> 1e6 after 50 epochs in the default pendulum example).
I'll try to get a MWE and compare the gradients. This might be hint of a sensitivity analysis, zygote or primitives issue, which could be the reason why the pendulum example in the Python implementation of GOKU nets converges faster and is more robust under different random seeds when using exactly the same architecture, hyperparameters and initializations.
Maybe I'm doing something wrong in terms of DiffEqFlux with the simple for loop version? @ChrisRackauckas, a priori do you see something wrong here?
The text was updated successfully, but these errors were encountered: