Skip to content

Commit 8bf98e1

Browse files
authored
Update for DynamicPPL 0.33 and 0.34 (#2459)
* Update for DynamicPPL 0.33 * Don't remove import/export * 0.34 too * Update test compat too * Remove upstream tests for `predict`
1 parent 24d5556 commit 8bf98e1

File tree

4 files changed

+4
-246
lines changed

4 files changed

+4
-246
lines changed

Project.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.36.0"
3+
version = "0.36.1"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -63,7 +63,7 @@ Distributions = "0.23.3, 0.24, 0.25"
6363
DistributionsAD = "0.6"
6464
DocStringExtensions = "0.8, 0.9"
6565
DynamicHMC = "3.4"
66-
DynamicPPL = "0.32"
66+
DynamicPPL = "0.33, 0.34"
6767
EllipticalSliceSampling = "0.5, 1, 2"
6868
ForwardDiff = "0.10.3"
6969
Libtask = "0.8.8"

src/mcmc/Inference.jl

+1-107
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo)
396396
# this means that the code below will work both of linked and invlinked `vi`.
397397
# Ref: https://github.com/TuringLang/Turing.jl/issues/2195
398398
# NOTE: We need to `deepcopy` here to avoid modifying the original `vi`.
399-
vals = DynamicPPL.values_as_in_model(model, deepcopy(vi))
399+
vals = DynamicPPL.values_as_in_model(model, true, deepcopy(vi))
400400

401401
# Obtain an iterator over the flattened parameter names and values.
402402
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
@@ -612,112 +612,6 @@ end
612612
DynamicPPL.getspace(spl::Sampler) = getspace(spl.alg)
613613
DynamicPPL.inspace(vn::VarName, spl::Sampler) = inspace(vn, getspace(spl.alg))
614614

