Skip to content

make realnvp and nsf layers as part of the pkg #53

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Expand All @@ -27,6 +29,8 @@ CUDA = "5"
DifferentiationInterface = "0.6, 0.7"
Distributions = "0.25"
DocStringExtensions = "0.9"
Flux = "0.16"
Functors = "0.5.2"
Optimisers = "0.2.16, 0.3, 0.4"
ProgressMeter = "1.0.0"
StatsBase = "0.33, 0.34"
Expand Down
1 change: 1 addition & 0 deletions example/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extras]
CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"
123 changes: 8 additions & 115 deletions example/demo_RealNVP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,114 +11,6 @@ using NormalizingFlows
include("SyntheticTargets.jl")
include("utils.jl")

##################################
# define affine coupling layer using Bijectors.jl interface
#################################
struct AffineCoupling <: Bijectors.Bijector
dim::Int
mask::Bijectors.PartitionMask
s::Flux.Chain
t::Flux.Chain
end

# let params track field s and t
@functor AffineCoupling (s, t)

function AffineCoupling(
dim::Int, # dimension of input
hdims::Int, # dimension of hidden units for s and t
mask_idx::AbstractVector, # index of dimensione that one wants to apply transformations on
)
cdims = length(mask_idx) # dimension of parts used to construct coupling law
s = mlp3(cdims, hdims, cdims)
t = mlp3(cdims, hdims, cdims)
mask = PartitionMask(dim, mask_idx)
return AffineCoupling(dim, mask, s, t)
end

function Bijectors.transform(af::AffineCoupling, x::AbstractVecOrMat)
# partition vector using 'af.mask::PartitionMask`
x₁, x₂, x₃ = partition(af.mask, x)
y₁ = x₁ .* af.s(x₂) .+ af.t(x₂)
return combine(af.mask, y₁, x₂, x₃)
end

function (af::AffineCoupling)(x::AbstractArray)
return transform(af, x)
end

function Bijectors.with_logabsdet_jacobian(af::AffineCoupling, x::AbstractVector)
x_1, x_2, x_3 = Bijectors.partition(af.mask, x)
y_1 = af.s(x_2) .* x_1 .+ af.t(x_2)
logjac = sum(log ∘ abs, af.s(x_2)) # this is a scalar
return combine(af.mask, y_1, x_2, x_3), logjac
end

function Bijectors.with_logabsdet_jacobian(af::AffineCoupling, x::AbstractMatrix)
x_1, x_2, x_3 = Bijectors.partition(af.mask, x)
y_1 = af.s(x_2) .* x_1 .+ af.t(x_2)
logjac = sum(log ∘ abs, af.s(x_2); dims = 1) # 1 × size(x, 2)
return combine(af.mask, y_1, x_2, x_3), vec(logjac)
end


function Bijectors.with_logabsdet_jacobian(
iaf::Inverse{<:AffineCoupling}, y::AbstractVector
)
af = iaf.orig
# partition vector using `af.mask::PartitionMask`
y_1, y_2, y_3 = partition(af.mask, y)
# inverse transformation
x_1 = (y_1 .- af.t(y_2)) ./ af.s(y_2)
logjac = -sum(log ∘ abs, af.s(y_2))
return combine(af.mask, x_1, y_2, y_3), logjac
end

function Bijectors.with_logabsdet_jacobian(
iaf::Inverse{<:AffineCoupling}, y::AbstractMatrix
)
af = iaf.orig
# partition vector using `af.mask::PartitionMask`
y_1, y_2, y_3 = partition(af.mask, y)
# inverse transformation
x_1 = (y_1 .- af.t(y_2)) ./ af.s(y_2)
logjac = -sum(log ∘ abs, af.s(y_2); dims = 1)
return combine(af.mask, x_1, y_2, y_3), vec(logjac)
end

###################
# an equivalent definition of AffineCoupling using Bijectors.Coupling
# (see https://github.com/TuringLang/Bijectors.jl/blob/74d52d4eda72a6149b1a89b72524545525419b3f/src/bijectors/coupling.jl#L188C1-L188C1)
###################

