Skip to content

Commit 8177399

Browse files
committed
Fix more DynamicPPL stuff
1 parent 73e7b68 commit 8177399

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

perf/p0.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,17 @@ using BenchmarkTools
1212
y ~ Normal(μ, sqrt(σ))
1313
end
1414

15+
function make_sampling_model_and_args(model, rng, sampler)
16+
ctx = DynamicPPL.SamplingContext(rng, sampler, model.context)
17+
spl_model = DynamicPPL.contextualize(model, ctx)
18+
fargs, _ = DynamicPPL.make_evaluate_args_and_kwargs(spl_model, varinfo)
19+
return (spl_model, fargs)
20+
end
21+
1522
# Case 1: Sample from the prior.
16-
rng = MersenneTwister()
17-
m = Turing.Inference.TracedModel(gdemo(1.5, 2.0), SampleFromPrior(), VarInfo(), rng)
23+
spl, rng = SampleFromPrior(), MersenneTwister()
24+
spl_model, fargs = make_sampling_model_and_args(gdemo(1.5, 2.0), rng, spl)
25+
m = Turing.Inference.TracedModel(spl_model, spl, VarInfo(), fargs)
1826
f = m.evaluator[1];
1927
args = m.evaluator[2:end];
2028

@@ -27,7 +35,9 @@ println("Run a tape...")
2735
@btime t.tf(args...)
2836

2937
# Case 2: SMC sampler
30-
m = Turing.Inference.TracedModel(gdemo(1.5, 2.0), Sampler(SMC(50)), VarInfo(), rng)
38+
spl, rng = SMC(50), MersenneTwister()
39+
spl_model, fargs = make_sampling_model_and_args(gdemo(1.5, 2.0), rng, spl)
40+
m = Turing.Inference.TracedModel(spl_model, spl, VarInfo(), fargs)
3141
f = m.evaluator[1];
3242
args = m.evaluator[2:end];
3343

perf/p2.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ Random.seed!(rng, 2)
5252
iterations = 500
5353
model_fun = infiniteGMM(data)
5454

55-
m = Turing.Inference.TracedModel(model_fun, Sampler(SMC(50)), VarInfo(), rng)
55+
spl = SMC(50)
56+
spl_model, fargs = make_sampling_model_and_args(model_fun, rng, spl)
57+
m = Turing.Inference.TracedModel(spl_model, spl, VarInfo(), fargs)
5658
f = m.evaluator[1]
5759
args = m.evaluator[2:end]
5860

0 commit comments

Comments
 (0)