Skip to content

Commit c847652

Browse files
Merge pull request #56 from SciML/smc/refactor
Add callable parameter based interface
2 parents 7466069 + cc3b075 commit c847652

File tree

3 files changed

+110
-7
lines changed

3 files changed

+110
-7
lines changed

Project.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "1.6.1"
55

66
[deps]
77
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
8+
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
89
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
910
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
1011
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
@@ -17,6 +18,7 @@ Aqua = "0.8"
1718
ComponentArrays = "0.15.11"
1819
DifferentiationInterface = "0.6"
1920
ForwardDiff = "0.10.36"
21+
IntervalSets = "0.7.10"
2022
JET = "0.8, 0.9"
2123
Lux = "1"
2224
LuxCore = "1"
@@ -31,7 +33,7 @@ SciMLSensitivity = "7.72"
3133
SciMLStructures = "1.1.0"
3234
StableRNGs = "1"
3335
SymbolicIndexingInterface = "0.3.15"
34-
Symbolics = "6.22"
36+
Symbolics = "6.36"
3537
Test = "1.10"
3638
Zygote = "0.6.73"
3739
julia = "1.10"

src/ModelingToolkitNeuralNets.jl

+75-3
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
module ModelingToolkitNeuralNets
22

33
using ModelingToolkit: @parameters, @named, ODESystem, t_nounits
4+
using IntervalSets: var".."
45
using ModelingToolkitStandardLibrary.Blocks: RealInputArray, RealOutputArray
56
using Symbolics: Symbolics, @register_array_symbolic, @wrapped
67
using LuxCore: stateless_apply
78
using Lux: Lux
89
using Random: Xoshiro
910
using ComponentArrays: ComponentArray
1011

11-
export NeuralNetworkBlock, multi_layer_feed_forward
12+
export NeuralNetworkBlock, SymbolicNeuralNetwork, multi_layer_feed_forward, get_network
1213

1314
include("utils.jl")
1415

@@ -32,16 +33,17 @@ function NeuralNetworkBlock(; n_input = 1, n_output = 1,
3233

3334
@parameters p[1:length(ca)] = Vector(ca)
3435
@parameters T::typeof(typeof(ca))=typeof(ca) [tunable = false]
36+
@parameters lux_model::typeof(chain) = chain
3537

3638
@named input = RealInputArray(nin = n_input)
3739
@named output = RealOutputArray(nout = n_output)
3840

39-
out = stateless_apply(chain, input.u, lazyconvert(T, p))
41+
out = stateless_apply(lux_model, input.u, lazyconvert(T, p))
4042

4143
eqs = [output.u ~ out]
4244

4345
ude_comp = ODESystem(
44-
eqs, t_nounits, [], [p, T]; systems = [input, output], name)
46+
eqs, t_nounits, [], [lux_model, p, T]; systems = [input, output], name)
4547
return ude_comp
4648
end
4749

@@ -55,4 +57,74 @@ function lazyconvert(T, x::Symbolics.Arr)
5557
Symbolics.array_term(convert, T, x, size = size(x))
5658
end
5759

