Skip to content

Commit f75b675

Browse files
authored
Merge pull request #46 from luiarthur/issue41-attempt2
Skipping the Hastings computation for symmetric proposals
2 parents 35d310e + 8a89c4f commit f75b675

File tree

4 files changed

+63
-4
lines changed

4 files changed

+63
-4
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedMH"
22
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
3-
version = "0.5.5"
3+
version = "0.5.6"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/mh-core.jl

+31-2
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,31 @@ function AbstractMCMC.step(
191191
return transition, transition
192192
end
193193

194+
"""
195+
is_symmetric_proposal(proposal)::Bool
196+
197+
Implementing this for a custom proposal will allow `AbstractMCMC.step` to avoid
198+
computing the "Hastings" part of the Metropolis-Hasting log acceptance
199+
probability (if the proposal is indeed symmetric). By default,
200+
`is_symmetric_proposal(proposal)` returns `false`. The user is responsible for
201+
determining whether a custom proposal distribution is indeed symmetric. As
202+
noted in `MetropolisHastings`, `proposal` is a `Proposal`, `NamedTuple` of
203+
`Proposal`, or `Array{Proposal}` in the shape of your data.
204+
"""
205+
is_symmetric_proposal(proposal) = false
206+
207+
# The following univariate random walk proposals are symmetric.
208+
is_symmetric_proposal(::RandomWalkProposal{<:Normal}) = true
209+
is_symmetric_proposal(::RandomWalkProposal{<:MvNormal}) = true
210+
is_symmetric_proposal(::RandomWalkProposal{<:TDist}) = true
211+
is_symmetric_proposal(::RandomWalkProposal{<:Cauchy}) = true
212+
213+
# The following multivariate random walk proposals are symmetric.
214+
is_symmetric_proposal(::RandomWalkProposal{<:AbstractArray{<:Normal}}) = true
215+
is_symmetric_proposal(::RandomWalkProposal{<:AbstractArray{<:MvNormal}}) = true
216+
is_symmetric_proposal(::RandomWalkProposal{<:AbstractArray{<:TDist}}) = true
217+
is_symmetric_proposal(::RandomWalkProposal{<:AbstractArray{<:Cauchy}}) = true
218+
194219
# Define the other sampling steps.
195220
# Return a 2-tuple consisting of the next sample and the the next state.
196221
# In this case they are identical, and either a new proposal (if accepted)
@@ -206,8 +231,12 @@ function AbstractMCMC.step(
206231
params = propose(rng, spl, model, params_prev)
207232

208233
# Calculate the log acceptance probability.
209-
logα = logdensity(model, params) - logdensity(model, params_prev) +
210-
q(spl, params_prev, params) - q(spl, params, params_prev)
234+
logα = logdensity(model, params) - logdensity(model, params_prev)
235+
236+
# Compute Hastings portion of ratio if proposal is not symmetric.
237+
if !is_symmetric_proposal(spl.proposal)
238+
logα += q(spl, params_prev, params) - q(spl, params, params_prev)
239+
end
211240

212241
# Decide whether to return the previous params or the new one.
213242
if -Random.randexp(rng) < logα

test/runtests.jl

+27-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ using Test
88
using DiffResults
99
using ForwardDiff
1010

11+
include("util.jl")
12+
1113
@testset "AdvancedMH" begin
1214
# Set a seed
1315
Random.seed!(1234)
@@ -104,7 +106,31 @@ using ForwardDiff
104106

105107
@test chain1[1].params == val
106108
end
107-
109+
110+
@testset "is_symmetric_proposal" begin
111+
# True distributions
112+
d1 = Normal(5, .7)
113+
114+
# Model definition.
115+
m1 = DensityModel(x -> logpdf(d1, x))
116+
117+
# Set up the proposal (StandardNormal is a custom distribution in "util.jl").
118+
p1 = RandomWalkProposal(StandardNormal())
119+
120+
# Implement `is_symmetric_proposal` for StandardNormal random walk proposal.
121+
AdvancedMH.is_symmetric_proposal(::RandomWalkProposal{<:StandardNormal}) = true
122+
123+
# Make sure `is_symmetric_proposal` behaves correctly.
124+
@test AdvancedMH.is_symmetric_proposal(p1)
125+
126+
# Sample from the posterior with initial parameters.
127+
chain1 = sample(m1, MetropolisHastings(p1), 100000;
128+
chain_type=StructArray, param_names=["x"])
129+
130+
@test mean(chain1.x) mean(d1) atol=0.05
131+
@test std(chain1.x) std(d1) atol=0.05
132+
end
133+
108134
@testset "MALA" begin
109135

110136
# Set up the sampler.

test/util.jl

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Define a (custom) Standard Normal distribution, for illustrative puspose.
2+
struct StandardNormal <: Distributions.ContinuousUnivariateDistribution end
3+
Distributions.logpdf(::StandardNormal, x::Real) = -(x ^ 2 + log(2 * pi)) / 2
4+
Distributions.rand(rng::AbstractRNG, ::StandardNormal) = randn(Random.GLOBAL_RNG)

0 commit comments

Comments
 (0)