Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
keywords = ["markov chain monte carlo", "probablistic programming"]
license = "MIT"
desc = "Chain types and utility functions for MCMC simulations."
version = "7.1.0"
version = "7.2.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
32 changes: 32 additions & 0 deletions docs/src/statsplots.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,36 @@ plot(chn, seriestype = :violin)
corner(chn)
```

## Energy Plot

The energy plot is a diagnostic tool for HMC-based samplers (like NUTS) that helps diagnose sampling efficiency by visualizing the energy and energy transition distributions. This plot requires that the chain contains the internal sampler statistics `:hamiltonian_energy` and `:hamiltonian_energy_error`.

```@example statsplots
# First, we generate a chain that includes the required sampler parameters.
n_iter = 1000
n_chain = 4
val_params = randn(n_iter, 2, n_chain)
val_energy = randn(n_iter, 1, n_chain) .+ 20
val_energy_error = randn(n_iter, 1, n_chain) .* 0.5
full_val = hcat(val_params, val_energy, val_energy_error)

parameter_names = [:a, :b, :hamiltonian_energy, :hamiltonian_energy_error]
section_map = (
parameters=[:a, :b],
internals=[:hamiltonian_energy, :hamiltonian_energy_error],
)

chn_energy = Chains(full_val, parameter_names, section_map)

# Generate the energy plot (default is a density plot).
energyplot(chn_energy)
```

```@example statsplots
# The plot can also be generated as a histogram.
energyplot(chn_energy, kind=:histogram)
```

For plotting multiple parameters, ridgeline, forest and caterpillar plots can be useful.

## Ridgeline
Expand All @@ -156,6 +186,8 @@ forestplot(chn, chn.name_map[:parameters], hpd_val = [0.05, 0.15, 0.25], ordered
## API

```@docs
energyplot
energyplot!
ridgelineplot
ridgelineplot!
forestplot
Expand Down
66 changes: 66 additions & 0 deletions src/plot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,20 @@
@shorthands corner
@shorthands violinplot

"""
energyplot(chains::Chains; kind=:density, kwargs...)

Generate an energy plot for the samples in `chains`.

The energy plot is a diagnostic tool for HMC-based samplers like NUTS. It displays the distributions of the Hamiltonian energy and the energy transition (error) to diagnose sampler efficiency and identify divergences.

This plot is only available for chains that contain the `:hamiltonian_energy` and `:hamiltonian_energy_error` statistics in their `:internals` section.

# Keywords
- `kind::Symbol` (default: `:density`): The type of plot to generate. Can be `:density` or `:histogram`.
"""
@userplot EnergyPlot

"""
ridgelineplot(chains::Chains[, params::Vector{Symbol}]; kwargs...)

Expand Down Expand Up @@ -252,6 +266,58 @@ end
end
end

@recipe function f(p::EnergyPlot; kind = :density)
chains = p.args[1]

if kind ∉ (:density, :histogram)
error("`kind` must be one of `:density` or `:histogram`")
end

internal_names = names(chains, :internals)
required_params = [:hamiltonian_energy, :hamiltonian_energy_error]
for param in required_params
if param ∉ internal_names
error(
"`$param` not found in chain's internal parameters. Energy plots are only available for HMC/NUTS samplers.",
)
end
end

pooled = pool_chain(chains)
energy = vec(pooled[:, :hamiltonian_energy, :])
energy_error = vec(pooled[:, :hamiltonian_energy_error, :])

mean_energy = mean(energy)
std_energy = std(energy)
centered_energy = (energy .- mean_energy) ./ std_energy
scaled_energy_error = energy_error ./ std_energy

title := "Energy Plot"
xaxis := "Standardized Energy"
yaxis := "Density"
legend := :topright

@series begin
seriestype := kind
label := "Marginal Energy"
fillrange --> 0
fillalpha --> 0.5
normalize --> true
bins --> 50
centered_energy
end

@series begin
seriestype := kind
label := "Energy Transition"
fillrange --> 0
fillalpha --> 0.5
normalize --> true
bins --> 50
scaled_energy_error
end
end

@recipe function f(chains::Chains, parameters::AbstractVector{Symbol}; colordim = :chain)
colordim != :chain && error(
"Symbol names are interpreted as parameter names, only compatible with ",
Expand Down
24 changes: 24 additions & 0 deletions test/plot_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ n_chain = 3
val = randn(n_iter, n_name, n_chain) .+ [1, 2, 3]'
val = hcat(val, rand(1:2, n_iter, 1, n_chain))

# This chain is missing the required energy parameters for the energyplot.
chn = Chains(val)

# Silence all warnings.
Expand Down Expand Up @@ -92,6 +93,29 @@ Logging.disable_logging(Logging.Warn)
display(plot(chn, 2))
display(plot(chn, 2, colordim = :parameter))
println()

@testset "Energy plot" begin
# Construct a chain with the required internal parameters.
val_params = randn(n_iter, 2, n_chain)
val_energy = rand(n_iter, 1, n_chain) .* 10 .+ 20
val_energy_error = randn(n_iter, 1, n_chain) .* 0.1
full_val = hcat(val_params, val_energy, val_energy_error)

parameter_names = [:a, :b, :hamiltonian_energy, :hamiltonian_energy_error]
section_map = (
parameters = [:a, :b],
internals = [:hamiltonian_energy, :hamiltonian_energy_error],
)

chn_energy = Chains(full_val, parameter_names, section_map)

println("energyplot")
display(energyplot(chn_energy))
display(energyplot(chn_energy, kind = :histogram))
println()

@test_throws ErrorException energyplot(chn)
end
end

# Reset log level.
Expand Down
Loading