615-
"""
616-
617-
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
618-
619-
Execute `model` conditioned on each sample in `chain`, and return the resulting `Chains`.
620-
621-
If `include_all` is `false`, the returned `Chains` will contain only those variables
622-
sampled/not present in `chain`.
623-
624-
# Details
625-
Internally calls `Turing.Inference.transitions_from_chain` to obtained the samples
626-
and then converts these into a `Chains` object using `AbstractMCMC.bundle_samples`.
627-
628-
# Example
629-
```jldoctest
630-
julia> using Turing; Turing.setprogress!(false);
631-
[ Info: [Turing]: progress logging is disabled globally
632-
633-
julia> @model function linear_reg(x, y, σ = 0.1)
634-
β ~ Normal(0, 1)
635-
636-
for i ∈ eachindex(y)
637-
y[i] ~ Normal(β * x[i], σ)
638-
end
639-
end;
640-
641-
julia> σ = 0.1; f(x) = 2 * x + 0.1 * randn();
642-
643-
julia> Δ = 0.1; xs_train = 0:Δ:10; ys_train = f.(xs_train);
644-
645-
julia> xs_test = [10 + Δ, 10 + 2 * Δ]; ys_test = f.(xs_test);
646-
647-
julia> m_train = linear_reg(xs_train, ys_train, σ);
648-
649-
julia> chain_lin_reg = sample(m_train, NUTS(100, 0.65), 200);
650-
┌ Info: Found initial step size
651-
└ ϵ = 0.003125
652-
653-
julia> m_test = linear_reg(xs_test, Vector{Union{Missing, Float64}}(undef, length(ys_test)), σ);
654-
655-
julia> predictions = predict(m_test, chain_lin_reg)
656-
Object of type Chains, with data of type 100×2×1 Array{Float64,3}
657-
658-
Iterations = 1:100
659-
Thinning interval = 1
660-
Chains = 1
661-
Samples per chain = 100
662-
parameters = y[1], y[2]
663-
664-
2-element Array{ChainDataFrame,1}
665-
666-
Summary Statistics
667-
parameters mean std naive_se mcse ess r_hat
668-
────────── ─────── ────── ──────── ─────── ──────── ──────
669-
y[1] 20.1974 0.1007 0.0101 missing 101.0711 0.9922
670-
y[2] 20.3867 0.1062 0.0106 missing 101.4889 0.9903
671-
672-
Quantiles
673-
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
674-
────────── ─────── ─────── ─────── ─────── ───────
675-
y[1] 20.0342 20.1188 20.2135 20.2588 20.4188
676-
y[2] 20.1870 20.3178 20.3839 20.4466 20.5895
677-
678-
679-
julia> ys_pred = vec(mean(Array(group(predictions, :y)); dims = 1));
680-
681-
julia> sum(abs2, ys_test - ys_pred) ≤ 0.1
682-
true
683-
```
684-
"""
685-
function predict(model::Model, chain::MCMCChains.Chains; kwargs...)
686-
return predict(Random.default_rng(), model, chain; kwargs...)
687-
end
688-
function predict(
689-
rng::AbstractRNG, model::Model, chain::MCMCChains.Chains; include_all=false
690-
)
691-
# Don't need all the diagnostics
692-
chain_parameters = MCMCChains.get_sections(chain, :parameters)
693-
694-
spl = DynamicPPL.SampleFromPrior()
695-
696-
# Sample transitions using `spl` conditioned on values in `chain`
697-
transitions = transitions_from_chain(rng, model, chain_parameters; sampler=spl)
698-
699-
# Let the Turing internals handle everything else for you
700-
chain_result = reduce(
701-
MCMCChains.chainscat,
702-
[
703-
AbstractMCMC.bundle_samples(
704-
transitions[:, chain_idx], model, spl, nothing, MCMCChains.Chains
705-
) for chain_idx in 1:size(transitions, 2)
706-
],
707-
)
708-
709-
parameter_names = if include_all
710-
names(chain_result, :parameters)
711-
else
712-
filter(
713-
k -> (k, names(chain_parameters, :parameters)),
714-
names(chain_result, :parameters),
715-
)
716-
end
717-
718-
return chain_result[parameter_names]
719-
end
720-
721615
"""
722616
723617
transitions_from_chain(

test/Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ Combinatorics = "1"
5151
Distributions = "0.25"
5252
DistributionsAD = "0.6.3"
5353
DynamicHMC = "2.1.6, 3.0"
54-
DynamicPPL = "0.32.2"
54+
DynamicPPL = "0.33, 0.34"
5555
FiniteDifferences = "0.10.8, 0.11, 0.12"
5656
ForwardDiff = "0.10.12 - 0.10.32, 0.10"
5757
HypothesisTests = "0.11"

test/mcmc/utilities.jl

-136
Original file line numberDiff line numberDiff line change
@@ -1,145 +1,9 @@
11
module MCMCUtilitiesTests
22

33
using ..Models: gdemo_default
4-
using Distributions: Normal, sample, truncated
5-
using LinearAlgebra: I, vec
6-
using Random: Random
7-
using Random: MersenneTwister
84
using Test: @test, @testset
95
using Turing
106

11-
@testset "predict" begin
12-
Random.seed!(100)
13-
14-
@model function linear_reg(x, y, σ=0.1)
15-
β ~ Normal(0, 1)
16-
17-
for i in eachindex(y)
18-
y[i] ~ Normal* x[i], σ)
19-
end
20-
end
21-
22-
@model function linear_reg_vec(x, y, σ=0.1)
23-
β ~ Normal(0, 1)
24-
return y ~ MvNormal.* x, σ^2 * I)
25-
end
26-
27-
f(x) = 2 * x + 0.1 * randn()
28-
29-
Δ = 0.1
30-
xs_train = 0:Δ:10
31-
ys_train = f.(xs_train)
32-
xs_test = [10 + Δ, 10 + 2 * Δ]
33-
ys_test = f.(xs_test)
34-
35-
# Infer
36-
m_lin_reg = linear_reg(xs_train, ys_train)
37-
chain_lin_reg = sample(m_lin_reg, NUTS(100, 0.65), 200)
38-
39-
# Predict on two last indices
40-
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(ys_test)))
41-
predictions = Turing.Inference.predict(m_lin_reg_test, chain_lin_reg)
42-
43-
ys_pred = vec(mean(Array(group(predictions, :y)); dims=1))
44-
45-
@test sum(abs2, ys_test - ys_pred) 0.1
46-
47-
# Ensure that `rng` is respected
48-
predictions1 = let rng = MersenneTwister(42)
49-
predict(rng, m_lin_reg_test, chain_lin_reg[1:2])
50-
end
51-
predictions2 = let rng = MersenneTwister(42)
52-
predict(rng, m_lin_reg_test, chain_lin_reg[1:2])
53-
end
54-
@test all(Array(predictions1) .== Array(predictions2))
55-
56-
# Predict on two last indices for vectorized
57-
m_lin_reg_test = linear_reg_vec(xs_test, missing)
58-
predictions_vec = Turing.Inference.predict(m_lin_reg_test, chain_lin_reg)
59-
ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims=1))
60-
61-
@test sum(abs2, ys_test - ys_pred_vec) 0.1
62-
63-
# Multiple chains
64-
chain_lin_reg = sample(m_lin_reg, NUTS(100, 0.65), MCMCThreads(), 200, 2)
65-
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(ys_test)))
66-
predictions = Turing.Inference.predict(m_lin_reg_test, chain_lin_reg)
67-
68-
@test size(chain_lin_reg, 3) == size(predictions, 3)
69-
70-
for chain_idx in MCMCChains.chains(chain_lin_reg)
71-
ys_pred = vec(mean(Array(group(predictions[:, :, chain_idx], :y)); dims=1))
72-
@test sum(abs2, ys_test - ys_pred) 0.1
73-
end
74-
75-
# Predict on two last indices for vectorized
76-
m_lin_reg_test = linear_reg_vec(xs_test, missing)
77-
predictions_vec = Turing.Inference.predict(m_lin_reg_test, chain_lin_reg)
78-
79-
for chain_idx in MCMCChains.chains(chain_lin_reg)
80-
ys_pred_vec = vec(mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims=1))
81-
@test sum(abs2, ys_test - ys_pred_vec) 0.1
82-
end
83-
84-
# https://github.com/TuringLang/Turing.jl/issues/1352
85-
@model function simple_linear1(x, y)
86-
intercept ~ Normal(0, 1)
87-
coef ~ MvNormal(zeros(2), I)
88-
coef = reshape(coef, 1, size(x, 1))
89-
90-
mu = vec(intercept .+ coef * x)
91-
error ~ truncated(Normal(0, 1), 0, Inf)
92-
return y ~ MvNormal(mu, error^2 * I)
93-
end
94-
95-
@model function simple_linear2(x, y)
96-
intercept ~ Normal(0, 1)
97-
coef ~ filldist(Normal(0, 1), 2)
98-
coef = reshape(coef, 1, size(x, 1))
99-
100-
mu = vec(intercept .+ coef * x)
101-
error ~ truncated(Normal(0, 1), 0, Inf)
102-
return y ~ MvNormal(mu, error^2 * I)
103-
end
104-
105-
@model function simple_linear3(x, y)
106-
intercept ~ Normal(0, 1)
107-
coef = Vector(undef, 2)
108-
for i in axes(coef, 1)
109-
coef[i] ~ Normal(0, 1)
110-
end
111-
coef = reshape(coef, 1, size(x, 1))
112-
113-
mu = vec(intercept .+ coef * x)
114-
error ~ truncated(Normal(0, 1), 0, Inf)
115-
return y ~ MvNormal(mu, error^2 * I)
116-
end
117-
118-
@model function simple_linear4(x, y)
119-
intercept ~ Normal(0, 1)
120-
coef1 ~ Normal(0, 1)
121-
coef2 ~ Normal(0, 1)
122-
coef = [coef1, coef2]
123-
coef = reshape(coef, 1, size(x, 1))
124-
125-
mu = vec(intercept .+ coef * x)
126-
error ~ truncated(Normal(0, 1), 0, Inf)
127-
return y ~ MvNormal(mu, error^2 * I)
128-
end
129-
130-
# Some data
131-
x = randn(2, 100)
132-
y = [1 + 2 * a + 3 * b for (a, b) in eachcol(x)]
133-
134-
for model in [simple_linear1, simple_linear2, simple_linear3, simple_linear4]
135-
m = model(x, y)
136-
chain = sample(m, NUTS(), 100)
137-
chain_predict = predict(model(x, missing), chain)
138-
mean_prediction = [mean(chain_predict["y[$i]"].data) for i in 1:length(y)]
139-
@test mean(abs2, mean_prediction - y) 1e-3
140-
end
141-
end
142-
1437
@testset "Timer" begin
1448
chain = sample(gdemo_default, MH(), 1000)
1459

0 commit comments

Comments
 (0)