Skip to content

Commit 11944c9

Browse files
mhaurupenelopeysmgithub-actions[bot]
authored
Replace Gibbs inner loop with recursion (#2464)
* Use recursion in gibbs_inner_step * Remove unnecessary initialisation in Gibbs We were doing work that was already done by the caller of initialstep. * variable naming / destructuring (#2465) * Variable naming, destructuring * Tuple -> Vec * Reviewdog code style suggestions Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: Penelope Yong <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 7d6f8ed commit 11944c9

File tree

1 file changed

+96
-63
lines changed

1 file changed

+96
-63
lines changed

src/mcmc/gibbs.jl

Lines changed: 96 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -409,52 +409,74 @@ function DynamicPPL.initialstep(
409409
rng::Random.AbstractRNG,
410410
model::DynamicPPL.Model,
411411
spl::DynamicPPL.Sampler{<:Gibbs},
412-
vi_base::DynamicPPL.AbstractVarInfo;
412+
vi::DynamicPPL.AbstractVarInfo;
413413
initial_params=nothing,
414414
kwargs...,
415415
)
416416
alg = spl.alg
417417
varnames = alg.varnames
418418
samplers = alg.samplers
419419

420-
# Run the model once to get the varnames present + initial values to condition on.
421-
vi = DynamicPPL.VarInfo(rng, model)
422-
if initial_params !== nothing
423-
vi = DynamicPPL.unflatten(vi, initial_params)
424-
end
420+
vi, states = gibbs_initialstep_recursive(
421+
rng, model, varnames, samplers, vi; initial_params=initial_params, kwargs...
422+
)
423+
return Transition(model, vi), GibbsState(vi, states)
424+
end
425425

426-
# Initialise each component sampler in turn, collect all their states.
427-
states = []
428-
for (varnames_local, sampler_local) in zip(varnames, samplers)
429-
# Get the initial values for this component sampler.
430-
initial_params_local = if initial_params === nothing
431-
nothing
432-
else
433-
DynamicPPL.subset(vi, varnames_local)[:]
434-
end
426+
"""
427+
Take the first step of MCMC for the first component sampler, and call the same function
428+
recursively on the remaining samplers, until no samplers remain. Return the global VarInfo
429+
and a tuple of initial states for all component samplers.
430+
"""
431+
function gibbs_initialstep_recursive(
432+
rng, model, varname_vecs, samplers, vi, states=(); initial_params=nothing, kwargs...
433+
)
434+
# End recursion
435+
if isempty(varname_vecs) && isempty(samplers)
436+
return vi, states
437+
end
435438

436-
# Construct the conditioned model.
437-
model_local, context_local = make_conditional(model, varnames_local, vi)
439+
varnames, varname_vecs_tail... = varname_vecs
440+
sampler, samplers_tail... = samplers
438441

439-
# Take initial step.
440-
_, new_state_local = AbstractMCMC.step(
441-
rng,
442-
model_local,
443-
sampler_local;
444-
# FIXME: This will cause issues if the sampler expects initial params in unconstrained space.
445-
# This is not the case for any samplers in Turing.jl, but will be for external samplers, etc.
446-
initial_params=initial_params_local,
447-
kwargs...,
448-
)
449-
new_vi_local = varinfo(new_state_local)
450-
# Merge in any new variables that were introduced during the step, but that
451-
# were not in the domain of the current sampler.
452-
vi = merge(vi, get_global_varinfo(context_local))
453-
# Merge the new values for all the variables sampled by the current sampler.
454-
vi = merge(vi, new_vi_local)
455-
push!(states, new_state_local)
442+
# Get the initial values for this component sampler.
443+
initial_params_local = if initial_params === nothing
444+
nothing
445+
else
446+
DynamicPPL.subset(vi, varnames)[:]
456447
end
457-
return Transition(model, vi), GibbsState(vi, states)
448+
449+
# Construct the conditioned model.
450+
conditioned_model, context = make_conditional(model, varnames, vi)
451+
452+
# Take initial step with the current sampler.
453+
_, new_state = AbstractMCMC.step(
454+
rng,
455+
conditioned_model,
456+
sampler;
457+
# FIXME: This will cause issues if the sampler expects initial params in unconstrained space.
458+
# This is not the case for any samplers in Turing.jl, but will be for external samplers, etc.
459+
initial_params=initial_params_local,
460+
kwargs...,
461+
)
462+
new_vi_local = varinfo(new_state)
463+
# Merge in any new variables that were introduced during the step, but that
464+
# were not in the domain of the current sampler.
465+
vi = merge(vi, get_global_varinfo(context))
466+
# Merge the new values for all the variables sampled by the current sampler.
467+
vi = merge(vi, new_vi_local)
468+
469+
states = (states..., new_state)
470+
return gibbs_initialstep_recursive(
471+
rng,
472+
model,
473+
varname_vecs_tail,
474+
samplers_tail,
475+
vi,
476+
states;
477+
initial_params=initial_params,
478+
kwargs...,
479+
)
458480
end
459481

460482
function AbstractMCMC.step(
@@ -471,17 +493,7 @@ function AbstractMCMC.step(
471493
states = state.states
472494
@assert length(samplers) == length(state.states)
473495

474-
# TODO: move this into a recursive function so we can unroll when reasonable?
475-
for index in 1:length(samplers)
476-
# Take the inner step.
477-
sampler_local = samplers[index]
478-
state_local = states[index]
479-
varnames_local = varnames[index]
480-
vi, new_state_local = gibbs_step_inner(
481-
rng, model, varnames_local, sampler_local, state_local, vi; kwargs...
482-
)
483-
states = Accessors.setindex(states, new_state_local, index)
484-
end
496+
vi, states = gibbs_step_recursive(rng, model, varnames, samplers, states, vi; kwargs...)
485497
return Transition(model, vi), GibbsState(vi, states)
486498
end
487499

@@ -605,19 +617,33 @@ function match_linking!!(varinfo_local, prev_state_local, model)
605617
return varinfo_local
606618
end
607619

608-
function gibbs_step_inner(
620+
"""
621+
Run a Gibbs step for the first varname/sampler/state tuple, and recursively call the same
622+
function on the tail, until there are no more samplers left.
623+
"""
624+
function gibbs_step_recursive(
609625
rng::Random.AbstractRNG,
610626
model::DynamicPPL.Model,
611-
varnames_local,
612-
sampler_local,
613-
state_local,
614-
global_vi;
627+
varname_vecs,
628+
samplers,
629+
states,
630+
global_vi,
631+
new_states=();
615632
kwargs...,
616633
)
634+
# End recursion.
635+
if isempty(varname_vecs) && isempty(samplers) && isempty(states)
636+
return global_vi, new_states
637+
end
638+
639+
varnames, varname_vecs_tail... = varname_vecs
640+
sampler, samplers_tail... = samplers
641+
state, states_tail... = states
642+
617643
# Construct the conditional model and the varinfo that this sampler should use.
618-
model_local, context_local = make_conditional(model, varnames_local, global_vi)
619-
varinfo_local = subset(global_vi, varnames_local)
620-
varinfo_local = match_linking!!(varinfo_local, state_local, model)
644+
conditioned_model, context = make_conditional(model, varnames, global_vi)
645+
vi = subset(global_vi, varnames)
646+
vi = match_linking!!(vi, state, model)
621647

622648
# TODO(mhauru) The below may be overkill. If the varnames for this sampler are not
623649
# sampled by other samplers, we don't need to `setparams`, but could rather simply
@@ -628,18 +654,25 @@ function gibbs_step_inner(
628654
# going to be a significant expense anyway.
629655
# Set the state of the current sampler, accounting for any changes made by other
630656
# samplers.
631-
state_local = setparams_varinfo!!(
632-
model_local, sampler_local, state_local, varinfo_local
633-
)
657+
state = setparams_varinfo!!(conditioned_model, sampler, state, vi)
634658

635659
# Take a step with the local sampler.
636-
new_state_local = last(
637-
AbstractMCMC.step(rng, model_local, sampler_local, state_local; kwargs...)
638-
)
660+
new_state = last(AbstractMCMC.step(rng, conditioned_model, sampler, state; kwargs...))
639661

640-
new_vi_local = varinfo(new_state_local)
662+
new_vi_local = varinfo(new_state)
641663
# Merge the latest values for all the variables in the current sampler.
642-
new_global_vi = merge(get_global_varinfo(context_local), new_vi_local)
664+
new_global_vi = merge(get_global_varinfo(context), new_vi_local)
643665
new_global_vi = setlogp!!(new_global_vi, getlogp(new_vi_local))
644-
return new_global_vi, new_state_local
666+
667+
new_states = (new_states..., new_state)
668+
return gibbs_step_recursive(
669+
rng,
670+
model,
671+
varname_vecs_tail,
672+
samplers_tail,
673+
states_tail,
674+
new_global_vi,
675+
new_states;
676+
kwargs...,
677+
)
645678
end

0 commit comments

Comments
 (0)