Skip to content

Update to the [email protected] interface #2506

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
ed6946c
update to match the [email protected] interface
Red-Portal Mar 14, 2025
a94269d
run formatter
Red-Portal Mar 14, 2025
a4711a9
run formatter
Red-Portal Mar 14, 2025
3f8068b
run formatter
Red-Portal Mar 14, 2025
222a638
run formatter
Red-Portal Mar 14, 2025
57097f5
run formatter
Red-Portal Mar 14, 2025
a42eea8
run formatter
Red-Portal Mar 14, 2025
798f319
run formatter
Red-Portal Mar 14, 2025
69a4972
run formatter
Red-Portal Mar 14, 2025
cbcb8b5
run formatter
Red-Portal Mar 14, 2025
081d6ff
remove plotting
Red-Portal Mar 14, 2025
a32a673
Merge branch 'update_advancedvi' of github.com:TuringLang/Turing.jl i…
Red-Portal Mar 14, 2025
1bcec3e
fix formatting
Red-Portal Mar 14, 2025
b142832
fix formatting
Red-Portal Mar 14, 2025
061ec35
fix formatting
Red-Portal Mar 14, 2025
736bd3e
remove unused dependency
Red-Portal Mar 14, 2025
fd434d8
Merge branch 'update_advancedvi' of github.com:TuringLang/Turing.jl i…
Red-Portal Mar 14, 2025
57108ee
Merge branch 'main' into update_advancedvi
yebai Mar 18, 2025
8dc8067
Merge branch 'main' into update_advancedvi
yebai Mar 20, 2025
297c32a
Update Project.toml
yebai Mar 20, 2025
3010b5e
Merge branch 'main' of github.com:TuringLang/Turing.jl into update_ad…
Red-Portal Mar 25, 2025
0c04434
fix make some arugments of vi initializer to be optional kwargs
Red-Portal Mar 25, 2025
17a8290
Merge branch 'update_advancedvi' of github.com:TuringLang/Turing.jl i…
Red-Portal Mar 25, 2025
0e496c4
Merge branch 'main' into update_advancedvi
yebai Apr 18, 2025
c1533a8
Update src/variational/bijectors.jl
yebai Apr 18, 2025
231d6e2
Update Turing.jl
yebai Apr 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ Accessors = "0.1"
AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6, 0.7"
AdvancedMH = "0.8"
AdvancedPS = "0.6.0"
AdvancedVI = "0.2"
AdvancedVI = "0.3.1"
BangBang = "0.4.2"
Bijectors = "0.14, 0.15"
Compat = "4.15.0"
Expand Down
2 changes: 0 additions & 2 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ function setprogress!(progress::Bool)
@info "[Turing]: progress logging is $(progress ? "enabled" : "disabled") globally"
PROGRESS[] = progress
AbstractMCMC.setprogress!(progress; silent=true)
# TODO: `AdvancedVI.turnprogress` is removed in AdvancedVI v0.3
AdvancedVI.turnprogress(progress)
return progress
end

Expand Down
180 changes: 150 additions & 30 deletions src/variational/VariationalInference.jl
Original file line number Diff line number Diff line change
@@ -1,50 +1,170 @@

module Variational

using DistributionsAD: DistributionsAD
using DynamicPPL: DynamicPPL
using StatsBase: StatsBase
using StatsFuns: StatsFuns
using LogDensityProblems: LogDensityProblems
using DynamicPPL
using ADTypes
using Distributions
using LinearAlgebra
using LogDensityProblems
using Random

using Random: Random
import ..Turing: DEFAULT_ADTYPE, PROGRESS

import AdvancedVI
import Bijectors

# Reexports
using AdvancedVI: vi, ADVI, ELBO, elbo, TruncatedADAGrad, DecayedADAGrad
export vi, ADVI, ELBO, elbo, TruncatedADAGrad, DecayedADAGrad

"""
make_logjoint(model::Model; weight = 1.0)
Constructs the logjoint as a function of latent variables, i.e. the map z → p(x ∣ z) p(z).
The weight used to scale the likelihood, e.g. when doing stochastic gradient descent one needs to
use `DynamicPPL.MiniBatch` context to run the `Model` with a weight `num_total_obs / batch_size`.
## Notes
- For sake of efficiency, the returned function is closes over an instance of `VarInfo`. This means that you *might* run into some weird behaviour if you call this method sequentially using different types; if that's the case, just generate a new one for each type using `make_logjoint`.
"""
function make_logjoint(model::DynamicPPL.Model; weight=1.0)
# setup
using AdvancedVI: RepGradELBO, ScoreGradELBO, DoG, DoWG
export vi, RepGradELBO, ScoreGradELBO, DoG, DoWG

export meanfield_gaussian, fullrank_gaussian

include("bijectors.jl")

function make_logdensity(model::DynamicPPL.Model)
weight = 1.0

Check warning on line 25 in src/variational/VariationalInference.jl

View check run for this annotation

Codecov / codecov/patch

src/variational/VariationalInference.jl#L24-L25

Added lines #L24 - L25 were not covered by tests
ctx = DynamicPPL.MiniBatchContext(DynamicPPL.DefaultContext(), weight)
f = DynamicPPL.LogDensityFunction(model, DynamicPPL.VarInfo(model), ctx)
return Base.Fix1(LogDensityProblems.logdensity, f)
return DynamicPPL.LogDensityFunction(model, DynamicPPL.VarInfo(model), ctx)

Check warning on line 27 in src/variational/VariationalInference.jl

View check run for this annotation

Codecov / codecov/patch

src/variational/VariationalInference.jl#L27

Added line #L27 was not covered by tests
end

# objectives
function (elbo::ELBO)(
function initialize_gaussian_scale(

Check warning on line 30 in src/variational/VariationalInference.jl

View check run for this annotation

Codecov / codecov/patch

src/variational/VariationalInference.jl#L30

Added line #L30 was not covered by tests
rng::Random.AbstractRNG,
alg::AdvancedVI.VariationalInference,
q,
model::DynamicPPL.Model,
num_samples;
weight=1.0,
location::AbstractVector,
scale::AbstractMatrix;
num_samples::Int=10,
num_max_trials::Int=10,
reduce_factor=one(eltype(scale)) / 2,
)
prob = make_logdensity(model)
ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob)
varinfo = DynamicPPL.VarInfo(model)

Check warning on line 41 in src/variational/VariationalInference.jl

View check run for this annotation

Codecov / codecov/patch

src/variational/VariationalInference.jl#L39-L41

Added lines #L39 - L41 were not covered by tests

n_trial = 0
while true
q = AdvancedVI.MvLocationScale(location, scale, Normal())
b = Bijectors.bijector(model; varinfo=varinfo)
q_trans = Bijectors.transformed(q, Bijectors.inverse(b))
energy = mean(ℓπ, eachcol(rand(rng, q_trans, num_samples)))

Check warning on line 48 in src/variational/VariationalInference.jl

View check run for this annotation

Codecov / codecov/patch

src/variational/VariationalInference.jl#L43-L48

Added lines #L43 - L48 were not covered by tests

if isfinite(energy)
return scale
elseif n_trial == num_max_trials
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
elseif n_trial == num_max_trials
else

error("Could not find an initial")

Check warning on line 53 in src/variational/VariationalInference.jl

View check run for this annotation

Codecov / codecov/patch

src/variational/VariationalInference.jl#L50-L53

Added lines #L50 - L53 were not covered by tests
end

scale = reduce_factor * scale
n_trial += 1
end

Check warning on line 58 in src/variational/VariationalInference.jl

View check run for this annotation

Codecov / codecov/patch

src/variational/VariationalInference.jl#L56-L58

Added lines #L56 - L58 were not covered by tests
end

function meanfield_gaussian(

Check warning on line 61 in src/variational/VariationalInference.jl

View check run for this annotation

Codecov / codecov/patch

src/variational/VariationalInference.jl#L61

Added line #L61 was not covered by tests
Copy link
Member

@yebai yebai Apr 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function meanfield_gaussian(
function q_meanfield_gaussian(

rng::Random.AbstractRNG,
model::DynamicPPL.Model;
location::Union{Nothing,<:AbstractVector}=nothing,
scale::Union{Nothing,<:Diagonal}=nothing,
kwargs...,
)
varinfo = DynamicPPL.VarInfo(model)

Check warning on line 68 in src/variational/VariationalInference.jl

View check run for this annotation

Codecov / codecov/patch

src/variational/VariationalInference.jl#L68

Added line #L68 was not covered by tests
# Use linked `varinfo` to determine the correct number of parameters.
# TODO: Replace with `length` once this is implemented for `VarInfo`.
varinfo_linked = DynamicPPL.link(varinfo, model)
num_params = length(varinfo_linked[:])

Check warning on line 72 in src/variational/VariationalInference.jl

View check run for this annotation

Codecov / codecov/patch

src/variational/VariationalInference.jl#L71-L72

Added lines #L71 - L72 were not covered by tests

μ = if isnothing(location)
zeros(num_params)

Check warning on line 75 in src/variational/VariationalInference.jl

View check run for this annotation

Codecov / codecov/patch

src/variational/VariationalInference.jl#L74-L75

Added lines #L74 - L75 were not covered by tests
else
@assert length(location) == num_params "Length of the provided location vector, $(length(location)), does not match dimension of the target distribution, $(num_params)."
location

Check warning on line 78 in src/variational/VariationalInference.jl

View check run for this annotation

Codecov / codecov/patch

src/variational/VariationalInference.jl#L77-L78

Added lines #L77 - L78 were not covered by tests
end

L = if isnothing(scale)
initialize_gaussian_scale(rng, model, μ, Diagonal(ones(num_params)); kwargs...)

Check warning on line 82 in src/variational/VariationalInference.jl

View check run for this annotation

Codecov / codecov/patch

src/variational/VariationalInference.jl#L81-L82

Added lines #L81 - L82 were not covered by tests
else
@assert size(scale) == (num_params, num_params) "Dimensions of the provided scale matrix, $(size(scale)), does not match the dimension of the target distribution, $(num_params)."
L = scale

Check warning on line 85 in src/variational/VariationalInference.jl

View check run for this annotation

Codecov / codecov/patch

src/variational/VariationalInference.jl#L84-L85

Added lines #L84 - L85 were not covered by tests
end

q = AdvancedVI.MeanFieldGaussian(μ, L)
b = Bijectors.bijector(model; varinfo=varinfo)
return Bijectors.transformed(q, Bijectors.inverse(b))

Check warning on line 90 in src/variational/VariationalInference.jl

View check run for this annotation

Codecov / codecov/patch

src/variational/VariationalInference.jl#L88-L90

Added lines #L88 - L90 were not covered by tests
end

function meanfield_gaussian(

Check warning on line 93 in src/variational/VariationalInference.jl

View check run for this annotation

Codecov / codecov/patch

src/variational/VariationalInference.jl#L93

Added line #L93 was not covered by tests
Copy link
Member

@yebai yebai Apr 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function meanfield_gaussian(
function q_meanfield_gaussian(

model::DynamicPPL.Model;
location::Union{Nothing,<:AbstractVector}=nothing,
scale::Union{Nothing,<:Diagonal}=nothing,
kwargs...,
)
return elbo(rng, alg, q, make_logjoint(model; weight=weight), num_samples; kwargs...)
return meanfield_gaussian(Random.default_rng(), model; location, scale, kwargs...)

Check warning on line 99 in src/variational/VariationalInference.jl

View check run for this annotation

Codecov / codecov/patch

src/variational/VariationalInference.jl#L99

Added line #L99 was not covered by tests
end

function fullrank_gaussian(

Check warning on line 102 in src/variational/VariationalInference.jl

View check run for this annotation

Codecov / codecov/patch

src/variational/VariationalInference.jl#L102

Added line #L102 was not covered by tests
Copy link
Member

@yebai yebai Apr 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function fullrank_gaussian(
function q_fullrank_gaussian(

rng::Random.AbstractRNG,
model::DynamicPPL.Model;
location::Union{Nothing,<:AbstractVector}=nothing,
scale::Union{Nothing,<:LowerTriangular}=nothing,
kwargs...,
)
varinfo = DynamicPPL.VarInfo(model)

Check warning on line 109 in src/variational/VariationalInference.jl

View check run for this annotation

Codecov / codecov/patch

src/variational/VariationalInference.jl#L109

Added line #L109 was not covered by tests
# Use linked `varinfo` to determine the correct number of parameters.
# TODO: Replace with `length` once this is implemented for `VarInfo`.
varinfo_linked = DynamicPPL.link(varinfo, model)
num_params = length(varinfo_linked[:])

Check warning on line 113 in src/variational/VariationalInference.jl

View check run for this annotation

Codecov / codecov/patch

src/variational/VariationalInference.jl#L112-L113

Added lines #L112 - L113 were not covered by tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we get the dimentionality via num_params = length(varinfo_linked) instead of length(varinfo_linked[:])?

cc @mkarikom


μ = if isnothing(location)
zeros(num_params)

Check warning on line 116 in src/variational/VariationalInference.jl

View check run for this annotation

Codecov / codecov/patch

src/variational/VariationalInference.jl#L115-L116

Added lines #L115 - L116 were not covered by tests
else
@assert length(location) == num_params "Length of the provided location vector, $(length(location)), does not match dimension of the target distribution, $(num_params)."
location

Check warning on line 119 in src/variational/VariationalInference.jl

View check run for this annotation

Codecov / codecov/patch

src/variational/VariationalInference.jl#L118-L119

Added lines #L118 - L119 were not covered by tests
end

L = if isnothing(scale)
L0 = LowerTriangular(Matrix{Float64}(I, num_params, num_params))
initialize_gaussian_scale(rng, model, μ, L0; kwargs...)

Check warning on line 124 in src/variational/VariationalInference.jl

View check run for this annotation

Codecov / codecov/patch

src/variational/VariationalInference.jl#L122-L124

Added lines #L122 - L124 were not covered by tests
else
@assert size(scale) == (num_params, num_params) "Dimensions of the provided scale matrix, $(size(scale)), does not match the dimension of the target distribution, $(num_params)."
scale

Check warning on line 127 in src/variational/VariationalInference.jl

View check run for this annotation

Codecov / codecov/patch

src/variational/VariationalInference.jl#L126-L127

Added lines #L126 - L127 were not covered by tests
end

q = AdvancedVI.FullRankGaussian(μ, L)
b = Bijectors.bijector(model; varinfo=varinfo)
return Bijectors.transformed(q, Bijectors.inverse(b))

Check warning on line 132 in src/variational/VariationalInference.jl

View check run for this annotation

Codecov / codecov/patch

src/variational/VariationalInference.jl#L130-L132

Added lines #L130 - L132 were not covered by tests
end

# VI algorithms
include("advi.jl")
function fullrank_gaussian(

Check warning on line 135 in src/variational/VariationalInference.jl

View check run for this annotation

Codecov / codecov/patch

src/variational/VariationalInference.jl#L135

Added line #L135 was not covered by tests
Copy link
Member

@yebai yebai Apr 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function fullrank_gaussian(
function q_fullrank_gaussian(

model::DynamicPPL.Model;
location::Union{Nothing,<:AbstractVector}=nothing,
scale::Union{Nothing,<:LowerTriangular}=nothing,
kwargs...,
)
return fullrank_gaussian(Random.default_rng(), model; location, scale, kwargs...)

Check warning on line 141 in src/variational/VariationalInference.jl

View check run for this annotation

Codecov / codecov/patch

src/variational/VariationalInference.jl#L141

Added line #L141 was not covered by tests
end

function vi(

Check warning on line 144 in src/variational/VariationalInference.jl

View check run for this annotation

Codecov / codecov/patch

src/variational/VariationalInference.jl#L144

Added line #L144 was not covered by tests
model::DynamicPPL.Model,
q::Bijectors.TransformedDistribution,
n_iterations::Int;
objective=RepGradELBO(10; entropy=AdvancedVI.ClosedFormEntropyZeroGradient()),
show_progress::Bool=PROGRESS[],
optimizer=AdvancedVI.DoWG(),
averager=AdvancedVI.PolynomialAveraging(),
operator=AdvancedVI.ProximalLocationScaleEntropy(),
adtype::ADTypes.AbstractADType=DEFAULT_ADTYPE,
kwargs...,
)
return AdvancedVI.optimize(

Check warning on line 156 in src/variational/VariationalInference.jl

View check run for this annotation

Codecov / codecov/patch

src/variational/VariationalInference.jl#L156

Added line #L156 was not covered by tests
make_logdensity(model),
objective,
q,
n_iterations;
show_progress=show_progress,
adtype,
optimizer,
averager,
operator,
kwargs...,
)
end

end
140 changes: 0 additions & 140 deletions src/variational/advi.jl

This file was deleted.

Loading
Loading