Skip to content

Commit f557d43

Browse files
authored
Merge pull request #29 from briandepasquale/master
2 parents 2bad97c + 0ef7b9d commit f557d43

File tree

6 files changed

+131
-10
lines changed

6 files changed

+131
-10
lines changed

Project.toml

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedMH"
22
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
3-
version = "0.5.4"
3+
version = "0.5.5"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -18,6 +18,8 @@ julia = "1"
1818
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
1919
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
2020
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
21+
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
22+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2123

2224
[targets]
23-
test = ["StructArrays", "MCMCChains", "Test"]
25+
test = ["StructArrays", "MCMCChains", "Test", "ForwardDiff", "DiffResults"]

README.md

+34
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,37 @@ in parallel for free:
118118
# Sample 4 chains from the posterior.
119119
chain = psample(model, RWMH(init_params), 100000, 4; param_names=["μ","σ"], chain_type=Chains)
120120
```
121+
122+
## Metropolis-adjusted Langevin algorithm (MALA)
123+
124+
AdvancedMH.jl also offers an implementation of [MALA](https://en.wikipedia.org/wiki/Metropolis-adjusted_Langevin_algorithm) if the `ForwardDiff` and `DiffResults` packages are available.
125+
126+
A `MALA` sampler can be constructed by `MALA(proposal)` where `proposal` is a function that
127+
takes the gradient computed at the current sample. It is required to specify an initial sample `init_params` when calling `sample`.
128+
129+
```julia
130+
# Import the package.
131+
using AdvancedMH
132+
using Distributions
133+
using MCMCChains
134+
using DiffResults
135+
using ForwardDiff
136+
137+
# Generate a set of data from the posterior we want to estimate.
138+
data = rand(Normal(0, 1), 30)
139+
140+
# Define the components of a basic model.
141+
insupport(θ) = θ[2] >= 0
142+
dist(θ) = Normal(θ[1], θ[2])
143+
density(θ) = insupport(θ) ? sum(logpdf.(dist(θ), data)) : -Inf
144+
145+
# Construct a DensityModel.
146+
model = DensityModel(density)
147+
148+
# Set up the sampler with a multivariate Gaussian proposal.
149+
sigma = 1e-1
150+
spl = MALA(x -> MvNormal((sigma^2 / 2) .* x, sigma))
151+
152+
# Sample from the posterior.
153+
chain = sample(model, spl, 100000; init_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
154+
```

src/AdvancedMH.jl

+9-3
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,17 @@ import Random
99

1010
# Exports
1111
export MetropolisHastings, DensityModel, RWMH, StaticMH, StaticProposal,
12-
RandomWalkProposal, Ensemble, StretchProposal
12+
RandomWalkProposal, Ensemble, StretchProposal, MALA
1313

1414
# Reexports
1515
export sample, MCMCThreads, MCMCDistributed
1616

1717
# Abstract type for MH-style samplers. Needs better name?
1818
abstract type MHSampler <: AbstractMCMC.AbstractSampler end
1919

20+
# Abstract type for MH-style transitions.
21+
abstract type AbstractTransition end
22+
2023
# Define a model type. Stores the log density function and the data to
2124
# evaluate the log density on.
2225
"""
@@ -37,7 +40,7 @@ end
3740

3841
# Create a very basic Transition type, only stores the
3942
# parameter draws and the log probability of the draw.
40-
struct Transition{T<:Union{Vector, Real, NamedTuple}, L<:Real}
43+
struct Transition{T<:Union{Vector, Real, NamedTuple}, L<:Real} <: AbstractTransition
4144
params :: T
4245
lp :: L
4346
end
@@ -51,7 +54,7 @@ logdensity(model::DensityModel, t::Transition) = t.lp
5154

5255
# A basic chains constructor that works with the Transition struct we defined.
5356
function AbstractMCMC.bundle_samples(
54-
ts::Vector{<:Transition},
57+
ts::Vector{<:AbstractTransition},
5558
model::DensityModel,
5659
sampler::MHSampler,
5760
state,
@@ -98,6 +101,9 @@ end
98101
function __init__()
99102
@require MCMCChains="c7f686f2-ff18-58e9-bc7b-31028e88f75d" include("mcmcchains-connect.jl")
100103
@require StructArrays="09ab397b-f2b6-538f-b94a-2f83cf4a842a" include("structarray-connect.jl")
104+
@require DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" begin
105+
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include("MALA.jl")
106+
end
101107
end
102108

103109
# Include inference methods.

src/MALA.jl

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
using .ForwardDiff: gradient!
2+
using .DiffResults: GradientResult, value, gradient
3+
4+
struct MALA{D} <: MHSampler
5+
proposal::D
6+
end
7+
8+
9+
# Create a RandomWalkProposal if we weren't given one already.
10+
MALA(d) = MALA(RandomWalkProposal(d))
11+
12+
# If we were given a RandomWalkProposal, just use that instead.
13+
MALA(d::RandomWalkProposal) = MALA{typeof(d)}(d)
14+
15+
16+
struct GradientTransition{T<:Union{Vector, Real, NamedTuple}, L<:Real, G<:Union{Vector, Real, NamedTuple}} <: AbstractTransition
17+
params::T
18+
lp::L
19+
gradient::G
20+
end
21+
22+
transition(::MALA, model, params) = GradientTransition(model, params)
23+
24+
# Store the new draw, its log density and its gradient
25+
GradientTransition(model::DensityModel, params) = GradientTransition(params, logdensity_and_gradient(model, params)...)
26+
27+
propose(rng::Random.AbstractRNG, ::MALA, model) = error("please specify initial parameters")
28+
29+
function propose(
30+
rng::Random.AbstractRNG,
31+
spl::MALA{<:Proposal},
32+
model::DensityModel,
33+
params_prev::GradientTransition
34+
)
35+
proposal = propose(rng, spl.proposal(params_prev.gradient), model, params_prev.params)
36+
return GradientTransition(model, proposal)
37+
end
38+
39+
40+
function q(
41+
spl::MALA{<:Proposal},
42+
t::GradientTransition,
43+
t_cond::GradientTransition
44+
)
45+
return q(spl.proposal(-t_cond.gradient), t.params, t_cond.params)
46+
end
47+
48+
49+
"""
50+
logdensity_and_gradient(model::DensityModel, params)
51+
52+
Efficiently returns the value and gradient of the model
53+
"""
54+
function logdensity_and_gradient(model::DensityModel, params)
55+
res = GradientResult(params)
56+
gradient!(res, model.logdensity, params)
57+
return (value(res), gradient(res))
58+
end
59+
60+
61+
logdensity(model::DensityModel, t::GradientTransition) = t.lp

src/mh-core.jl

+7-4
Original file line numberDiff line numberDiff line change
@@ -169,20 +169,23 @@ function q(
169169
end
170170
end
171171

172+
transition(sampler, model, params) = transition(model, params)
173+
transition(model, params) = Transition(model, params)
174+
172175
# Define the first sampling step.
173176
# Return a 2-tuple consisting of the initial sample and the initial state.
174177
# In this case they are identical.
175178
function AbstractMCMC.step(
176179
rng::Random.AbstractRNG,
177180
model::DensityModel,
178-
spl::MetropolisHastings;
181+
spl::MHSampler;
179182
init_params=nothing,
180183
kwargs...
181184
)
182185
if init_params === nothing
183186
transition = propose(rng, spl, model)
184187
else
185-
transition = Transition(model, init_params)
188+
transition = AdvancedMH.transition(spl, model, init_params)
186189
end
187190

188191
return transition, transition
@@ -195,8 +198,8 @@ end
195198
function AbstractMCMC.step(
196199
rng::Random.AbstractRNG,
197200
model::DensityModel,
198-
spl::MetropolisHastings,
199-
params_prev::Transition;
201+
spl::MHSampler,
202+
params_prev::AbstractTransition;
200203
kwargs...
201204
)
202205
# Generate a new proposal.

test/runtests.jl

+16-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ using MCMCChains
55

66
using Random
77
using Test
8+
using DiffResults
9+
using ForwardDiff
810

911
@testset "AdvancedMH" begin
1012
# Set a seed
@@ -102,7 +104,20 @@ using Test
102104

103105
@test chain1[1].params == val
104106
end
107+
108+
@testset "MALA" begin
109+
110+
# Set up the sampler.
111+
sigma = 1e-1
112+
spl1 = MALA(x -> MvNormal((sigma^2 / 2) .* x, sigma))
113+
114+
# Sample from the posterior with initial parameters.
115+
chain1 = sample(model, spl1, 100000; init_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
116+
117+
@test mean(chain1.μ) 0.0 atol=0.1
118+
@test mean(chain1.σ) 1.0 atol=0.1
119+
end
105120

106121
@testset "EMCEE" begin include("emcee.jl") end
122+
107123
end
108-

0 commit comments

Comments
 (0)