60+
"""
61+
SymbolicNeuralNetwork(; n_input = 1, n_output = 1,
62+
chain = multi_layer_feed_forward(n_input, n_output),
63+
rng = Xoshiro(0),
64+
init_params = Lux.initialparameters(rng, chain),
65+
nn_name = :NN,
66+
nn_p_name = :p,
67+
eltype = Float64)
68+
69+
Create symbolic parameter for a neural network and one for its parameters.
70+
Example:
71+
72+
```
73+
chain = multi_layer_feed_forward(2, 2)
74+
NN, p = SymbolicNeuralNetwork(; chain, n_input=2, n_output=2, rng = StableRNG(42))
75+
```
76+
77+
The NN and p are symbolic parameters that can be used later as part of a system.
78+
To change the name of the symbolic variables, use `nn_name` and `nn_p_name`.
79+
To get the predictions of the neural network, use
80+
81+
```
82+
pred ~ NN(input, p)
83+
```
84+
85+
where `pred` and `input` are a symbolic vector variable with the lengths `n_output` and `n_input`.
86+
87+
To use this outside of an equation, you can get the default values for the symbols and make a similar call
88+
89+
```
90+
defaults(sys)[sys.NN](input, nn_p)
91+
```
92+
93+
where `sys` is a system (e.g. `ODESystem`) that contains `NN`, `input` is a vector of `n_input` length and
94+
`nn_p` is a vector representing parameter values for the neural network.
95+
96+
To get the underlying Lux model you can use `get_network(defaults(sys)[sys.NN])` or
97+
"""
98+
function SymbolicNeuralNetwork(; n_input = 1, n_output = 1,
99+
chain = multi_layer_feed_forward(n_input, n_output),
100+
rng = Xoshiro(0),
101+
init_params = Lux.initialparameters(rng, chain),
102+
nn_name = :NN,
103+
nn_p_name = :p,
104+
eltype = Float64)
105+
ca = ComponentArray{eltype}(init_params)
106+
wrapper = StatelessApplyWrapper(chain, typeof(ca))
107+
108+
p = @parameters $(nn_p_name)[1:length(ca)] = Vector(ca)
109+
NN = @parameters ($(nn_name)::typeof(wrapper))(..)[1:n_output] = wrapper
110+
111+
return only(NN), only(p)
112+
end
113+
114+
struct StatelessApplyWrapper{NN}
115+
lux_model::NN
116+
T::DataType
117+
end
118+
119+
function (wrapper::StatelessApplyWrapper)(input::AbstractArray, nn_p::AbstractVector)
120+
stateless_apply(get_network(wrapper), input, convert(wrapper.T, nn_p))
121+
end
122+
123+
function Base.show(io::IO, m::MIME"text/plain", wrapper::StatelessApplyWrapper)
124+
printstyled(io, "LuxCore.stateless_apply wrapper for:\n", color = :gray)
125+
show(io, m, get_network(wrapper))
126+
end
127+
128+
get_network(wrapper::StatelessApplyWrapper) = wrapper.lux_model
129+
58130
end

test/lotka_volterra.jl

+32-3
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,12 @@ chain = multi_layer_feed_forward(2, 2)
5151

5252
eqs = [connect(model.nn_in, nn.output)
5353
connect(model.nn_out, nn.input)]
54-
54+
eqs = [model.nn_in.u ~ nn.output.u, model.nn_out.u ~ nn.input.u]
5555
ude_sys = complete(ODESystem(
5656
eqs, ModelingToolkit.t_nounits, systems = [model, nn],
5757
name = :ude_sys))
5858

59-
sys = structural_simplify(ude_sys)
59+
sys = structural_simplify(ude_sys, allow_symbolic = true)
6060

6161
prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys, [], (0, 1.0), [])
6262

@@ -103,7 +103,7 @@ ps = (prob, sol_ref, get_vars, get_refs, set_x);
103103
@test all(.!isnan.(∇l1))
104104
@test !iszero(∇l1)
105105

106-
@test ∇l1∇l2 rtol=1e-2
106+
@test ∇l1∇l2 rtol=1e-3
107107
@test ∇l1∇l3 rtol=1e-5
108108

109109
op = OptimizationProblem(of, x0, ps)
@@ -135,3 +135,32 @@ res_sol = solve(res_prob, Rodas4(), saveat = sol_ref.t)
135135
# plot!(res_sol, idxs = [sys.lotka.x, sys.lotka.y])
136136

137137
@test SciMLBase.successful_retcode(res_sol)
138+
139+
function lotka_ude2()
140+
@variables t x(t)=3.1 y(t)=1.5 pred(t)[1:2]
141+
@parameters α=1.3 [tunable = false] δ=1.8 [tunable = false]
142+
chain = multi_layer_feed_forward(2, 2)
143+
NN, p = SymbolicNeuralNetwork(; chain, n_input = 2, n_output = 2, rng = StableRNG(42))
144+
Dt = ModelingToolkit.D_nounits
145+
146+
eqs = [pred ~ NN([x, y], p)
147+
Dt(x) ~ α * x + pred[1]
148+
Dt(y) ~ -δ * y + pred[2]]
149+
return ODESystem(eqs, ModelingToolkit.t_nounits, name = :lotka)
150+
end
151+
152+
sys2 = structural_simplify(lotka_ude2())
153+
154+
prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys2, [], (0, 1.0), [])
155+
156+
sol = solve(prob, Rodas5P(), abstol = 1e-10, reltol = 1e-8)
157+
158+
@test SciMLBase.successful_retcode(sol)
159+
160+
set_x2 = setp_oop(sys2, sys2.p)
161+
ps2 = (prob, sol_ref, get_vars, get_refs, set_x2);
162+
op2 = OptimizationProblem(of, x0, ps2)
163+
164+
res2 = solve(op2, Adam(), maxiters = 10000)
165+
166+
@test res.u res2.u

0 commit comments

Comments
 (0)