Skip to content

Commit 0d20bca

Browse files
test: fix initialization, account for Initial parameters in test
1 parent 420299c commit 0d20bca

File tree

1 file changed

+16
-9
lines changed

1 file changed

+16
-9
lines changed

test/lotka_volterra.jl

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ using SymbolicIndexingInterface
88
using Optimization
99
using OptimizationOptimisers: Adam
1010
using SciMLStructures
11-
using SciMLStructures: Tunable
11+
using SciMLStructures: Tunable, canonicalize
12+
using PreallocationTools
1213
using ForwardDiff
1314
using StableRNGs
1415

@@ -51,7 +52,7 @@ eqs = [connect(model.nn_in, nn.output)
5152

5253
ude_sys = complete(ODESystem(
5354
eqs, ModelingToolkit.t_nounits, systems = [model, nn],
54-
name = :ude_sys, defaults = [nn.input.u => [0.0, 0.0]]))
55+
name = :ude_sys))
5556

5657
sys = structural_simplify(ude_sys)
5758

@@ -61,13 +62,18 @@ model_true = structural_simplify(lotka_true())
6162
prob_true = ODEProblem{true, SciMLBase.FullSpecialize}(model_true, [], (0, 1.0), [])
6263
sol_ref = solve(prob_true, Rodas4())
6364

64-
x0 = reduce(vcat, getindex.((default_values(sys),), tunable_parameters(sys)))
65+
x0 = default_values(sys)[nn.p]
6566

6667
get_vars = getu(sys, [sys.lotka.x, sys.lotka.y])
6768
get_refs = getu(model_true, [model_true.x, model_true.y])
68-
69-
function loss(x, (prob, sol_ref, get_vars, get_refs))
70-
new_p = SciMLStructures.replace(Tunable(), prob.p, x)
69+
set_x = setu(sys, nn.p)
70+
diffcache = DiffCache(canonicalize(Tunable(), parameter_values(prob))[1])
71+
72+
function loss(x, (prob, sol_ref, get_vars, get_refs, set_x, diffcache))
73+
tunables = get_tmp(diffcache, x)
74+
copyto!(tunables, canonicalize(Tunable(), prob.p)[1])
75+
new_p = SciMLStructures.replace(Tunable(), prob.p, tunables)
76+
set_x(new_p, x)
7177
new_prob = remake(prob, p = new_p, u0 = eltype(x).(prob.u0))
7278
ts = sol_ref.t
7379
new_sol = solve(new_prob, Rodas4(), saveat = ts)
@@ -87,14 +93,14 @@ end
8793

8894
of = OptimizationFunction{true}(loss, AutoForwardDiff())
8995

90-
ps = (prob, sol_ref, get_vars, get_refs);
96+
ps = (prob, sol_ref, get_vars, get_refs, set_x, diffcache);
9197

9298
@test_call target_modules=(ModelingToolkitNeuralNets,) loss(x0, ps)
9399
@test_opt target_modules=(ModelingToolkitNeuralNets,) loss(x0, ps)
94100

95101
@test all(.!isnan.(ForwardDiff.gradient(Base.Fix2(of, ps), x0)))
96102

97-
op = OptimizationProblem(of, x0, (prob, sol_ref, get_vars, get_refs))
103+
op = OptimizationProblem(of, x0, ps)
98104

99105
# using Plots
100106

@@ -114,7 +120,8 @@ res = solve(op, Adam(), maxiters = 5000)#, callback = plot_cb)
114120

115121
@test res.objective < 1
116122

117-
res_p = SciMLStructures.replace(Tunable(), prob.p, res.u)
123+
res_p = copy(prob.p)
124+
set_x(res_p, res.u)
118125
res_prob = remake(prob, p = res_p)
119126
res_sol = solve(res_prob, Rodas4(), saveat = sol_ref.t)
120127

0 commit comments

Comments
 (0)