-
Notifications
You must be signed in to change notification settings - Fork 46
Add Stochastic Gradient HMC #428
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
a bump on this |
AdvancedHMC.jl documentation for PR #428 is available at: |
The basic SGHMC algorithm is intuitive, but in the AbstractMCMC tests, there are |
let me take a look, I haven't deal with this part of the code before, so give me bit of time. |
The following version passes the test function AbstractMCMC.step(
rng::AbstractRNG,
model::AbstractMCMC.LogDensityModel,
spl::SGHMC,
state::SGHMCState;
n_adapts::Int=0,
kwargs...,
)
if haskey(kwargs, :nadapts)
throw(
ArgumentError(
"keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps.",
),
)
end
i = state.i + 1
t_old = state.transition
adaptor = state.adaptor
κ = state.κ
metric = state.metric
# Reconstruct hamiltonian.
h = Hamiltonian(metric, model)
# Compute gradient of log density.
logdensity_and_gradient = Base.Fix1(
LogDensityProblems.logdensity_and_gradient, model.logdensity
)
θ = copy(t_old.z.θ)
grad = last(logdensity_and_gradient(θ))
# Update latent variables and velocity according to
# equation (15) of Chen et al. (2014)
v = state.velocity
η = spl.learning_rate
α = spl.momentum_decay
newv = (1 - α) .* v .+ η .* grad .+ sqrt(2 * η * α) .* randn(rng, eltype(v), length(v))
θ .+= newv
# Make new transition.
z = phasepoint(h, θ, v)
t = transition(rng, h, κ, z)
# Adapt h and spl.
tstat = stat(t)
h, κ, isadapted = adapt!(h, κ, adaptor, i, n_adapts, θ, tstat.acceptance_rate)
tstat = merge(tstat, (is_adapt=isadapted,))
# Compute next sample and state.
sample = Transition(t.z, tstat)
newstate = SGHMCState(i, t, h.metric, κ, adaptor, newv)
return sample, newstate
end I made three updates:
I was a bit uneasy about the inplace update, thinking there might be some mismatch between the parameters and logp. So I made the above changes and I think they at least make the test case correct. I know the logic is copied from Turing.jl and the order of updating |
@sunxd3 Thanks a lot! I see what's going wrong here, it's the parameters have not been updated to the right state and the in-place changing for parameter theta is not suitable here, now the issue has been fixed! If the CIs are all green, this PR should be ready now |
Looks fine to me. Now the tests pass, I think the algorithm is very likely to be correctly implemented. Still curious why this works (I am referring to the order or updating theta first). |
While the order of updating theta is different in AHMC and Turing, it seems they both work fine? But it seems the Turing one should be correct, though, in each step, theta is only updated using velocity from the previous step. |
That is what confuses me, because in this PR, the order can't be reversed. (I think it will fail if we use the old velocity.) |
Tag @yebai for knowledge and review. I think this is ready (certainty of correctness is quite high but not 100 percent). |
@ErikQQY version bump maybe? (We would need to bump minor version because |
I think there might be some methods ambiguity when using AdvancedHMC inside Turing, since they both export SGHMC, we need to take care of that. |
Part of #60
This is a WIP PR, still need to configure the implementation of SGHMC in Turing.jl, mainly about the usage of DynamicPPL.jl.