1
1
module ModelingToolkitNeuralNets
2
2
3
3
using ModelingToolkit: @parameters , @named , ODESystem, t_nounits
4
+ using IntervalSets: var".."
4
5
using ModelingToolkitStandardLibrary. Blocks: RealInputArray, RealOutputArray
5
6
using Symbolics: Symbolics, @register_array_symbolic , @wrapped
6
7
using LuxCore: stateless_apply
7
8
using Lux: Lux
8
9
using Random: Xoshiro
9
10
using ComponentArrays: ComponentArray
10
11
11
- export NeuralNetworkBlock, multi_layer_feed_forward
12
+ export NeuralNetworkBlock, SymbolicNeuralNetwork, multi_layer_feed_forward, get_network
12
13
13
14
include (" utils.jl" )
14
15
@@ -32,16 +33,17 @@ function NeuralNetworkBlock(; n_input = 1, n_output = 1,
32
33
33
34
@parameters p[1 : length (ca)] = Vector (ca)
34
35
@parameters T:: typeof (typeof (ca))= typeof (ca) [tunable = false ]
36
+ @parameters lux_model:: typeof (chain) = chain
35
37
36
38
@named input = RealInputArray (nin = n_input)
37
39
@named output = RealOutputArray (nout = n_output)
38
40
39
- out = stateless_apply (chain , input. u, lazyconvert (T, p))
41
+ out = stateless_apply (lux_model , input. u, lazyconvert (T, p))
40
42
41
43
eqs = [output. u ~ out]
42
44
43
45
ude_comp = ODESystem (
44
- eqs, t_nounits, [], [p, T]; systems = [input, output], name)
46
+ eqs, t_nounits, [], [lux_model, p, T]; systems = [input, output], name)
45
47
return ude_comp
46
48
end
47
49
@@ -55,4 +57,74 @@ function lazyconvert(T, x::Symbolics.Arr)
55
57
Symbolics. array_term (convert, T, x, size = size (x))
56
58
end
57
59
60
+ """
61
+ SymbolicNeuralNetwork(; n_input = 1, n_output = 1,
62
+ chain = multi_layer_feed_forward(n_input, n_output),
63
+ rng = Xoshiro(0),
64
+ init_params = Lux.initialparameters(rng, chain),
65
+ nn_name = :NN,
66
+ nn_p_name = :p,
67
+ eltype = Float64)
68
+
69
+ Create symbolic parameter for a neural network and one for its parameters.
70
+ Example:
71
+
72
+ ```
73
+ chain = multi_layer_feed_forward(2, 2)
74
+ NN, p = SymbolicNeuralNetwork(; chain, n_input=2, n_output=2, rng = StableRNG(42))
75
+ ```
76
+
77
+ The NN and p are symbolic parameters that can be used later as part of a system.
78
+ To change the name of the symbolic variables, use `nn_name` and `nn_p_name`.
79
+ To get the predictions of the neural network, use
80
+
81
+ ```
82
+ pred ~ NN(input, p)
83
+ ```
84
+
85
+ where `pred` and `input` are a symbolic vector variable with the lengths `n_output` and `n_input`.
86
+
87
+ To use this outside of an equation, you can get the default values for the symbols and make a similar call
88
+
89
+ ```
90
+ defaults(sys)[sys.NN](input, nn_p)
91
+ ```
92
+
93
+ where `sys` is a system (e.g. `ODESystem`) that contains `NN`, `input` is a vector of `n_input` length and
94
+ `nn_p` is a vector representing parameter values for the neural network.
95
+
96
+ To get the underlying Lux model you can use `get_network(defaults(sys)[sys.NN])` or
97
+ """
98
+ function SymbolicNeuralNetwork (; n_input = 1 , n_output = 1 ,
99
+ chain = multi_layer_feed_forward (n_input, n_output),
100
+ rng = Xoshiro (0 ),
101
+ init_params = Lux. initialparameters (rng, chain),
102
+ nn_name = :NN ,
103
+ nn_p_name = :p ,
104
+ eltype = Float64)
105
+ ca = ComponentArray {eltype} (init_params)
106
+ wrapper = StatelessApplyWrapper (chain, typeof (ca))
107
+
108
+ p = @parameters $ (nn_p_name)[1 : length (ca)] = Vector (ca)
109
+ NN = @parameters ($ (nn_name):: typeof (wrapper))(.. )[1 : n_output] = wrapper
110
+
111
+ return only (NN), only (p)
112
+ end
113
+
114
+ struct StatelessApplyWrapper{NN}
115
+ lux_model:: NN
116
+ T:: DataType
117
+ end
118
+
119
+ function (wrapper:: StatelessApplyWrapper )(input:: AbstractArray , nn_p:: AbstractVector )
120
+ stateless_apply (get_network (wrapper), input, convert (wrapper. T, nn_p))
121
+ end
122
+
123
+ function Base. show (io:: IO , m:: MIME"text/plain" , wrapper:: StatelessApplyWrapper )
124
+ printstyled (io, " LuxCore.stateless_apply wrapper for:\n " , color = :gray )
125
+ show (io, m, get_network (wrapper))
126
+ end
127
+
128
+ get_network (wrapper:: StatelessApplyWrapper ) = wrapper. lux_model
129
+
58
130
end
0 commit comments