|
1 | 1 | module MCMCUtilitiesTests
|
2 | 2 |
|
3 | 3 | using ..Models: gdemo_default
|
4 |
| -using Distributions: Normal, sample, truncated |
5 |
| -using LinearAlgebra: I, vec |
6 |
| -using Random: Random |
7 |
| -using Random: MersenneTwister |
8 | 4 | using Test: @test, @testset
|
9 | 5 | using Turing
|
10 | 6 |
|
11 |
| -@testset "predict" begin |
12 |
| - Random.seed!(100) |
13 |
| - |
14 |
| - @model function linear_reg(x, y, σ=0.1) |
15 |
| - β ~ Normal(0, 1) |
16 |
| - |
17 |
| - for i in eachindex(y) |
18 |
| - y[i] ~ Normal(β * x[i], σ) |
19 |
| - end |
20 |
| - end |
21 |
| - |
22 |
| - @model function linear_reg_vec(x, y, σ=0.1) |
23 |
| - β ~ Normal(0, 1) |
24 |
| - return y ~ MvNormal(β .* x, σ^2 * I) |
25 |
| - end |
26 |
| - |
27 |
| - f(x) = 2 * x + 0.1 * randn() |
28 |
| - |
29 |
| - Δ = 0.1 |
30 |
| - xs_train = 0:Δ:10 |
31 |
| - ys_train = f.(xs_train) |
32 |
| - xs_test = [10 + Δ, 10 + 2 * Δ] |
33 |
| - ys_test = f.(xs_test) |
34 |
| - |
35 |
| - # Infer |
36 |
| - m_lin_reg = linear_reg(xs_train, ys_train) |
37 |
| - chain_lin_reg = sample(m_lin_reg, NUTS(100, 0.65), 200) |
38 |
| - |
39 |
| - # Predict on two last indices |
40 |
| - m_lin_reg_test = linear_reg(xs_test, fill(missing, length(ys_test))) |
41 |
| - predictions = Turing.Inference.predict(m_lin_reg_test, chain_lin_reg) |
42 |
| - |
43 |
| - ys_pred = vec(mean(Array(group(predictions, :y)); dims=1)) |
44 |
| - |
45 |
| - @test sum(abs2, ys_test - ys_pred) ≤ 0.1 |
46 |
| - |
47 |
| - # Ensure that `rng` is respected |
48 |
| - predictions1 = let rng = MersenneTwister(42) |
49 |
| - predict(rng, m_lin_reg_test, chain_lin_reg[1:2]) |
50 |
| - end |
51 |
| - predictions2 = let rng = MersenneTwister(42) |
52 |
| - predict(rng, m_lin_reg_test, chain_lin_reg[1:2]) |
53 |
| - end |
54 |
| - @test all(Array(predictions1) .== Array(predictions2)) |
55 |
| - |
56 |
| - # Predict on two last indices for vectorized |
57 |
| - m_lin_reg_test = linear_reg_vec(xs_test, missing) |
58 |
| - predictions_vec = Turing.Inference.predict(m_lin_reg_test, chain_lin_reg) |
59 |
| - ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims=1)) |
60 |
| - |
61 |
| - @test sum(abs2, ys_test - ys_pred_vec) ≤ 0.1 |
62 |
| - |
63 |
| - # Multiple chains |
64 |
| - chain_lin_reg = sample(m_lin_reg, NUTS(100, 0.65), MCMCThreads(), 200, 2) |
65 |
| - m_lin_reg_test = linear_reg(xs_test, fill(missing, length(ys_test))) |
66 |
| - predictions = Turing.Inference.predict(m_lin_reg_test, chain_lin_reg) |
67 |
| - |
68 |
| - @test size(chain_lin_reg, 3) == size(predictions, 3) |
69 |
| - |
70 |
| - for chain_idx in MCMCChains.chains(chain_lin_reg) |
71 |
| - ys_pred = vec(mean(Array(group(predictions[:, :, chain_idx], :y)); dims=1)) |
72 |
| - @test sum(abs2, ys_test - ys_pred) ≤ 0.1 |
73 |
| - end |
74 |
| - |
75 |
| - # Predict on two last indices for vectorized |
76 |
| - m_lin_reg_test = linear_reg_vec(xs_test, missing) |
77 |
| - predictions_vec = Turing.Inference.predict(m_lin_reg_test, chain_lin_reg) |
78 |
| - |
79 |
| - for chain_idx in MCMCChains.chains(chain_lin_reg) |
80 |
| - ys_pred_vec = vec(mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims=1)) |
81 |
| - @test sum(abs2, ys_test - ys_pred_vec) ≤ 0.1 |
82 |
| - end |
83 |
| - |
84 |
| - # https://github.com/TuringLang/Turing.jl/issues/1352 |
85 |
| - @model function simple_linear1(x, y) |
86 |
| - intercept ~ Normal(0, 1) |
87 |
| - coef ~ MvNormal(zeros(2), I) |
88 |
| - coef = reshape(coef, 1, size(x, 1)) |
89 |
| - |
90 |
| - mu = vec(intercept .+ coef * x) |
91 |
| - error ~ truncated(Normal(0, 1), 0, Inf) |
92 |
| - return y ~ MvNormal(mu, error^2 * I) |
93 |
| - end |
94 |
| - |
95 |
| - @model function simple_linear2(x, y) |
96 |
| - intercept ~ Normal(0, 1) |
97 |
| - coef ~ filldist(Normal(0, 1), 2) |
98 |
| - coef = reshape(coef, 1, size(x, 1)) |
99 |
| - |
100 |
| - mu = vec(intercept .+ coef * x) |
101 |
| - error ~ truncated(Normal(0, 1), 0, Inf) |
102 |
| - return y ~ MvNormal(mu, error^2 * I) |
103 |
| - end |
104 |
| - |
105 |
| - @model function simple_linear3(x, y) |
106 |
| - intercept ~ Normal(0, 1) |
107 |
| - coef = Vector(undef, 2) |
108 |
| - for i in axes(coef, 1) |
109 |
| - coef[i] ~ Normal(0, 1) |
110 |
| - end |
111 |
| - coef = reshape(coef, 1, size(x, 1)) |
112 |
| - |
113 |
| - mu = vec(intercept .+ coef * x) |
114 |
| - error ~ truncated(Normal(0, 1), 0, Inf) |
115 |
| - return y ~ MvNormal(mu, error^2 * I) |
116 |
| - end |
117 |
| - |
118 |
| - @model function simple_linear4(x, y) |
119 |
| - intercept ~ Normal(0, 1) |
120 |
| - coef1 ~ Normal(0, 1) |
121 |
| - coef2 ~ Normal(0, 1) |
122 |
| - coef = [coef1, coef2] |
123 |
| - coef = reshape(coef, 1, size(x, 1)) |
124 |
| - |
125 |
| - mu = vec(intercept .+ coef * x) |
126 |
| - error ~ truncated(Normal(0, 1), 0, Inf) |
127 |
| - return y ~ MvNormal(mu, error^2 * I) |
128 |
| - end |
129 |
| - |
130 |
| - # Some data |
131 |
| - x = randn(2, 100) |
132 |
| - y = [1 + 2 * a + 3 * b for (a, b) in eachcol(x)] |
133 |
| - |
134 |
| - for model in [simple_linear1, simple_linear2, simple_linear3, simple_linear4] |
135 |
| - m = model(x, y) |
136 |
| - chain = sample(m, NUTS(), 100) |
137 |
| - chain_predict = predict(model(x, missing), chain) |
138 |
| - mean_prediction = [mean(chain_predict["y[$i]"].data) for i in 1:length(y)] |
139 |
| - @test mean(abs2, mean_prediction - y) ≤ 1e-3 |
140 |
| - end |
141 |
| -end |
142 |
| - |
143 | 7 | @testset "Timer" begin
|
144 | 8 | chain = sample(gdemo_default, MH(), 1000)
|
145 | 9 |
|
|
0 commit comments