@@ -12,9 +12,17 @@ using BenchmarkTools
12
12
y ~ Normal (μ, sqrt (σ))
13
13
end
14
14
15
+ function make_sampling_model_and_args (model, rng, sampler)
16
+ ctx = DynamicPPL. SamplingContext (rng, sampler, model. context)
17
+ spl_model = DynamicPPL. contextualize (model, ctx)
18
+ fargs, _ = DynamicPPL. make_evaluate_args_and_kwargs (spl_model, varinfo)
19
+ return (spl_model, fargs)
20
+ end
21
+
15
22
# Case 1: Sample from the prior.
16
- rng = MersenneTwister ()
17
- m = Turing. Inference. TracedModel (gdemo (1.5 , 2.0 ), SampleFromPrior (), VarInfo (), rng)
23
+ spl, rng = SampleFromPrior (), MersenneTwister ()
24
+ spl_model, fargs = make_sampling_model_and_args (gdemo (1.5 , 2.0 ), rng, spl)
25
+ m = Turing. Inference. TracedModel (spl_model, spl, VarInfo (), fargs)
18
26
f = m. evaluator[1 ];
19
27
args = m. evaluator[2 : end ];
20
28
@@ -27,7 +35,9 @@ println("Run a tape...")
27
35
@btime t. tf (args... )
28
36
29
37
# Case 2: SMC sampler
30
- m = Turing. Inference. TracedModel (gdemo (1.5 , 2.0 ), Sampler (SMC (50 )), VarInfo (), rng)
38
+ spl, rng = SMC (50 ), MersenneTwister ()
39
+ spl_model, fargs = make_sampling_model_and_args (gdemo (1.5 , 2.0 ), rng, spl)
40
+ m = Turing. Inference. TracedModel (spl_model, spl, VarInfo (), fargs)
31
41
f = m. evaluator[1 ];
32
42
args = m. evaluator[2 : end ];
33
43
0 commit comments