Skip to content

csml-beach/multirate-sampling

Repository files navigation

Multirate Sampling

Open in GitHub Codespaces

Code accompanying the multirate SVGD submission. The repository implements fixed and adaptive multirate SVGD variants, compares them against particle and stochastic-gradient MCMC baselines, and evaluates quality-cost tradeoffs across synthetic and UCI tasks.

Banana demo animation
Banana
Squiggly demo animation
Squiggly
Two Moons demo animation
Two Moons
Ring demo animation
Ring

Highlights

  • Core JAX samplers: SVGD, Strang-SVGD, MR-SVGD, Adapt-MR-SVGD, SGLD, SGHMC.
  • Benchmarks: 50D Gaussian, 2D targets, UCI logistic regression, UCI BNN, 2D mixture (mix8), and large-scale HLR (longtail/uniform groups).
  • Unified CSV metrics and plotting scripts used for manuscript figures.
  • Paper source in paper/.

Repository Layout

  • jax/samplers.py: sampler implementations.
  • jax/benchmarks/: benchmark drivers, metrics, datasets, and plotting scripts.
  • metrics/: generated benchmark CSV outputs.
  • figures/: generated plots.
  • animations/: generated GIFs/animations.
  • docs/README.md: detailed benchmark runbook.
  • paper/: manuscript source.

Run on GitHub Codespaces (Recommended)

  1. Click the Codespaces badge above.
  2. Create a new codespace on the default branch.
  3. Wait for the post-create setup to finish (defined in .devcontainer/postCreate.sh).
  4. Run a quick check:
python -c "import jax; print(jax.devices())"
  1. Run a representative benchmark:
python jax/benchmarks/gauss50/benchmark_gauss50.py
python jax/benchmarks/gauss50/plot_gauss50.py

Local Setup

python -m venv .venv
source .venv/bin/activate
pip install --upgrade pip
pip install -r requirements.txt

Notes:

  • UCI datasets are downloaded automatically into data/uci/ on first use.
  • Metrics and figures are generated by running the benchmark scripts in this repository.

Basic Reproduction

Quick experiment (low compute):

python jax/benchmarks/hlr/benchmark_hlr.py \
  --group-mode longtail \
  --n-samples 4000 \
  --n-features 24 \
  --n-groups 300 \
  --iters 40 \
  --save-every 20 \
  --particles 12 \
  --seeds 0
python jax/benchmarks/hlr/plot_hlr.py

HLR (both group modes, five seeds):

python jax/benchmarks/hlr/benchmark_hlr.py --group-mode both --seeds 0,1,2,3,4
python jax/benchmarks/hlr/plot_hlr.py

Regenerate plots for existing metrics:

bash scripts/run_all_plots.sh

For full benchmark-by-benchmark commands and output paths, see docs/README.md.

Documentation

About

Multirate SVGD benchmark suite in JAX for Bayesian sampling: fixed/adaptive multirate variants vs SVGD, SGLD, and SGHMC across 50D Gaussian, 2D synthetic targets, UCI logistic/BNN tasks, and large-scale hierarchical logistic regression.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Contributors

Languages