@@ -409,52 +409,74 @@ function DynamicPPL.initialstep(
409
409
rng:: Random.AbstractRNG ,
410
410
model:: DynamicPPL.Model ,
411
411
spl:: DynamicPPL.Sampler{<:Gibbs} ,
412
- vi_base :: DynamicPPL.AbstractVarInfo ;
412
+ vi :: DynamicPPL.AbstractVarInfo ;
413
413
initial_params= nothing ,
414
414
kwargs... ,
415
415
)
416
416
alg = spl. alg
417
417
varnames = alg. varnames
418
418
samplers = alg. samplers
419
419
420
- # Run the model once to get the varnames present + initial values to condition on.
421
- vi = DynamicPPL . VarInfo ( rng, model)
422
- if initial_params != = nothing
423
- vi = DynamicPPL . unflatten (vi, initial_params )
424
- end
420
+ vi, states = gibbs_initialstep_recursive (
421
+ rng, model, varnames, samplers, vi; initial_params = initial_params, kwargs ...
422
+ )
423
+ return Transition (model, vi), GibbsState (vi, states )
424
+ end
425
425
426
- # Initialise each component sampler in turn, collect all their states.
427
- states = []
428
- for (varnames_local, sampler_local) in zip (varnames, samplers)
429
- # Get the initial values for this component sampler.
430
- initial_params_local = if initial_params === nothing
431
- nothing
432
- else
433
- DynamicPPL. subset (vi, varnames_local)[:]
434
- end
426
+ """
427
+ Take the first step of MCMC for the first component sampler, and call the same function
428
+ recursively on the remaining samplers, until no samplers remain. Return the global VarInfo
429
+ and a tuple of initial states for all component samplers.
430
+ """
431
+ function gibbs_initialstep_recursive (
432
+ rng, model, varname_vecs, samplers, vi, states= (); initial_params= nothing , kwargs...
433
+ )
434
+ # End recursion
435
+ if isempty (varname_vecs) && isempty (samplers)
436
+ return vi, states
437
+ end
435
438
436
- # Construct the conditioned model.
437
- model_local, context_local = make_conditional (model, varnames_local, vi)
439
+ varnames, varname_vecs_tail ... = varname_vecs
440
+ sampler, samplers_tail ... = samplers
438
441
439
- # Take initial step.
440
- _, new_state_local = AbstractMCMC. step (
441
- rng,
442
- model_local,
443
- sampler_local;
444
- # FIXME : This will cause issues if the sampler expects initial params in unconstrained space.
445
- # This is not the case for any samplers in Turing.jl, but will be for external samplers, etc.
446
- initial_params= initial_params_local,
447
- kwargs... ,
448
- )
449
- new_vi_local = varinfo (new_state_local)
450
- # Merge in any new variables that were introduced during the step, but that
451
- # were not in the domain of the current sampler.
452
- vi = merge (vi, get_global_varinfo (context_local))
453
- # Merge the new values for all the variables sampled by the current sampler.
454
- vi = merge (vi, new_vi_local)
455
- push! (states, new_state_local)
442
+ # Get the initial values for this component sampler.
443
+ initial_params_local = if initial_params === nothing
444
+ nothing
445
+ else
446
+ DynamicPPL. subset (vi, varnames)[:]
456
447
end
457
- return Transition (model, vi), GibbsState (vi, states)
448
+
449
+ # Construct the conditioned model.
450
+ conditioned_model, context = make_conditional (model, varnames, vi)
451
+
452
+ # Take initial step with the current sampler.
453
+ _, new_state = AbstractMCMC. step (
454
+ rng,
455
+ conditioned_model,
456
+ sampler;
457
+ # FIXME : This will cause issues if the sampler expects initial params in unconstrained space.
458
+ # This is not the case for any samplers in Turing.jl, but will be for external samplers, etc.
459
+ initial_params= initial_params_local,
460
+ kwargs... ,
461
+ )
462
+ new_vi_local = varinfo (new_state)
463
+ # Merge in any new variables that were introduced during the step, but that
464
+ # were not in the domain of the current sampler.
465
+ vi = merge (vi, get_global_varinfo (context))
466
+ # Merge the new values for all the variables sampled by the current sampler.
467
+ vi = merge (vi, new_vi_local)
468
+
469
+ states = (states... , new_state)
470
+ return gibbs_initialstep_recursive (
471
+ rng,
472
+ model,
473
+ varname_vecs_tail,
474
+ samplers_tail,
475
+ vi,
476
+ states;
477
+ initial_params= initial_params,
478
+ kwargs... ,
479
+ )
458
480
end
459
481
460
482
function AbstractMCMC. step (
@@ -471,17 +493,7 @@ function AbstractMCMC.step(
471
493
states = state. states
472
494
@assert length (samplers) == length (state. states)
473
495
474
- # TODO : move this into a recursive function so we can unroll when reasonable?
475
- for index in 1 : length (samplers)
476
- # Take the inner step.
477
- sampler_local = samplers[index]
478
- state_local = states[index]
479
- varnames_local = varnames[index]
480
- vi, new_state_local = gibbs_step_inner (
481
- rng, model, varnames_local, sampler_local, state_local, vi; kwargs...
482
- )
483
- states = Accessors. setindex (states, new_state_local, index)
484
- end
496
+ vi, states = gibbs_step_recursive (rng, model, varnames, samplers, states, vi; kwargs... )
485
497
return Transition (model, vi), GibbsState (vi, states)
486
498
end
487
499
@@ -605,19 +617,33 @@ function match_linking!!(varinfo_local, prev_state_local, model)
605
617
return varinfo_local
606
618
end
607
619
608
- function gibbs_step_inner (
620
+ """
621
+ Run a Gibbs step for the first varname/sampler/state tuple, and recursively call the same
622
+ function on the tail, until there are no more samplers left.
623
+ """
624
+ function gibbs_step_recursive (
609
625
rng:: Random.AbstractRNG ,
610
626
model:: DynamicPPL.Model ,
611
- varnames_local,
612
- sampler_local,
613
- state_local,
614
- global_vi;
627
+ varname_vecs,
628
+ samplers,
629
+ states,
630
+ global_vi,
631
+ new_states= ();
615
632
kwargs... ,
616
633
)
634
+ # End recursion.
635
+ if isempty (varname_vecs) && isempty (samplers) && isempty (states)
636
+ return global_vi, new_states
637
+ end
638
+
639
+ varnames, varname_vecs_tail... = varname_vecs
640
+ sampler, samplers_tail... = samplers
641
+ state, states_tail... = states
642
+
617
643
# Construct the conditional model and the varinfo that this sampler should use.
618
- model_local, context_local = make_conditional (model, varnames_local , global_vi)
619
- varinfo_local = subset (global_vi, varnames_local )
620
- varinfo_local = match_linking!! (varinfo_local, state_local , model)
644
+ conditioned_model, context = make_conditional (model, varnames , global_vi)
645
+ vi = subset (global_vi, varnames )
646
+ vi = match_linking!! (vi, state , model)
621
647
622
648
# TODO (mhauru) The below may be overkill. If the varnames for this sampler are not
623
649
# sampled by other samplers, we don't need to `setparams`, but could rather simply
@@ -628,18 +654,25 @@ function gibbs_step_inner(
628
654
# going to be a significant expense anyway.
629
655
# Set the state of the current sampler, accounting for any changes made by other
630
656
# samplers.
631
- state_local = setparams_varinfo!! (
632
- model_local, sampler_local, state_local, varinfo_local
633
- )
657
+ state = setparams_varinfo!! (conditioned_model, sampler, state, vi)
634
658
635
659
# Take a step with the local sampler.
636
- new_state_local = last (
637
- AbstractMCMC. step (rng, model_local, sampler_local, state_local; kwargs... )
638
- )
660
+ new_state = last (AbstractMCMC. step (rng, conditioned_model, sampler, state; kwargs... ))
639
661
640
- new_vi_local = varinfo (new_state_local )
662
+ new_vi_local = varinfo (new_state )
641
663
# Merge the latest values for all the variables in the current sampler.
642
- new_global_vi = merge (get_global_varinfo (context_local ), new_vi_local)
664
+ new_global_vi = merge (get_global_varinfo (context ), new_vi_local)
643
665
new_global_vi = setlogp!! (new_global_vi, getlogp (new_vi_local))
644
- return new_global_vi, new_state_local
666
+
667
+ new_states = (new_states... , new_state)
668
+ return gibbs_step_recursive (
669
+ rng,
670
+ model,
671
+ varname_vecs_tail,
672
+ samplers_tail,
673
+ states_tail,
674
+ new_global_vi,
675
+ new_states;
676
+ kwargs... ,
677
+ )
645
678
end
0 commit comments