Skip to content

Commit 3749df0

Browse files
authored
Make AdvancedMH compatible with AbstractMCMC 5 (#92)
* Make AdvancedMH compatible with AbstractMCMC 5 * Fix typo * Change is breaking
1 parent 8ddb81e commit 3749df0

File tree

4 files changed

+18
-18
lines changed

4 files changed

+18
-18
lines changed

Project.toml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedMH"
22
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
3-
version = "0.7.6"
3+
version = "0.8.0"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -23,18 +23,18 @@ AdvancedMHMCMCChainsExt = "MCMCChains"
2323
AdvancedMHStructArraysExt = "StructArrays"
2424

2525
[compat]
26-
AbstractMCMC = "4, 5"
26+
AbstractMCMC = "5"
2727
DiffResults = "1"
28-
Distributions = "0.20 - 0.25"
29-
LinearAlgebra = "1.6 - 1.11"
30-
Random = "1.6 - 1.11"
28+
Distributions = "0.25"
3129
FillArrays = "1"
3230
ForwardDiff = "0.10"
3331
LogDensityProblems = "2"
34-
MCMCChains = "5, 6"
32+
MCMCChains = "6.0.4"
3533
Requires = "1"
3634
StructArrays = "0.6"
3735
julia = "1.6"
36+
LinearAlgebra = "1.6"
37+
Random = "1.6"
3838

3939
[extras]
4040
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"

README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,21 +138,21 @@ AdvancedMH.jl implements the interface of [AbstractMCMC](https://github.com/Turi
138138

139139
```julia
140140
# Sample 4 chains from the posterior serially, without thread or process parallelism.
141-
chain = sample(model, RWMH(init_params), MCMCSerial(), 100000, 4; param_names=["μ","σ"], chain_type=Chains)
141+
chain = sample(model, spl, MCMCSerial(), 100000, 4; param_names=["μ","σ"], chain_type=Chains)
142142

143143
# Sample 4 chains from the posterior using multiple threads.
144-
chain = sample(model, RWMH(init_params), MCMCThreads(), 100000, 4; param_names=["μ","σ"], chain_type=Chains)
144+
chain = sample(model, spl, MCMCThreads(), 100000, 4; param_names=["μ","σ"], chain_type=Chains)
145145

146146
# Sample 4 chains from the posterior using multiple processes.
147-
chain = sample(model, RWMH(init_params), MCMCDistributed(), 100000, 4; param_names=["μ","σ"], chain_type=Chains)
147+
chain = sample(model, spl, MCMCDistributed(), 100000, 4; param_names=["μ","σ"], chain_type=Chains)
148148
```
149149

150150
## Metropolis-adjusted Langevin algorithm (MALA)
151151

152152
AdvancedMH.jl also offers an implementation of [MALA](https://en.wikipedia.org/wiki/Metropolis-adjusted_Langevin_algorithm) if the `ForwardDiff` and `DiffResults` packages are available.
153153

154154
A `MALA` sampler can be constructed by `MALA(proposal)` where `proposal` is a function that
155-
takes the gradient computed at the current sample. It is required to specify an initial sample `init_params` when calling `sample`.
155+
takes the gradient computed at the current sample. It is required to specify an initial sample `initial_params` when calling `sample`.
156156

157157
```julia
158158
# Import the package.
@@ -180,7 +180,7 @@ model = DensityModel(density)
180180
spl = MALA(x -> MvNormal((σ² / 2) .* x, σ² * I))
181181

182182
# Sample from the posterior.
183-
chain = sample(model, spl, 100000; init_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
183+
chain = sample(model, spl, 100000; initial_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
184184
```
185185

186186
### Usage with [`LogDensityProblems.jl`](https://github.com/tpapp/LogDensityProblems.jl)
@@ -192,5 +192,5 @@ Using our implementation of the `LogDensityProblems.jl` interface above:
192192
```julia
193193
using LogDensityProblemsAD
194194
model_with_ad = LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), LogTargetDensity())
195-
sample(model_with_ad, spl, 100000; init_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
195+
sample(model_with_ad, spl, 100000; initial_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
196196
```

src/mh-core.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ spl = MetropolisHastings(proposal)
3333
When using `MetropolisHastings` with the function `sample`, the following keyword
3434
arguments are allowed:
3535
36-
- `init_params` defines the initial parameterization for your model. If
36+
- `initial_params` defines the initial parameterization for your model. If
3737
none is given, the initial parameters will be drawn from the sampler's proposals.
3838
- `param_names` is a vector of strings to be assigned to parameters. This is only
3939
used if `chain_type=Chains`.
@@ -77,10 +77,10 @@ function AbstractMCMC.step(
7777
rng::Random.AbstractRNG,
7878
model::DensityModelOrLogDensityModel,
7979
sampler::MHSampler;
80-
init_params=nothing,
80+
initial_params=nothing,
8181
kwargs...
8282
)
83-
params = init_params === nothing ? propose(rng, sampler, model) : init_params
83+
params = initial_params === nothing ? propose(rng, sampler, model) : initial_params
8484
transition = AdvancedMH.transition(sampler, model, params)
8585
return transition, transition
8686
end

test/runtests.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ include("util.jl")
182182
val = [0.4, 1.2]
183183

184184
# Sample from the posterior.
185-
chain1 = sample(model, spl1, 10, init_params = val)
185+
chain1 = sample(model, spl1, 10, initial_params = val)
186186

187187
@test chain1[1].params == val
188188
end
@@ -265,7 +265,7 @@ include("util.jl")
265265
spl1 = MALA(x -> MvNormal((σ² / 2) .* x, σ² * I))
266266

267267
# Sample from the posterior with initial parameters.
268-
chain1 = sample(model, spl1, 100000; init_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
268+
chain1 = sample(model, spl1, 100000; initial_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
269269

270270
@test mean(chain1.μ) 0.0 atol=0.1
271271
@test mean(chain1.σ) 1.0 atol=0.1
@@ -276,7 +276,7 @@ include("util.jl")
276276
admodel,
277277
spl1,
278278
100000;
279-
init_params=ones(2),
279+
initial_params=ones(2),
280280
chain_type=StructArray,
281281
param_names=["μ", "σ"]
282282
)

0 commit comments

Comments
 (0)