Skip to content

Add Stochastic Gradient HMC #428

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open

Add Stochastic Gradient HMC #428

wants to merge 11 commits into from

Conversation

ErikQQY
Copy link
Collaborator

@ErikQQY ErikQQY commented Apr 23, 2025

Part of #60

This is a WIP PR, still need to configure the implementation of SGHMC in Turing.jl, mainly about the usage of DynamicPPL.jl.

@ErikQQY ErikQQY marked this pull request as draft April 23, 2025 11:14
@yebai yebai requested a review from sunxd3 April 25, 2025 09:21
@sunxd3
Copy link
Member

sunxd3 commented Apr 30, 2025

a bump on this

@ErikQQY ErikQQY marked this pull request as ready for review May 18, 2025 09:35
Copy link
Contributor

AdvancedHMC.jl documentation for PR #428 is available at:
https://TuringLang.github.io/AdvancedHMC.jl/previews/PR428/

@ErikQQY
Copy link
Collaborator Author

ErikQQY commented May 18, 2025

The basic SGHMC algorithm is intuitive, but in the AbstractMCMC tests, there are Inf in the final result after being transformed back to the original space.Gonna need some advice on what's going wrong here @sunxd3

@sunxd3
Copy link
Member

sunxd3 commented May 18, 2025

let me take a look, I haven't deal with this part of the code before, so give me bit of time.

@sunxd3
Copy link
Member

sunxd3 commented May 21, 2025

The following version passes the test

function AbstractMCMC.step(
    rng::AbstractRNG,
    model::AbstractMCMC.LogDensityModel,
    spl::SGHMC,
    state::SGHMCState;
    n_adapts::Int=0,
    kwargs...,
)
    if haskey(kwargs, :nadapts)
        throw(
            ArgumentError(
                "keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps.",
            ),
        )
    end

    i = state.i + 1
    t_old = state.transition
    adaptor = state.adaptor
    κ = state.κ
    metric = state.metric

    # Reconstruct hamiltonian.
    h = Hamiltonian(metric, model)

    # Compute gradient of log density.
    logdensity_and_gradient = Base.Fix1(
        LogDensityProblems.logdensity_and_gradient, model.logdensity
    )
    θ = copy(t_old.z.θ)
    grad = last(logdensity_and_gradient(θ))

    # Update latent variables and velocity according to
    # equation (15) of Chen et al. (2014)
    v = state.velocity
    η = spl.learning_rate
    α = spl.momentum_decay
    newv = (1 - α) .* v .+ η .* grad .+ sqrt(2 * η * α) .* randn(rng, eltype(v), length(v))
    θ .+= newv

    # Make new transition.
    z = phasepoint(h, θ, v)
    t = transition(rng, h, κ, z)

    # Adapt h and spl.
    tstat = stat(t)
    h, κ, isadapted = adapt!(h, κ, adaptor, i, n_adapts, θ, tstat.acceptance_rate)
    tstat = merge(tstat, (is_adapt=isadapted,))

    # Compute next sample and state.
    sample = Transition(t.z, tstat)
    newstate = SGHMCState(i, t, h.metric, κ, adaptor, newv)

    return sample, newstate
end

I made three updates:

  1. copy θ
  2. update θ use the newv
  3. create t use the updated θ

I was a bit uneasy about the inplace update, thinking there might be some mismatch between the parameters and logp. So I made the above changes and I think they at least make the test case correct.

I know the logic is copied from Turing.jl and the order of updating v and θ was okay as it was. But I have to switch the order to make it work here. Ideally I can get to the bottom, but thought you might know better.

@ErikQQY
Copy link
Collaborator Author

ErikQQY commented May 22, 2025

@sunxd3 Thanks a lot! I see what's going wrong here, it's the parameters have not been updated to the right state and the in-place changing for parameter theta is not suitable here, now the issue has been fixed!

If the CIs are all green, this PR should be ready now

@ErikQQY ErikQQY requested review from yebai and sunxd3 May 22, 2025 04:31
@sunxd3
Copy link
Member

sunxd3 commented May 22, 2025

Looks fine to me. Now the tests pass, I think the algorithm is very likely to be correctly implemented.

Still curious why this works (I am referring to the order or updating theta first).

sunxd3
sunxd3 previously approved these changes May 22, 2025
@ErikQQY
Copy link
Collaborator Author

ErikQQY commented May 22, 2025

Still curious why this works (I am referring to the order or updating theta first).

While the order of updating theta is different in AHMC and Turing, it seems they both work fine? But it seems the Turing one should be correct, though, in each step, theta is only updated using velocity from the previous step.

@sunxd3
Copy link
Member

sunxd3 commented May 22, 2025

That is what confuses me, because in this PR, the order can't be reversed. (I think it will fail if we use the old velocity.)

@sunxd3
Copy link
Member

sunxd3 commented May 26, 2025

Tag @yebai for knowledge and review. I think this is ready (certainty of correctness is quite high but not 100 percent).

@sunxd3
Copy link
Member

sunxd3 commented May 26, 2025

@ErikQQY version bump maybe? (We would need to bump minor version because SGHMC will be exported. Probably lump several changes together or do something like TuringLang/Turing.jl#2517)

@ErikQQY
Copy link
Collaborator Author

ErikQQY commented May 26, 2025

Probably bump several changes together or do something like TuringLang/Turing.jl#2517)

I think there might be some methods ambiguity when using AdvancedHMC inside Turing, since they both export SGHMC, we need to take care of that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants