Skip to content

Commit 7d6f8ed

Browse files
mhaurupenelopeysm
andauthored
Rework Gibbs constructors (#2456)
* Rework Gibbs constructors, and remove the dead test/experimental/gibbs.jl * Update HISTORY.md * Clarify docstring * Remove unnecessary _maybecollect in gibbs.jl * Fix a bug * Fix more Gibbs constructors in tests * Improve HISTORY.md note Co-authored-by: Penelope Yong <[email protected]> * Apply proposals from code review * Add type bounds to Gibbs type parameters * Style improvements to gibbs.jl * Fix method ambiguity * Modify type signature of Gibbs --------- Co-authored-by: Penelope Yong <[email protected]>
1 parent 0c3d3d0 commit 7d6f8ed

File tree

10 files changed

+92
-355
lines changed

10 files changed

+92
-355
lines changed

HISTORY.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,11 @@
44

55
0.36.0 introduces a new Gibbs sampler. It's been included in several previous releases as `Turing.Experimental.Gibbs`, but now takes over the old Gibbs sampler, which gets removed completely.
66

7-
The new Gibbs sampler supports the same user-facing interface as the old one. However, given
8-
that the internals of it having been completely rewritten in a very different manner, there
9-
may be accidental breakage that we haven't anticipated. Please report any you find.
7+
The new Gibbs sampler currently supports the same user-facing interface as the old one, but the old constructors have been deprecated, and will be removed in the future. Also, given that the internals have been completely rewritten in a very different manner, there may be accidental breakage that we haven't anticipated. Please report any you find.
108

119
`GibbsConditional` has also been removed. It was never very user-facing, but it was exported, so technically this is breaking.
1210

13-
The old Gibbs constructor relied on being called with several subsamplers, and each of the constructors of the subsamplers would take as arguments the symbols for the variables that they are to sample, e.g. `Gibbs(HMC(:x), MH(:y))`. This constructor has been deprecated, and will be removed in the future. The new constructor works by assigning samplers to either symbols or `VarNames`, e.g. `Gibbs(; x=HMC(), y=MH())` or `Gibbs(@varname(x) => HMC(), @varname(y) => MH())`. This allows more granular specification of which sampler to use for which variable.
11+
The old Gibbs constructor relied on being called with several subsamplers, and each of the constructors of the subsamplers would take as arguments the symbols for the variables that they are to sample, e.g. `Gibbs(HMC(:x), MH(:y))`. This constructor has been deprecated, and will be removed in the future. The new constructor works by mapping symbols, `VarName`s, or iterables thereof to samplers, e.g. `Gibbs(x=>HMC(), y=>MH())`, `Gibbs(@varname(x) => HMC(), @varname(y) => MH())`, `Gibbs((:x, :y) => NUTS(), :z => MH())`. This allows more granular specification of which sampler to use for which variable.
1412

1513
Likewise, the old constructor for calling one subsampler more often than another, `Gibbs((HMC(0.01, 4, :x), 2), (MH(:y), 1))` has been deprecated. The new way to do this is to use `RepeatSampler`, also introduced at this version: `Gibbs(@varname(x) => RepeatSampler(HMC(0.01, 4), 2), @varname(y) => MH())`.
1614

src/mcmc/gibbs.jl

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -292,15 +292,40 @@ function set_selector(x::RepeatSampler)
292292
end
293293
set_selector(x::InferenceAlgorithm) = DynamicPPL.Sampler(x, DynamicPPL.Selector(0))
294294

295+
to_varname_list(x::Union{VarName,Symbol}) = [VarName(x)]
296+
# Any other value is assumed to be an iterable of VarNames and Symbols.
297+
to_varname_list(t) = collect(map(VarName, t))
298+
295299
"""
296300
Gibbs
297301
298302
A type representing a Gibbs sampler.
299303
304+
# Constructors
305+
306+
`Gibbs` needs to be given a set of pairs of variable names and samplers. Instead of a single
307+
variable name per sampler, one can also give an iterable of variables, all of which are
308+
sampled by the same component sampler.
309+
310+
Each variable name can be given as either a `Symbol` or a `VarName`.
311+
312+
Some examples of valid constructors are:
313+
```julia
314+
Gibbs(:x => NUTS(), :y => MH())
315+
Gibbs(@varname(x) => NUTS(), @varname(y) => MH())
316+
Gibbs((@varname(x), :y) => NUTS(), :z => MH())
317+
```
318+
319+
Currently only variable names without indexing are supported, so for instance
320+
`Gibbs(@varname(x[1]) => NUTS())` does not work. This will hopefully change in the future.
321+
300322
# Fields
301323
$(TYPEDFIELDS)
302324
"""
303-
struct Gibbs{V,A} <: InferenceAlgorithm
325+
struct Gibbs{N,V<:NTuple{N,AbstractVector{<:VarName}},A<:NTuple{N,Any}} <:
326+
InferenceAlgorithm
327+
# TODO(mhauru) Revisit whether A should have a fixed element type once
328+
# InferenceAlgorithm/Sampler types have been cleaned up.
304329
"varnames representing variables for each sampler"
305330
varnames::V
306331
"samplers for each entry in `varnames`"
@@ -310,40 +335,30 @@ struct Gibbs{V,A} <: InferenceAlgorithm
310335
if length(varnames) != length(samplers)
311336
throw(ArgumentError("Number of varnames and samplers must match."))
312337
end
338+
313339
for spl in samplers
314340
if !isgibbscomponent(spl)
315341
msg = "All samplers must be valid Gibbs components, $(spl) is not."
316342
throw(ArgumentError(msg))
317343
end
318344
end
319-
return new{typeof(varnames),typeof(samplers)}(varnames, samplers)
320-
end
321-
end
322-
323-
to_varname(vn::VarName) = vn
324-
to_varname(s::Symbol) = VarName{s}()
325-
# Any other value is assumed to be an iterable.
326-
to_varname(t) = map(to_varname, collect(t))
327345

328-
# NamedTuple
329-
Gibbs(; algs...) = Gibbs(NamedTuple(algs))
330-
function Gibbs(algs::NamedTuple)
331-
return Gibbs(map(to_varname, keys(algs)), map(set_selector drop_space, values(algs)))
346+
# Ensure that samplers have the same selector, and that varnames are lists of
347+
# VarNames.
348+
samplers = tuple(map(set_selector drop_space, samplers)...)
349+
varnames = tuple(map(to_varname_list, varnames)...)
350+
return new{length(samplers),typeof(varnames),typeof(samplers)}(varnames, samplers)
351+
end
332352
end
333353

334-
# AbstractDict
335-
function Gibbs(algs::AbstractDict)
336-
return Gibbs(
337-
map(to_varname, collect(keys(algs))), map(set_selector drop_space, values(algs))
338-
)
339-
end
340354
function Gibbs(algs::Pair...)
341-
return Gibbs(map(to_varname first, algs), map(set_selector drop_space last, algs))
355+
return Gibbs(map(first, algs), map(last, algs))
342356
end
343357

344358
# The below two constructors only provide backwards compatibility with the constructor of
345359
# the old Gibbs sampler. They are deprecated and will be removed in the future.
346-
function Gibbs(algs::InferenceAlgorithm...)
360+
function Gibbs(alg1::InferenceAlgorithm, other_algs::InferenceAlgorithm...)
361+
algs = [alg1, other_algs...]
347362
varnames = map(algs) do alg
348363
space = getspace(alg)
349364
if (space isa VarName)
@@ -365,7 +380,11 @@ function Gibbs(algs::InferenceAlgorithm...)
365380
return Gibbs(varnames, map(set_selector drop_space, algs))
366381
end
367382

368-
function Gibbs(algs_with_iters::Tuple{<:InferenceAlgorithm,Int}...)
383+
function Gibbs(
384+
alg_with_iters1::Tuple{<:InferenceAlgorithm,Int},
385+
other_algs_with_iters::Tuple{<:InferenceAlgorithm,Int}...,
386+
)
387+
algs_with_iters = [alg_with_iters1, other_algs_with_iters...]
369388
algs = Iterators.map(first, algs_with_iters)
370389
iters = Iterators.map(last, algs_with_iters)
371390
algs_duplicated = Iterators.flatten((
@@ -384,11 +403,6 @@ struct GibbsState{V<:DynamicPPL.AbstractVarInfo,S}
384403
states::S
385404
end
386405

387-
_maybevec(x) = vec(x) # assume it's iterable
388-
_maybevec(x::Tuple) = [x...]
389-
_maybevec(x::VarName) = [x]
390-
_maybevec(x::Symbol) = [x]
391-
392406
varinfo(state::GibbsState) = state.vi
393407

394408
function DynamicPPL.initialstep(
@@ -412,7 +426,6 @@ function DynamicPPL.initialstep(
412426
# Initialise each component sampler in turn, collect all their states.
413427
states = []
414428
for (varnames_local, sampler_local) in zip(varnames, samplers)
415-
varnames_local = _maybevec(varnames_local)
416429
# Get the initial values for this component sampler.
417430
initial_params_local = if initial_params === nothing
418431
nothing
@@ -463,7 +476,7 @@ function AbstractMCMC.step(
463476
# Take the inner step.
464477
sampler_local = samplers[index]
465478
state_local = states[index]
466-
varnames_local = _maybevec(varnames[index])
479+
varnames_local = varnames[index]
467480
vi, new_state_local = gibbs_step_inner(
468481
rng, model, varnames_local, sampler_local, state_local, vi; kwargs...
469482
)

test/dynamicppl/compiler.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ const gdemo_default = gdemo_d()
5454

5555
smc = SMC()
5656
pg = PG(10)
57-
gibbs = Gibbs(; p=HMC(0.2, 3), x=PG(10))
57+
gibbs = Gibbs(:p => HMC(0.2, 3), :x => PG(10))
5858

5959
chn_s = sample(testbb(obs), smc, 1000)
6060
chn_p = sample(testbb(obs), pg, 2000)
@@ -81,7 +81,7 @@ const gdemo_default = gdemo_d()
8181
return s, m
8282
end
8383

84-
gibbs = Gibbs(; s=PG(10), m=HMC(0.4, 8))
84+
gibbs = Gibbs(:s => PG(10), :m => HMC(0.4, 8))
8585
chain = sample(fggibbstest(xs), gibbs, 2)
8686
end
8787
@testset "new grammar" begin
@@ -177,7 +177,7 @@ const gdemo_default = gdemo_d()
177177
end
178178

179179
@testset "sample" begin
180-
alg = Gibbs(; m=HMC(0.2, 3), s=PG(10))
180+
alg = Gibbs(:m => HMC(0.2, 3), :s => PG(10))
181181
chn = sample(gdemo_default, alg, 1000)
182182
end
183183

0 commit comments

Comments
 (0)