Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions examples/examples.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,16 +108,19 @@ static void run_nuts(const F& target_logp_grad, const VectorS& theta_init,
template <typename F, typename RNG>
static void run_walnuts(const F& target_logp_grad, VectorS theta_init, RNG& rng,
Integer D, Integer N, S macro_step_size,
Integer max_nuts_depth, S max_error, VectorS inv_mass) {
Integer max_nuts_depth, Integer min_micro_steps,
S max_error, VectorS inv_mass) {
std::cout << "\nRUN WALNUTS"
<< "; D = " << D << "; N = " << N
<< "; macro_step_size = " << macro_step_size
<< "; max_nuts_depth = " << max_nuts_depth
<< "; min_micro_steps = " << min_micro_steps
<< "; max_error = " << max_error << std::endl;
global_start_timer();
nuts::Random<double, RNG> rand(rng);
nuts::WalnutsSampler sample(rand, target_logp_grad, theta_init, inv_mass,
macro_step_size, max_nuts_depth, max_error);
macro_step_size, max_nuts_depth, min_micro_steps,
max_error);
MatrixS draws(D, N);
for (Integer n = 0; n < N; ++n) {
draws.col(n) = sample();
Expand All @@ -129,15 +132,16 @@ static void run_walnuts(const F& target_logp_grad, VectorS theta_init, RNG& rng,
template <typename F, typename RNG>
static void run_adaptive_walnuts(const F& target_logp_grad,
const VectorS& theta_init, RNG& rng, Integer D,
Integer N, Integer max_nuts_depth,
Integer N, double step_size_init,
Integer max_nuts_depth,
Integer min_micro_steps,
S max_error) {
Eigen::VectorXd mass_init = Eigen::VectorXd::Ones(D);
double init_count = 1.1;
double mass_iteration_offset = 1.1;
double additive_smoothing = 0.1;
nuts::MassAdaptConfig mass_cfg(mass_init, init_count, mass_iteration_offset,
additive_smoothing);
double step_size_init = 0.5;
double accept_rate_target = 2.0 / 3.0;
double step_iteration_offset = 2.0;
double learning_rate = 0.95;
Expand All @@ -146,11 +150,13 @@ static void run_adaptive_walnuts(const F& target_logp_grad,
step_iteration_offset, learning_rate,
decay_rate);
Integer max_step_depth = 8;
nuts::WalnutsConfig walnuts_cfg(max_error, max_nuts_depth, max_step_depth);
nuts::WalnutsConfig walnuts_cfg(max_error, max_nuts_depth, max_step_depth,
min_micro_steps);
std::cout << "\nRUN ADAPTIVE WALNUTS"
<< "; D = " << D << "; N = " << N
<< "; step_size_init = " << step_size_init
<< "; max_nuts_depth = " << max_nuts_depth
<< "; min_micro_steps = " << min_micro_steps
<< "; max_error = " << max_error << std::endl;
global_start_timer();
nuts::AdaptiveWalnuts walnuts(rng, target_logp_grad, theta_init, mass_cfg,
Expand All @@ -166,19 +172,20 @@ static void run_adaptive_walnuts(const F& target_logp_grad,
global_end_timer();
summarize(draws);
std::cout << std::endl;
std::cout << "Macro step size = " << sampler.macro_step_size() << std::endl;
std::cout << "Initial micro step size = " << sampler.macro_step_size() << std::endl;
std::cout << "Max error = " << sampler.max_error() << std::endl;
std::cout << "Inverse mass matrix = "
<< sampler.inverse_mass_matrix_diagonal().transpose() << std::endl;
}

int main() {
unsigned int seed = 428763;
unsigned int seed = 83435638;
Integer D = 200;
Integer N = 1000;
S step_size = 0.5;
S step_size = 0.4;
Integer max_depth = 10;
S max_error = 1.0; // 61% Metropolis
Integer min_micro_steps = 4;
S max_error = 1; // 61% Metropolis
VectorS inv_mass = VectorS::Ones(D);
std::mt19937 rng(seed);

Expand All @@ -198,10 +205,10 @@ int main() {
inv_mass);

run_walnuts(target_logp_grad, theta_init, rng, D, N, step_size, max_depth,
max_error, inv_mass);
min_micro_steps, max_error, inv_mass);

run_adaptive_walnuts(target_logp_grad, theta_init, rng, D, N, max_depth,
max_error);
run_adaptive_walnuts(target_logp_grad, theta_init, rng, D, N, step_size,
max_depth, min_micro_steps, max_error);

return 0;
}
16 changes: 10 additions & 6 deletions examples/examples_stan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,14 @@ template <typename RNG>
static void test_walnuts(const DynamicStanModel& model,
const VectorS& theta_init, RNG& rng, Integer N,
S macro_step_size, Integer max_nuts_depth,
Integer min_micro_steps,
S log_max_error, VectorS inv_mass) {
std::cout << "\nTEST WALNUTS" << std::endl;
nuts::Random<double, RNG> rand(rng);
auto logp = [&model](auto&&... args) { model.logp_grad(args...); };

nuts::WalnutsSampler sample(rand, logp, theta_init, inv_mass, macro_step_size,
max_nuts_depth, log_max_error);
max_nuts_depth, min_micro_steps, log_max_error);
int M = model.constrained_dimensions();

MatrixS draws(M, N);
Expand All @@ -104,7 +105,7 @@ template <typename RNG>
static void test_adaptive_walnuts(const DynamicStanModel& model,
const VectorS& theta_init, RNG& rng,
Integer D, Integer N, Integer max_nuts_depth,
S max_error) {
Integer min_micro_steps, S max_error) {
double logp_time = 0.0;
int logp_count = 0;
auto global_start = std::chrono::high_resolution_clock::now();
Expand All @@ -124,7 +125,8 @@ static void test_adaptive_walnuts(const DynamicStanModel& model,
step_iteration_offset, learning_rate,
decay_rate);
Integer max_step_depth = 8;
nuts::WalnutsConfig walnuts_cfg(max_error, max_nuts_depth, max_step_depth);
nuts::WalnutsConfig walnuts_cfg(max_error, max_nuts_depth, max_step_depth,
min_micro_steps);
std::cout << "\nTEST ADAPTIVE WALNUTS"
<< "; D = " << D << "; N = " << N
<< "; step_size_init = " << step_size_init
Expand Down Expand Up @@ -195,6 +197,7 @@ int main(int argc, char** argv) {
Integer N = 1000;
S step_size = 0.465;
Integer max_depth = 10;
Integer min_micro_steps = 1;
S max_error = 0.5;

char* lib{nullptr};
Expand Down Expand Up @@ -230,10 +233,11 @@ int main(int argc, char** argv) {

test_nuts(model, theta_init, rng, N, step_size, max_depth, inv_mass);

test_walnuts(model, theta_init, rng, N, step_size, max_depth, max_error,
inv_mass);
test_walnuts(model, theta_init, rng, N, step_size, max_depth,
min_micro_steps, max_error, inv_mass);

test_adaptive_walnuts(model, theta_init, rng, D, N, max_depth, max_error);
test_adaptive_walnuts(model, theta_init, rng, D, N, max_depth,
min_micro_steps, max_error);

return 0;
}
14 changes: 11 additions & 3 deletions examples/stan_cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ Matrix run_walnuts(DynamicStanModel& model, RNG& rng, const Vector& theta_init,
double step_size_init, double accept_rate_target,
double step_iteration_offset, double learning_rate,
double decay_rate, double max_error, int64_t max_nuts_depth,
int64_t max_step_depth) {
int64_t max_step_depth, int64_t min_micro_steps) {
double logp_time = 0.0;
int logp_count = 0;
auto global_start = std::chrono::high_resolution_clock::now();
Expand All @@ -97,7 +97,8 @@ Matrix run_walnuts(DynamicStanModel& model, RNG& rng, const Vector& theta_init,
nuts::StepAdaptConfig step_cfg(step_size_init, accept_rate_target,
step_iteration_offset, learning_rate,
decay_rate);
nuts::WalnutsConfig walnuts_cfg(max_error, max_nuts_depth, max_step_depth);
nuts::WalnutsConfig walnuts_cfg(max_error, max_nuts_depth, max_step_depth,
min_micro_steps);

std::cout << "Running Adaptive WALNUTS"
<< "; D = " << theta_init.size() << "; W = " << warmup
Expand Down Expand Up @@ -181,6 +182,7 @@ int main(int argc, char** argv) {
int64_t samples = 128;
int64_t max_nuts_depth = 10;
int64_t max_step_depth = 8;
int64_t min_micro_steps = 1;
double max_error = 0.5;
double init = 2.0;
double init_count = 1.1;
Expand Down Expand Up @@ -220,6 +222,11 @@ int main(int argc, char** argv) {
->default_val(max_step_depth)
->check(CLI::PositiveNumber);

app.add_option("--min-micro-steps", min_micro_steps,
"Minimum micro steps per macro step")
->default_val(min_micro_steps)
->check(CLI::PositiveNumber);

app.add_option("--max-error", max_error,
"Maximum error allowed in joint densities")
->default_val(max_error)
Expand Down Expand Up @@ -293,7 +300,8 @@ int main(int argc, char** argv) {
run_walnuts(model, rng, theta_init, warmup, samples, init_count,
mass_iteration_offset, additive_smoothing, step_size_init,
accept_rate_target, step_iteration_offset, learning_rate,
decay_rate, max_error, max_nuts_depth, max_step_depth);
decay_rate, max_error, max_nuts_depth, max_step_depth,
min_micro_steps);

auto names = model.param_names();
summarize(names, draws);
Expand Down
22 changes: 17 additions & 5 deletions include/walnuts/adaptive_walnuts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,25 +193,32 @@ struct WalnutsConfig {
* doublings for NUTS.
* @param[in] max_step_depth The maximum number of step doublings
* per macro step.
* @param[in] min_micro_steps The minimum number of micro steps per macro
* step.
* @throw std::invalid_argument If the log max error is not finite and
* positive.
* @throw std::invalid_argument If the maximum tree depth is not positive.
* @throw std::invalid_argument If the maximum step depth is negative.
*/
WalnutsConfig(S log_max_error, Integer max_nuts_depth, Integer max_step_depth)
WalnutsConfig(S log_max_error, Integer max_nuts_depth, Integer max_step_depth,
Integer min_micro_steps)
: log_max_error_(log_max_error),
max_nuts_depth_(max_nuts_depth),
max_step_depth_(max_step_depth) {
max_step_depth_(max_step_depth),
min_micro_steps_(min_micro_steps) {
if (!(log_max_error > 0) || std::isinf(log_max_error)) {
throw std::invalid_argument(
"Log maximum error must be positive and finite.");
}
if (max_nuts_depth < 1) {
if (max_nuts_depth <= 0) {
throw std::invalid_argument("Maximum NUTS depth must be positive.");
}
if (max_step_depth < 0) {
throw std::invalid_argument("Maximum step depth must be non-negative.");
}
if (min_micro_steps <= 0) {
throw std::invalid_argument("Minimum micro steps must be positive.");
}
}

/** The maximum error in Hamiltonian in macro steps. */
Expand All @@ -222,6 +229,9 @@ struct WalnutsConfig {

/** The maximum number of step doublings per macro step. */
const Integer max_step_depth_;

/** The minimum number of micro steps per macro step. */
const Integer min_micro_steps_;
};

/**
Expand Down Expand Up @@ -486,7 +496,8 @@ class AdaptiveWalnuts {
Vec<S> grad_select;
theta_ = transition_w(
rand_, logp_grad_, inv_mass, chol_mass, step_adapt_handler_.step_size(),
walnuts_cfg_.max_nuts_depth_, std::move(theta_), grad_select,
walnuts_cfg_.max_nuts_depth_, walnuts_cfg_.min_micro_steps_,
std::move(theta_), grad_select,
walnuts_cfg_.log_max_error_, step_adapt_handler_);
mass_estimator_.observe(theta_, grad_select, iteration_);
++iteration_;
Expand All @@ -507,7 +518,8 @@ class AdaptiveWalnuts {
return WalnutsSampler<F, S, RNG>(
rand_, logp_grad_.logp_grad_, theta_,
mass_estimator_.inv_mass_estimate(), step_adapt_handler_.step_size(),
walnuts_cfg_.max_nuts_depth_, walnuts_cfg_.log_max_error_);
walnuts_cfg_.max_nuts_depth_, walnuts_cfg_.min_micro_steps_,
walnuts_cfg_.log_max_error_);
}

/**
Expand Down
Loading
Loading