# struct AffineCoupling <: Bijectors.Bijector
# dim::Int
# mask::Bijectors.PartitionMask
# s::Flux.Chain
# t::Flux.Chain
# end

# # let params track field s and t
# @functor AffineCoupling (s, t)

# function AffineCoupling(dim, mask, s, t)
# return Bijectors.Coupling(θ -> Bijectors.Shift(t(θ)) ∘ Bijectors.Scale(s(θ)), mask)
# end

# function AffineCoupling(
# dim::Int, # dimension of input
# hdims::Int, # dimension of hidden units for s and t
# mask_idx::AbstractVector, # index of dimensione that one wants to apply transformations on
# )
# cdims = length(mask_idx) # dimension of parts used to construct coupling law
# s = mlp3(cdims, hdims, cdims)
# t = mlp3(cdims, hdims, cdims)
# mask = PartitionMask(dim, mask_idx)
# return AffineCoupling(dim, mask, s, t)
# end



##################################
# start demo
#################################
Expand All @@ -132,29 +24,30 @@ T = Float32
target = Banana(2, 1.0f0, 100.0f0)
logp = Base.Fix1(logpdf, target)


######################################
# learn the target using Affine coupling flow
######################################
@leaf MvNormal
q0 = MvNormal(zeros(T, 2), ones(T, 2))
q0 = MvNormal(zeros(T, 2), I)

d = 2
hdims = 32

# alternating the coupling layers
Ls = [AffineCoupling(d, hdims, [1]) ∘ AffineCoupling(d, hdims, [2]) for i in 1:3]
hdims = [16, 16]
nlayers = 3

flow = create_flow(Ls, q0)
# use NormalizingFlows.realnvp to create a RealNVP flow
flow = realnvp(q0, hdims, nlayers; paramtype=T)
flow_untrained = deepcopy(flow)


######################################
# start training
######################################
sample_per_iter = 64
sample_per_iter = 16

