@@ -8,7 +8,8 @@ using SymbolicIndexingInterface
8
8
using Optimization
9
9
using OptimizationOptimisers: Adam
10
10
using SciMLStructures
11
- using SciMLStructures: Tunable
11
+ using SciMLStructures: Tunable, canonicalize
12
+ using PreallocationTools
12
13
using ForwardDiff
13
14
using StableRNGs
14
15
@@ -51,7 +52,7 @@ eqs = [connect(model.nn_in, nn.output)
51
52
52
53
ude_sys = complete (ODESystem (
53
54
eqs, ModelingToolkit. t_nounits, systems = [model, nn],
54
- name = :ude_sys , defaults = [nn . input . u => [ 0.0 , 0.0 ]] ))
55
+ name = :ude_sys ))
55
56
56
57
sys = structural_simplify (ude_sys)
57
58
@@ -61,13 +62,18 @@ model_true = structural_simplify(lotka_true())
61
62
prob_true = ODEProblem {true, SciMLBase.FullSpecialize} (model_true, [], (0 , 1.0 ), [])
62
63
sol_ref = solve (prob_true, Rodas4 ())
63
64
64
- x0 = reduce (vcat, getindex .(( default_values (sys),), tunable_parameters (sys)))
65
+ x0 = default_values (sys)[nn . p]
65
66
66
67
get_vars = getu (sys, [sys. lotka. x, sys. lotka. y])
67
68
get_refs = getu (model_true, [model_true. x, model_true. y])
68
-
69
- function loss (x, (prob, sol_ref, get_vars, get_refs))
70
- new_p = SciMLStructures. replace (Tunable (), prob. p, x)
69
+ set_x = setu (sys, nn. p)
70
+ diffcache = DiffCache (canonicalize (Tunable (), parameter_values (prob))[1 ])
71
+
72
+ function loss (x, (prob, sol_ref, get_vars, get_refs, set_x, diffcache))
73
+ tunables = get_tmp (diffcache, x)
74
+ copyto! (tunables, canonicalize (Tunable (), prob. p)[1 ])
75
+ new_p = SciMLStructures. replace (Tunable (), prob. p, tunables)
76
+ set_x (new_p, x)
71
77
new_prob = remake (prob, p = new_p, u0 = eltype (x).(prob. u0))
72
78
ts = sol_ref. t
73
79
new_sol = solve (new_prob, Rodas4 (), saveat = ts)
87
93
88
94
of = OptimizationFunction {true} (loss, AutoForwardDiff ())
89
95
90
- ps = (prob, sol_ref, get_vars, get_refs);
96
+ ps = (prob, sol_ref, get_vars, get_refs, set_x, diffcache );
91
97
92
98
@test_call target_modules= (ModelingToolkitNeuralNets,) loss (x0, ps)
93
99
@test_opt target_modules= (ModelingToolkitNeuralNets,) loss (x0, ps)
94
100
95
101
@test all (.! isnan .(ForwardDiff. gradient (Base. Fix2 (of, ps), x0)))
96
102
97
- op = OptimizationProblem (of, x0, (prob, sol_ref, get_vars, get_refs) )
103
+ op = OptimizationProblem (of, x0, ps )
98
104
99
105
# using Plots
100
106
@@ -114,7 +120,8 @@ res = solve(op, Adam(), maxiters = 5000)#, callback = plot_cb)
114
120
115
121
@test res. objective < 1
116
122
117
- res_p = SciMLStructures. replace (Tunable (), prob. p, res. u)
123
+ res_p = copy (prob. p)
124
+ set_x (res_p, res. u)
118
125
res_prob = remake (prob, p = res_p)
119
126
res_sol = solve (res_prob, Rodas4 (), saveat = sol_ref. t)
120
127
0 commit comments