Skip to content

Commit db083fc

Browse files
add acceptance rate info to Transition struct (#88)
* add acceptance rate info to Transition struct * move acceptance statistics to MHSampler * Revert "move acceptance statistics to MHSampler" This reverts commit fcfafaf. * incorporate feedback - use only accepted boolean in Transition * Update src/mh-core.jl --------- Co-authored-by: Cameron Pfiffer <[email protected]>
1 parent 3749df0 commit db083fc

File tree

4 files changed

+50
-38
lines changed

4 files changed

+50
-38
lines changed

src/AdvancedMH.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@ export
2727
# Reexports
2828
export sample, MCMCThreads, MCMCDistributed, MCMCSerial
2929

30-
# Abstract type for MH-style samplers. Needs better name?
30+
# Abstract type for MH-style samplers. Needs better name?
3131
abstract type MHSampler <: AbstractMCMC.AbstractSampler end
3232

3333
# Abstract type for MH-style transitions.
3434
abstract type AbstractTransition end
3535

36-
# Define a model type. Stores the log density function and the data to
36+
# Define a model type. Stores the log density function and the data to
3737
# evaluate the log density on.
3838
"""
3939
DensityModel{F} <: AbstractModel
@@ -53,17 +53,19 @@ end
5353

5454
const DensityModelOrLogDensityModel = Union{<:DensityModel,<:AbstractMCMC.LogDensityModel}
5555

56-
# Create a very basic Transition type, only stores the
57-
# parameter draws and the log probability of the draw.
56+
# Create a very basic Transition type, stores the
57+
# parameter draws, the log probability of the draw,
58+
# and the draw information until this point
5859
struct Transition{T,L<:Real} <: AbstractTransition
5960
params :: T
6061
lp :: L
62+
accepted :: Bool
6163
end
6264

63-
# Store the new draw and its log density.
64-
Transition(model::DensityModelOrLogDensityModel, params) = Transition(params, logdensity(model, params))
65-
function Transition(model::AbstractMCMC.LogDensityModel, params)
66-
return Transition(params, LogDensityProblems.logdensity(model.logdensity, params))
65+
# Store the new draw, its log density, and draw information
66+
Transition(model::DensityModelOrLogDensityModel, params, accepted) = Transition(params, logdensity(model, params), accepted)
67+
function Transition(model::AbstractMCMC.LogDensityModel, params, accepted)
68+
return Transition(params, LogDensityProblems.logdensity(model.logdensity, params), accepted)
6769
end
6870

6971
# Calculate the density of the model given some parameterization.
@@ -128,7 +130,7 @@ function __init__()
128130
if exc.f === logdensity_and_gradient && length(arg_types) == 2 && first(arg_types) <: DensityModel && isempty(kwargs)
129131
print(io, "\\nDid you forget to load ForwardDiff?")
130132
end
131-
end
133+
end
132134
@static if !isdefined(Base, :get_extension)
133135
@require MCMCChains="c7f686f2-ff18-58e9-bc7b-31028e88f75d" include("../ext/AdvancedMHMCMCChainsExt.jl")
134136
@require StructArrays="09ab397b-f2b6-538f-b94a-2f83cf4a842a" include("../ext/AdvancedMHStructArraysExt.jl")

src/MALA.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@ struct GradientTransition{T<:Union{Vector, Real, NamedTuple}, L<:Real, G<:Union{
1515
params::T
1616
lp::L
1717
gradient::G
18+
accepted::Bool
1819
end
1920

2021
logdensity(model::DensityModelOrLogDensityModel, t::GradientTransition) = t.lp
2122

2223
propose(rng::Random.AbstractRNG, ::MALA, model) = error("please specify initial parameters")
23-
function transition(sampler::MALA, model::DensityModelOrLogDensityModel, params)
24-
return GradientTransition(params, logdensity_and_gradient(model, params)...)
24+
function transition(sampler::MALA, model::DensityModelOrLogDensityModel, params, accepted)
25+
return GradientTransition(params, logdensity_and_gradient(model, params)..., accepted)
2526
end
2627

2728
check_capabilities(model::DensityModelOrLogDensityModel) = nothing
@@ -44,7 +45,7 @@ function AbstractMCMC.step(
4445
kwargs...
4546
)
4647
check_capabilities(model)
47-
48+
4849
# Extract value and gradient of the log density of the current state.
4950
state = transition_prev.params
5051
logdensity_state = transition_prev.lp
@@ -69,9 +70,12 @@ function AbstractMCMC.step(
6970

7071
# Decide whether to return the previous params or the new one.
7172
transition = if -Random.randexp(rng) < logα
72-
GradientTransition(candidate, logdensity_candidate, gradient_logdensity_candidate)
73+
GradientTransition(candidate, logdensity_candidate, gradient_logdensity_candidate, true)
7374
else
74-
transition_prev
75+
candidate = transition_prev.params
76+
lp = transition_prev.lp
77+
gradient = transition_prev.gradient
78+
GradientTransition(candidate, lp, gradient, false)
7579
end
7680

7781
return transition, transition

src/emcee.jl

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ struct Ensemble{D} <: MHSampler
33
proposal::D
44
end
55

6-
function transition(sampler::Ensemble, model::DensityModelOrLogDensityModel, params)
7-
return [Transition(model, x) for x in params]
6+
function transition(sampler::Ensemble, model::DensityModelOrLogDensityModel, params, accepted)
7+
return [Transition(model, x, accepted) for x in params]
88
end
99

1010
# Define the other sampling steps.
@@ -68,7 +68,7 @@ end
6868
StretchProposal(p) = StretchProposal(p, 2.0)
6969

7070
function move(
71-
rng::Random.AbstractRNG,
71+
rng::Random.AbstractRNG,
7272
spl::Ensemble{<:StretchProposal},
7373
model::DensityModelOrLogDensityModel,
7474
walker::Transition,
@@ -84,16 +84,20 @@ function move(
8484
# Make new parameters
8585
y = @. other_walker.params + z * (walker.params - other_walker.params)
8686

87-
# Construct a new walker
88-
new_walker = Transition(model, y)
87+
# Construct a temporary new walker
88+
temp_walker = Transition(model, y, true)
8989

9090
# Calculate accept/reject value.
91-
alpha = alphamult + new_walker.lp - walker.lp
91+
alpha = alphamult + temp_walker.lp - walker.lp
9292

9393
if -Random.randexp(rng) <= alpha
94+
new_walker = Transition(model, y, true)
9495
return new_walker
9596
else
96-
return walker
97+
params = walker.params
98+
lp = walker.lp
99+
old_walker = Transition(params, lp, false)
100+
return old_walker
97101
end
98102
end
99103

@@ -124,7 +128,7 @@ end
124128

125129
# theta_min = theta - 2.0*π
126130
# theta_max = theta
127-
131+
128132
# f = walker.params
129133
# while true
130134
# stheta, ctheta = sincos(theta)
@@ -136,15 +140,15 @@ end
136140
# if new_walker.lp > y
137141
# return new_walker
138142
# else
139-
# if theta < 0
143+
# if theta < 0
140144
# theta_min = theta
141145
# else
142146
# theta_max = theta
143147
# end
144148

145149
# theta = theta_min + (theta_max - theta_min) * rand()
146150
# end
147-
# end
151+
# end
148152
# end
149153

150154
#####################
@@ -180,15 +184,15 @@ end
180184

181185
# theta_min = theta - 2.0*π
182186
# theta_max = theta
183-
187+
184188
# f = walker.params
185189

186190
# i = 0
187191
# while true
188192
# i += 1
189-
193+
190194
# stheta, ctheta = sincos(theta)
191-
195+
192196
# f_prime = f .* ctheta + nu .* stheta
193197

194198
# new_walker = Transition(model, f_prime)
@@ -198,13 +202,13 @@ end
198202
# if new_walker.lp > y
199203
# return new_walker
200204
# else
201-
# if theta < 0
205+
# if theta < 0
202206
# theta_min = theta
203207
# else
204208
# theta_max = theta
205209
# end
206210

207211
# theta = theta_min + (theta_max - theta_min) * rand()
208212
# end
209-
# end
213+
# end
210214
# end

src/mh-core.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
MetropolisHastings{D}
33
4-
`MetropolisHastings` has one field, `proposal`.
4+
`MetropolisHastings` has one field, `proposal`.
55
`proposal` is a `Proposal`, `NamedTuple` of `Proposal`, or `Array{Proposal}` in the shape of your data.
66
For example, if you wanted the sampler to return a `NamedTuple` with shape
77
@@ -38,7 +38,7 @@ none is given, the initial parameters will be drawn from the sampler's proposals
3838
- `param_names` is a vector of strings to be assigned to parameters. This is only
3939
used if `chain_type=Chains`.
4040
- `chain_type` is the type of chain you would like returned to you. Supported
41-
types are `chain_type=Chains` if `MCMCChains` is imported, or
41+
types are `chain_type=Chains` if `MCMCChains` is imported, or
4242
`chain_type=StructArray` if `StructArrays` is imported.
4343
"""
4444
struct MetropolisHastings{D} <: MHSampler
@@ -62,12 +62,12 @@ function propose(
6262
return propose(rng, sampler.proposal, model, transition_prev.params)
6363
end
6464

65-
function transition(sampler::MHSampler, model::DensityModelOrLogDensityModel, params)
65+
function transition(sampler::MHSampler, model::DensityModelOrLogDensityModel, params, accepted)
6666
logdensity = AdvancedMH.logdensity(model, params)
67-
return transition(sampler, model, params, logdensity)
67+
return transition(sampler, model, params, logdensity, accepted)
6868
end
69-
function transition(sampler::MHSampler, model::DensityModelOrLogDensityModel, params, logdensity::Real)
70-
return Transition(params, logdensity)
69+
function transition(sampler::MHSampler, model::DensityModelOrLogDensityModel, params, logdensity::Real, accepted)
70+
return Transition(params, logdensity, accepted)
7171
end
7272

7373
# Define the first sampling step.
@@ -81,7 +81,7 @@ function AbstractMCMC.step(
8181
kwargs...
8282
)
8383
params = initial_params === nothing ? propose(rng, sampler, model) : initial_params
84-
transition = AdvancedMH.transition(sampler, model, params)
84+
transition = AdvancedMH.transition(sampler, model, params, false)
8585
return transition, transition
8686
end
8787

@@ -106,9 +106,11 @@ function AbstractMCMC.step(
106106

107107
# Decide whether to return the previous params or the new one.
108108
transition = if -Random.randexp(rng) < logα
109-
AdvancedMH.transition(sampler, model, candidate, logdensity_candidate)
109+
AdvancedMH.transition(sampler, model, candidate, logdensity_candidate, true)
110110
else
111-
transition_prev
111+
params = transition_prev.params
112+
lp = transition_prev.lp
113+
Transition(params, lp, false)
112114
end
113115

114116
return transition, transition

0 commit comments

Comments
 (0)