# callback function to log training progress
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype)
# TODO: now using AutoMooncake the example broke, but AutoZygote works, need to debug
adtype = ADTypes.AutoMooncake(; config = Mooncake.Config())
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000
flow_trained, stats, _ = train_flow(
Expand Down
99 changes: 1 addition & 98 deletions example/demo_neural_spline_flow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,104 +11,6 @@ using NormalizingFlows
include("SyntheticTargets.jl")
include("utils.jl")

##################################
# define neural spline layer using Bijectors.jl interface
#################################
"""
Neural Rational quadratic Spline layer

# References
[1] Durkan, C., Bekasov, A., Murray, I., & Papamakarios, G., Neural Spline Flows, CoRR, arXiv:1906.04032 [stat.ML], (2019).
"""
struct NeuralSplineLayer{T,A<:Flux.Chain} <: Bijectors.Bijector
dim::Int # dimension of input
K::Int # number of knots
n_dims_transferred::Int # number of dimensions that are transformed
nn::A # networks that parmaterize the knots and derivatives
B::T # bound of the knots
mask::Bijectors.PartitionMask
end

function NeuralSplineLayer(
dim::T1, # dimension of input
hdims::T1, # dimension of hidden units for s and t
K::T1, # number of knots
B::T2, # bound of the knots
mask_idx::AbstractVector{<:Int}, # index of dimensione that one wants to apply transformations on
) where {T1<:Int,T2<:Real}
num_of_transformed_dims = length(mask_idx)
input_dims = dim - num_of_transformed_dims

# output dim of the NN
output_dims = (3K - 1)*num_of_transformed_dims
# one big mlp that outputs all the knots and derivatives for all the transformed dimensions
nn = mlp3(input_dims, hdims, output_dims)

mask = Bijectors.PartitionMask(dim, mask_idx)
return NeuralSplineLayer(dim, K, num_of_transformed_dims, nn, B, mask)
end

@functor NeuralSplineLayer (nn,)

# define forward and inverse transformation
"""
Build a rational quadratic spline from the nn output
Bijectors.jl has implemented the inverse and logabsdetjac for rational quadratic spline

we just need to map the nn output to the knots and derivatives of the RQS
"""
function instantiate_rqs(nsl::NeuralSplineLayer, x::AbstractVector)
K, B = nsl.K, nsl.B
nnoutput = reshape(nsl.nn(x), nsl.n_dims_transferred, :)
ws = @view nnoutput[:, 1:K]
hs = @view nnoutput[:, (K + 1):(2K)]
ds = @view nnoutput[:, (2K + 1):(3K - 1)]
return Bijectors.RationalQuadraticSpline(ws, hs, ds, B)
end

function Bijectors.transform(nsl::NeuralSplineLayer, x::AbstractVector)
x_1, x_2, x_3 = Bijectors.partition(nsl.mask, x)
# instantiate rqs knots and derivatives
rqs = instantiate_rqs(nsl, x_2)
y_1 = Bijectors.transform(rqs, x_1)
return Bijectors.combine(nsl.mask, y_1, x_2, x_3)
end

function Bijectors.transform(insl::Inverse{<:NeuralSplineLayer}, y::AbstractVector)
nsl = insl.orig
y1, y2, y3 = partition(nsl.mask, y)
rqs = instantiate_rqs(nsl, y2)
x1 = Bijectors.transform(Inverse(rqs), y1)
return Bijectors.combine(nsl.mask, x1, y2, y3)
end

function (nsl::NeuralSplineLayer)(x::AbstractVector)
return Bijectors.transform(nsl, x)
end

# define logabsdetjac
function Bijectors.logabsdetjac(nsl::NeuralSplineLayer, x::AbstractVector)
x_1, x_2, _ = Bijectors.partition(nsl.mask, x)
rqs = instantiate_rqs(nsl, x_2)
logjac = logabsdetjac(rqs, x_1)
return logjac
end

function Bijectors.logabsdetjac(insl::Inverse{<:NeuralSplineLayer}, y::AbstractVector)
nsl = insl.orig
y1, y2, _ = partition(nsl.mask, y)
rqs = instantiate_rqs(nsl, y2)
logjac = logabsdetjac(Inverse(rqs), y1)
return logjac
end

function Bijectors.with_logabsdet_jacobian(nsl::NeuralSplineLayer, x::AbstractVector)
x_1, x_2, x_3 = Bijectors.partition(nsl.mask, x)
rqs = instantiate_rqs(nsl, x_2)
y_1, logjac = with_logabsdet_jacobian(rqs, x_1)
return Bijectors.combine(nsl.mask, y_1, x_2, x_3), logjac
end

##################################
# start demo
#################################
Expand Down Expand Up @@ -148,6 +50,7 @@ sample_per_iter = 64

# callback function to log training progress
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype)
# TODO: now using AutoMooncake the example broke, but AutoZygote works, need to debug
adtype = ADTypes.AutoMooncake(; config = Mooncake.Config())
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000
flow_trained, stats, _ = train_flow(
Expand Down
5 changes: 0 additions & 5 deletions example/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,6 @@ function mlp3(input_dim::Int, hidden_dims::Int, output_dim::Int; activation=Flux
)
end

function create_flow(Ls, q₀)
ts = reduce(∘, Ls)
return transformed(q₀, ts)
end

function compare_trained_and_untrained_flow(
flow_trained::Bijectors.MultivariateTransformed,
flow_untrained::Bijectors.MultivariateTransformed,
Expand Down
15 changes: 14 additions & 1 deletion src/NormalizingFlows.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
module NormalizingFlows

using ADTypes
using Bijectors
using Distributions
using LinearAlgebra
using Optimisers
using ProgressMeter
using Random
using StatsBase
using Bijectors
using Bijectors: PartitionMask, Inverse, combine, partition
using Functors
import DifferentiationInterface as DI

using DocStringExtensions
Expand Down Expand Up @@ -123,4 +125,15 @@ function _device_specific_rand(
return Random.rand(rng, td, n)
end


# interface of contructing common flow layers
include("flows/utils.jl")
include("flows/realnvp.jl")
include("flows/neuralspline.jl")

export create_flow
export RealNVP_layer, realnvp, AffineCoupling
export NeuralSplineLayer


end
Loading
Loading