Skip to content

Add callable parameter based interface #56

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 3, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@ version = "1.6.1"

[deps]
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
@@ -17,6 +18,7 @@ Aqua = "0.8"
ComponentArrays = "0.15.11"
DifferentiationInterface = "0.6"
ForwardDiff = "0.10.36"
IntervalSets = "0.7.10"
JET = "0.8, 0.9"
Lux = "1"
LuxCore = "1"
@@ -31,7 +33,7 @@ SciMLSensitivity = "7.72"
SciMLStructures = "1.1.0"
StableRNGs = "1"
SymbolicIndexingInterface = "0.3.15"
Symbolics = "6.22"
Symbolics = "6.36"
Test = "1.10"
Zygote = "0.6.73"
julia = "1.10"
78 changes: 75 additions & 3 deletions src/ModelingToolkitNeuralNets.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
module ModelingToolkitNeuralNets

using ModelingToolkit: @parameters, @named, ODESystem, t_nounits
using IntervalSets: var".."
using ModelingToolkitStandardLibrary.Blocks: RealInputArray, RealOutputArray
using Symbolics: Symbolics, @register_array_symbolic, @wrapped
using LuxCore: stateless_apply
using Lux: Lux
using Random: Xoshiro
using ComponentArrays: ComponentArray

export NeuralNetworkBlock, multi_layer_feed_forward
export NeuralNetworkBlock, SymbolicNeuralNetwork, multi_layer_feed_forward, get_network

include("utils.jl")

@@ -32,16 +33,17 @@ function NeuralNetworkBlock(; n_input = 1, n_output = 1,

@parameters p[1:length(ca)] = Vector(ca)
@parameters T::typeof(typeof(ca))=typeof(ca) [tunable = false]
@parameters lux_model::typeof(chain) = chain

@named input = RealInputArray(nin = n_input)
@named output = RealOutputArray(nout = n_output)

out = stateless_apply(chain, input.u, lazyconvert(T, p))
out = stateless_apply(lux_model, input.u, lazyconvert(T, p))

eqs = [output.u ~ out]

ude_comp = ODESystem(
eqs, t_nounits, [], [p, T]; systems = [input, output], name)
eqs, t_nounits, [], [lux_model, p, T]; systems = [input, output], name)
return ude_comp
end

@@ -55,4 +57,74 @@ function lazyconvert(T, x::Symbolics.Arr)
Symbolics.array_term(convert, T, x, size = size(x))
end

"""
SymbolicNeuralNetwork(; n_input = 1, n_output = 1,
chain = multi_layer_feed_forward(n_input, n_output),
rng = Xoshiro(0),
init_params = Lux.initialparameters(rng, chain),
nn_name = :NN,
nn_p_name = :p,
eltype = Float64)

Create symbolic parameter for a neural network and one for its parameters.
Example:

```
chain = multi_layer_feed_forward(2, 2)
NN, p = SymbolicNeuralNetwork(; chain, n_input=2, n_output=2, rng = StableRNG(42))
```

The NN and p are symbolic parameters that can be used later as part of a system.
To change the name of the symbolic variables, use `nn_name` and `nn_p_name`.
To get the predictions of the neural network, use

```
pred ~ NN(input, p)
```

where `pred` and `input` are a symbolic vector variable with the lengths `n_output` and `n_input`.

To use this outside of an equation, you can get the default values for the symbols and make a similar call

```
defaults(sys)[sys.NN](input, nn_p)
```

where `sys` is a system (e.g. `ODESystem`) that contains `NN`, `input` is a vector of `n_input` length and
`nn_p` is a vector representing parameter values for the neural network.

To get the underlying Lux model you can use `get_network(defaults(sys)[sys.NN])` or
"""
function SymbolicNeuralNetwork(; n_input = 1, n_output = 1,
chain = multi_layer_feed_forward(n_input, n_output),
rng = Xoshiro(0),
init_params = Lux.initialparameters(rng, chain),
nn_name = :NN,
nn_p_name = :p,
eltype = Float64)
ca = ComponentArray{eltype}(init_params)
wrapper = StatelessApplyWrapper(chain, typeof(ca))

p = @parameters $(nn_p_name)[1:length(ca)] = Vector(ca)
NN = @parameters ($(nn_name)::typeof(wrapper))(..)[1:n_output] = wrapper

return only(NN), only(p)
end

struct StatelessApplyWrapper{NN}
lux_model::NN
T::DataType
end

function (wrapper::StatelessApplyWrapper)(input::AbstractArray, nn_p::AbstractVector)
stateless_apply(get_network(wrapper), input, convert(wrapper.T, nn_p))
end

function Base.show(io::IO, m::MIME"text/plain", wrapper::StatelessApplyWrapper)
printstyled(io, "LuxCore.stateless_apply wrapper for:\n", color = :gray)
show(io, m, get_network(wrapper))
end

get_network(wrapper::StatelessApplyWrapper) = wrapper.lux_model

end
35 changes: 32 additions & 3 deletions test/lotka_volterra.jl
Original file line number Diff line number Diff line change
@@ -51,12 +51,12 @@ chain = multi_layer_feed_forward(2, 2)

eqs = [connect(model.nn_in, nn.output)
connect(model.nn_out, nn.input)]

eqs = [model.nn_in.u ~ nn.output.u, model.nn_out.u ~ nn.input.u]
ude_sys = complete(ODESystem(
eqs, ModelingToolkit.t_nounits, systems = [model, nn],
name = :ude_sys))

sys = structural_simplify(ude_sys)
sys = structural_simplify(ude_sys, allow_symbolic = true)

prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys, [], (0, 1.0), [])

@@ -103,7 +103,7 @@ ps = (prob, sol_ref, get_vars, get_refs, set_x);
@test all(.!isnan.(∇l1))
@test !iszero(∇l1)

@test ∇l1≈∇l2 rtol=1e-2
@test ∇l1≈∇l2 rtol=1e-3
@test ∇l1≈∇l3 rtol=1e-5

op = OptimizationProblem(of, x0, ps)
@@ -135,3 +135,32 @@ res_sol = solve(res_prob, Rodas4(), saveat = sol_ref.t)
# plot!(res_sol, idxs = [sys.lotka.x, sys.lotka.y])

@test SciMLBase.successful_retcode(res_sol)

function lotka_ude2()
@variables t x(t)=3.1 y(t)=1.5 pred(t)[1:2]
@parameters α=1.3 [tunable = false] δ=1.8 [tunable = false]
chain = multi_layer_feed_forward(2, 2)
NN, p = SymbolicNeuralNetwork(; chain, n_input = 2, n_output = 2, rng = StableRNG(42))
Dt = ModelingToolkit.D_nounits

eqs = [pred ~ NN([x, y], p)
Dt(x) ~ α * x + pred[1]
Dt(y) ~ -δ * y + pred[2]]
return ODESystem(eqs, ModelingToolkit.t_nounits, name = :lotka)
end

sys2 = structural_simplify(lotka_ude2())

prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys2, [], (0, 1.0), [])

sol = solve(prob, Rodas5P(), abstol = 1e-10, reltol = 1e-8)

@test SciMLBase.successful_retcode(sol)

set_x2 = setp_oop(sys2, sys2.p)
ps2 = (prob, sol_ref, get_vars, get_refs, set_x2);
op2 = OptimizationProblem(of, x0, ps2)

res2 = solve(op2, Adam(), maxiters = 10000)

@test res.u ≈ res2.u