Skip to content

Commit eb4ec67

Browse files
author
manyfeatures
committed
adds lagrangian nn and simlple example
1 parent 9e1182e commit eb4ec67

File tree

4 files changed

+75
-1
lines changed

4 files changed

+75
-1
lines changed

Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2626
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
2727
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2828
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
29+
GenericLinearAlgebra = "14197337-ba66-59df-a3e3-ca00e7dcff7a"
2930

3031
[compat]
3132
Adapt = "3"
@@ -48,6 +49,7 @@ TerminalLoggers = "0.1"
4849
Zygote = "0.5, 0.6"
4950
ZygoteRules = "0.2"
5051
julia = "1.5"
52+
GenericLinearAlgebra = "0.2.5"
5153

5254
[extras]
5355
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"

docs/src/examples/lagrangian_nn.jl

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# One point test
2+
using Flux, ReverseDiff, LagrangianNN
3+
4+
m, k, b = 1, 1, 1
5+
6+
X = rand(2,1)
7+
Y = -k.*X[1]/m
8+
9+
g = Chain(Dense(2, 10, σ), Dense(10,1))
10+
model = LagrangianNN(g)
11+
params = model.params
12+
re = model.re
13+
14+
# some toy loss function
15+
function loss(x, y, p)
16+
nn = x -> model(x,p)
17+
out = sum((y .- (nn(x))).^2)
18+
out
19+
end
20+
opt = ADAM(0.01)
21+
epochs = 100
22+
23+
for epoch in 1:epochs
24+
x, y = X, Y
25+
gs = ReverseDiff.gradient(p -> loss(x, y, p), params)
26+
Flux.Optimise.update!(opt, params, gs)
27+
@show loss(x,y,params)
28+
end

src/DiffEqFlux.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module DiffEqFlux
22

33
using GalacticOptim, DataInterpolations, DiffEqBase, DiffResults, DiffEqSensitivity,
44
Distributions, ForwardDiff, Flux, Requires, Adapt, LinearAlgebra, RecursiveArrayTools,
5-
StaticArrays, Base.Iterators, Printf, Zygote
5+
StaticArrays, Base.Iterators, Printf, Zygote, GenericLinearAlgebra
66

77
using DistributionsAD
88
import ProgressLogging, ZygoteRules
@@ -82,11 +82,13 @@ include("tensor_product_basis.jl")
8282
include("tensor_product_layer.jl")
8383
include("collocation.jl")
8484
include("hnn.jl")
85+
include("lnn.jl")
8586
include("multiple_shooting.jl")
8687

8788
export diffeq_fd, diffeq_rd, diffeq_adjoint
8889
export DeterministicCNF, FFJORD, NeuralODE, NeuralDSDE, NeuralSDE, NeuralCDDE, NeuralDAE, NeuralODEMM, TensorLayer, AugmentedNDELayer, SplineLayer, NeuralHamiltonianDE
8990
export HamiltonianNN
91+
export LagrangianNN
9092
export ChebyshevBasis, SinBasis, CosBasis, FourierBasis, LegendreBasis, PolynomialBasis
9193
export neural_ode, neural_ode_rd
9294
export neural_dmsde

src/lnn.jl

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
"""
2+
Constructs a Lagrangian Neural Network [1].
3+
4+
References:
5+
[1] Miles Cranmer, Sam Greydanus, Stephan Hoyer, Peter Battaglia, David Spergel, and Shirley Ho.Lagrangian Neural Networks.
6+
InICLR 2020 Workshop on Integration of Deep Neural Modelsand Differential Equations, 2020.
7+
"""
8+
9+
struct LagrangianNN
10+
model
11+
re
12+
params
13+
14+
# Define inner constructor method
15+
function LagrangianNN(model; p = nothing)
16+
_p, re = Flux.destructure(model)
17+
if p === nothing
18+
p = _p
19+
end
20+
return new(model, re, p)
21+
end
22+
end
23+
24+
function (nn::LagrangianNN)(x, p = nn.params)
25+
@assert size(x,1) % 2 === 0 # velocity df should be equal to coords degree of freedom
26+
M = div(size(x,1), 2) # number of velocities degrees of freedom
27+
re = nn.re
28+
hess = x -> Zygote.hessian_reverse(x->sum(re(p)(x)), x) # we have to compute the whole hessian
29+
hess = hess(x)[M+1:end, M+1:end] # takes only velocities
30+
inv_hess = GenericLinearAlgebra.pinv(hess)
31+
32+
_grad_q = x -> Zygote.gradient(x->sum(re(p)(x)), x)[end]
33+
_grad_q = _grad_q(x)[1:M,:] # take only coord derivatives
34+
out1 =_grad_q
35+
36+
# Second term
37+
_grad_qv = x -> Zygote.gradient(x->sum(re(p)(x)), x)[end]
38+
_jac_qv = x -> Zygote.jacobian(x->_grad_qv(x), x)[end]
39+
out2 = _jac_qv(x)[1:M,M+1:end] * x[M+1:end] # take only dqdq_dot derivatives
40+
41+
return inv_hess * (out1 .+ out2)
42+
end

0 commit comments

Comments
 (0)