Skip to content

Commit 11b2935

Browse files
committed
response to reviewer's comment
1 parent 12356c1 commit 11b2935

File tree

2 files changed

+21
-9
lines changed

2 files changed

+21
-9
lines changed

test/runtests.jl

+17-9
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ using Test
88
using DiffResults
99
using ForwardDiff
1010

11+
include("util.jl")
12+
1113
@testset "AdvancedMH" begin
1214
# Set a seed
1315
Random.seed!(1234)
@@ -106,21 +108,27 @@ using ForwardDiff
106108
end
107109

108110
@testset "is_symmetric_proposal" begin
111+
# True distributions
112+
d1 = Normal(5, .7)
113+
109114
# Model definition.
110-
m1 = DensityModel(s -> logpdf(Normal(), s.x) + logpdf(Normal(5,.7), s.y))
115+
m1 = DensityModel(x -> logpdf(d1, x))
116+
117+
# Set up the proposal (StandardNormal is a custom distribution in "util.jl").
118+
p1 = RandomWalkProposal(StandardNormal())
119+
120+
# Implement `is_symmetric_proposal` for StandardNormal random walk proposal.
121+
AdvancedMH.is_symmetric_proposal(::RandomWalkProposal{<:StandardNormal}) = true
111122

112-
# Set up the proposal.
113-
p1 = (x=RandomWalkProposal(Normal(0,.5)), y=RandomWalkProposal(Normal(0,.5)))
114-
AdvancedMH.is_symmetric_proposal(proposal::typeof(p1)) = true
123+
# Make sure `is_symmetric_proposal` behaves correctly.
115124
@test AdvancedMH.is_symmetric_proposal(p1)
116125

117126
# Sample from the posterior with initial parameters.
118-
chain1 = sample(m1, MetropolisHastings(p1), 100000; chain_type=Vector{NamedTuple})
127+
chain1 = sample(m1, MetropolisHastings(p1), 100000;
128+
chain_type=StructArray, param_names=["x"])
119129

120-
@test mean(getindex.(chain1, :x)) 0 atol=0.05
121-
@test mean(getindex.(chain1, :y)) 5 atol=0.05
122-
@test std(getindex.(chain1, :x)) 1 atol=0.05
123-
@test std(getindex.(chain1, :y)) .7 atol=0.05
130+
@test mean(chain1.x) mean(d1) atol=0.05
131+
@test std(chain1.x) std(d1) atol=0.05
124132
end
125133

126134
@testset "MALA" begin

test/util.jl

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Define a (custom) Standard Normal distribution, for illustrative puspose.
2+
struct StandardNormal <: Distributions.ContinuousUnivariateDistribution end
3+
Distributions.logpdf(::StandardNormal, x::Real) = -(x ^ 2 + log(2 * pi)) / 2
4+
Distributions.rand(rng::AbstractRNG, ::StandardNormal) = randn(Random.GLOBAL_RNG)

0 commit comments

Comments
 (0)