-
Notifications
You must be signed in to change notification settings - Fork 226
Update to the [email protected] interface #2506
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
ed6946c
a94269d
a4711a9
3f8068b
222a638
57097f5
a42eea8
798f319
69a4972
cbcb8b5
081d6ff
a32a673
1bcec3e
b142832
061ec35
736bd3e
fd434d8
57108ee
8dc8067
297c32a
3010b5e
0c04434
17a8290
0e496c4
c1533a8
231d6e2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,50 +1,170 @@ | ||||||
|
||||||
module Variational | ||||||
|
||||||
using DistributionsAD: DistributionsAD | ||||||
using DynamicPPL: DynamicPPL | ||||||
using StatsBase: StatsBase | ||||||
using StatsFuns: StatsFuns | ||||||
using LogDensityProblems: LogDensityProblems | ||||||
using DynamicPPL | ||||||
using ADTypes | ||||||
using Distributions | ||||||
using LinearAlgebra | ||||||
using LogDensityProblems | ||||||
using Random | ||||||
|
||||||
using Random: Random | ||||||
import ..Turing: DEFAULT_ADTYPE, PROGRESS | ||||||
|
||||||
import AdvancedVI | ||||||
import Bijectors | ||||||
|
||||||
# Reexports | ||||||
using AdvancedVI: vi, ADVI, ELBO, elbo, TruncatedADAGrad, DecayedADAGrad | ||||||
export vi, ADVI, ELBO, elbo, TruncatedADAGrad, DecayedADAGrad | ||||||
|
||||||
""" | ||||||
make_logjoint(model::Model; weight = 1.0) | ||||||
Constructs the logjoint as a function of latent variables, i.e. the map z → p(x ∣ z) p(z). | ||||||
The weight used to scale the likelihood, e.g. when doing stochastic gradient descent one needs to | ||||||
use `DynamicPPL.MiniBatch` context to run the `Model` with a weight `num_total_obs / batch_size`. | ||||||
## Notes | ||||||
- For sake of efficiency, the returned function is closes over an instance of `VarInfo`. This means that you *might* run into some weird behaviour if you call this method sequentially using different types; if that's the case, just generate a new one for each type using `make_logjoint`. | ||||||
""" | ||||||
function make_logjoint(model::DynamicPPL.Model; weight=1.0) | ||||||
# setup | ||||||
using AdvancedVI: RepGradELBO, ScoreGradELBO, DoG, DoWG | ||||||
export vi, RepGradELBO, ScoreGradELBO, DoG, DoWG | ||||||
|
||||||
export meanfield_gaussian, fullrank_gaussian | ||||||
|
||||||
include("bijectors.jl") | ||||||
|
||||||
function make_logdensity(model::DynamicPPL.Model) | ||||||
weight = 1.0 | ||||||
ctx = DynamicPPL.MiniBatchContext(DynamicPPL.DefaultContext(), weight) | ||||||
f = DynamicPPL.LogDensityFunction(model, DynamicPPL.VarInfo(model), ctx) | ||||||
return Base.Fix1(LogDensityProblems.logdensity, f) | ||||||
return DynamicPPL.LogDensityFunction(model, DynamicPPL.VarInfo(model), ctx) | ||||||
end | ||||||
|
||||||
# objectives | ||||||
function (elbo::ELBO)( | ||||||
function initialize_gaussian_scale( | ||||||
rng::Random.AbstractRNG, | ||||||
alg::AdvancedVI.VariationalInference, | ||||||
q, | ||||||
model::DynamicPPL.Model, | ||||||
num_samples; | ||||||
weight=1.0, | ||||||
location::AbstractVector, | ||||||
scale::AbstractMatrix; | ||||||
num_samples::Int=10, | ||||||
num_max_trials::Int=10, | ||||||
reduce_factor=one(eltype(scale)) / 2, | ||||||
) | ||||||
prob = make_logdensity(model) | ||||||
ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob) | ||||||
varinfo = DynamicPPL.VarInfo(model) | ||||||
|
||||||
n_trial = 0 | ||||||
while true | ||||||
q = AdvancedVI.MvLocationScale(location, scale, Normal()) | ||||||
b = Bijectors.bijector(model; varinfo=varinfo) | ||||||
q_trans = Bijectors.transformed(q, Bijectors.inverse(b)) | ||||||
energy = mean(ℓπ, eachcol(rand(rng, q_trans, num_samples))) | ||||||
|
||||||
if isfinite(energy) | ||||||
return scale | ||||||
elseif n_trial == num_max_trials | ||||||
error("Could not find an initial") | ||||||
end | ||||||
|
||||||
scale = reduce_factor * scale | ||||||
n_trial += 1 | ||||||
end | ||||||
end | ||||||
|
||||||
function meanfield_gaussian( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
rng::Random.AbstractRNG, | ||||||
model::DynamicPPL.Model; | ||||||
location::Union{Nothing,<:AbstractVector}=nothing, | ||||||
scale::Union{Nothing,<:Diagonal}=nothing, | ||||||
kwargs..., | ||||||
) | ||||||
varinfo = DynamicPPL.VarInfo(model) | ||||||
# Use linked `varinfo` to determine the correct number of parameters. | ||||||
# TODO: Replace with `length` once this is implemented for `VarInfo`. | ||||||
varinfo_linked = DynamicPPL.link(varinfo, model) | ||||||
num_params = length(varinfo_linked[:]) | ||||||
|
||||||
μ = if isnothing(location) | ||||||
zeros(num_params) | ||||||
else | ||||||
@assert length(location) == num_params "Length of the provided location vector, $(length(location)), does not match dimension of the target distribution, $(num_params)." | ||||||
location | ||||||
end | ||||||
|
||||||
L = if isnothing(scale) | ||||||
initialize_gaussian_scale(rng, model, μ, Diagonal(ones(num_params)); kwargs...) | ||||||
else | ||||||
@assert size(scale) == (num_params, num_params) "Dimensions of the provided scale matrix, $(size(scale)), does not match the dimension of the target distribution, $(num_params)." | ||||||
L = scale | ||||||
end | ||||||
|
||||||
q = AdvancedVI.MeanFieldGaussian(μ, L) | ||||||
b = Bijectors.bijector(model; varinfo=varinfo) | ||||||
return Bijectors.transformed(q, Bijectors.inverse(b)) | ||||||
end | ||||||
|
||||||
function meanfield_gaussian( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
model::DynamicPPL.Model; | ||||||
location::Union{Nothing,<:AbstractVector}=nothing, | ||||||
scale::Union{Nothing,<:Diagonal}=nothing, | ||||||
kwargs..., | ||||||
) | ||||||
return elbo(rng, alg, q, make_logjoint(model; weight=weight), num_samples; kwargs...) | ||||||
return meanfield_gaussian(Random.default_rng(), model; location, scale, kwargs...) | ||||||
end | ||||||
|
||||||
function fullrank_gaussian( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
rng::Random.AbstractRNG, | ||||||
model::DynamicPPL.Model; | ||||||
location::Union{Nothing,<:AbstractVector}=nothing, | ||||||
scale::Union{Nothing,<:LowerTriangular}=nothing, | ||||||
kwargs..., | ||||||
) | ||||||
varinfo = DynamicPPL.VarInfo(model) | ||||||
# Use linked `varinfo` to determine the correct number of parameters. | ||||||
# TODO: Replace with `length` once this is implemented for `VarInfo`. | ||||||
varinfo_linked = DynamicPPL.link(varinfo, model) | ||||||
num_params = length(varinfo_linked[:]) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we get the dimentionality via cc @mkarikom |
||||||
|
||||||
μ = if isnothing(location) | ||||||
zeros(num_params) | ||||||
else | ||||||
@assert length(location) == num_params "Length of the provided location vector, $(length(location)), does not match dimension of the target distribution, $(num_params)." | ||||||
location | ||||||
end | ||||||
|
||||||
L = if isnothing(scale) | ||||||
L0 = LowerTriangular(Matrix{Float64}(I, num_params, num_params)) | ||||||
initialize_gaussian_scale(rng, model, μ, L0; kwargs...) | ||||||
else | ||||||
@assert size(scale) == (num_params, num_params) "Dimensions of the provided scale matrix, $(size(scale)), does not match the dimension of the target distribution, $(num_params)." | ||||||
scale | ||||||
end | ||||||
|
||||||
q = AdvancedVI.FullRankGaussian(μ, L) | ||||||
b = Bijectors.bijector(model; varinfo=varinfo) | ||||||
return Bijectors.transformed(q, Bijectors.inverse(b)) | ||||||
end | ||||||
|
||||||
# VI algorithms | ||||||
include("advi.jl") | ||||||
function fullrank_gaussian( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
model::DynamicPPL.Model; | ||||||
location::Union{Nothing,<:AbstractVector}=nothing, | ||||||
scale::Union{Nothing,<:LowerTriangular}=nothing, | ||||||
kwargs..., | ||||||
) | ||||||
return fullrank_gaussian(Random.default_rng(), model; location, scale, kwargs...) | ||||||
end | ||||||
|
||||||
function vi( | ||||||
model::DynamicPPL.Model, | ||||||
q::Bijectors.TransformedDistribution, | ||||||
n_iterations::Int; | ||||||
objective=RepGradELBO(10; entropy=AdvancedVI.ClosedFormEntropyZeroGradient()), | ||||||
show_progress::Bool=PROGRESS[], | ||||||
optimizer=AdvancedVI.DoWG(), | ||||||
averager=AdvancedVI.PolynomialAveraging(), | ||||||
operator=AdvancedVI.ProximalLocationScaleEntropy(), | ||||||
adtype::ADTypes.AbstractADType=DEFAULT_ADTYPE, | ||||||
kwargs..., | ||||||
) | ||||||
return AdvancedVI.optimize( | ||||||
make_logdensity(model), | ||||||
objective, | ||||||
q, | ||||||
n_iterations; | ||||||
show_progress=show_progress, | ||||||
adtype, | ||||||
optimizer, | ||||||
averager, | ||||||
operator, | ||||||
kwargs..., | ||||||
) | ||||||
end | ||||||
|
||||||
end |
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.