Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ S3method(apply_transformations,matrix)
S3method(diagnostic_factor,neff_ratio)
S3method(diagnostic_factor,rhat)
S3method(log_posterior,CmdStanMCMC)
S3method(log_posterior,draws_array)
S3method(log_posterior,draws_df)
S3method(log_posterior,draws_matrix)
S3method(log_posterior,stanfit)
S3method(log_posterior,stanreg)
S3method(melt_mcmc,matrix)
Expand All @@ -21,6 +24,9 @@ S3method(num_iters,mcmc_array)
S3method(num_params,data.frame)
S3method(num_params,mcmc_array)
S3method(nuts_params,CmdStanMCMC)
S3method(nuts_params,draws_array)
S3method(nuts_params,draws_df)
S3method(nuts_params,draws_matrix)
S3method(nuts_params,list)
S3method(nuts_params,stanfit)
S3method(nuts_params,stanreg)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# bayesplot (development version)

* Added `log_posterior()` and `nuts_params()` methods for `draws_array`, `draws_df`, and `draws_matrix` objects.
* `ppc_ecdf_overlay()`, `ppc_ecdf_overlay_grouped()`, and `ppd_ecdf_overlay()` now always use `geom_step()`. The `discrete` argument is deprecated.
* Fixed missing `drop = FALSE` in `nuts_params.CmdStanMCMC()`.
* Replace `apply()` with `storage.mode()` for integer-to-numeric matrix conversion in `validate_predictions()`.
Expand Down
65 changes: 60 additions & 5 deletions R/bayesplot-extractors.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,37 @@ log_posterior.stanreg <- function(object, inc_warmup = FALSE, ...) {
#' @export
#' @method log_posterior CmdStanMCMC
log_posterior.CmdStanMCMC <- function(object, inc_warmup = FALSE, ...) {
lp <- object$draws("lp__", inc_warmup = inc_warmup)
lp <- reshape2::melt(lp)
log_posterior.draws_array(object$draws("lp__", inc_warmup = inc_warmup), ...)
}

#' @rdname bayesplot-extractors
#' @export
#' @method log_posterior draws_array
log_posterior.draws_array <- function(object, ...) {
if (!"lp__" %in% posterior::variables(object)) {
abort("draws object does not contain an 'lp__' variable.")
}
lp <- reshape2::melt(object[, , "lp__", drop = FALSE])
lp$variable <- NULL
lp <- dplyr::rename_with(lp, capitalize_first)
validate_df_classes(lp[, c("Chain", "Iteration", "Value")],
c("integer", "integer", "numeric"))
}

#' @rdname bayesplot-extractors
#' @export
#' @method log_posterior draws_df
log_posterior.draws_df <- function(object, ...) {
log_posterior.draws_array(posterior::as_draws_array(object), ...)
}

#' @rdname bayesplot-extractors
#' @export
#' @method log_posterior draws_matrix
log_posterior.draws_matrix <- function(object, ...) {
log_posterior.draws_array(posterior::as_draws_array(object), ...)
}


#' @rdname bayesplot-extractors
#' @export
Expand Down Expand Up @@ -173,17 +196,49 @@ nuts_params.list <- function(object, pars = NULL, ...) {
#' @export
#' @method nuts_params CmdStanMCMC
nuts_params.CmdStanMCMC <- function(object, pars = NULL, ...) {
arr <- object$sampler_diagnostics()
if (!is.null(pars)) {
arr <- arr[,, pars, drop = FALSE]
nuts_params.draws_array(object$sampler_diagnostics(), pars = pars, ...)
}

#' @rdname bayesplot-extractors
#' @export
#' @method nuts_params draws_array
nuts_params.draws_array <- function(object, pars = NULL, ...) {
vars <- posterior::variables(object)
if (is.null(pars)) {
pars <- grep("__$", vars, value = TRUE)
pars <- setdiff(pars, "lp__")
if (!length(pars)) {
abort("draws object does not contain any NUTS sampler diagnostic variables (names ending in '__').")
}
} else {
missing_pars <- setdiff(pars, vars)
if (length(missing_pars)) {
abort(paste0("Variables not found in draws object: ",
paste(missing_pars, collapse = ", "), "."))
}
}
arr <- object[, , pars, drop = FALSE]
out <- reshape2::melt(arr)
colnames(out)[colnames(out) == "variable"] <- "parameter"
out <- dplyr::rename_with(out, capitalize_first)
validate_df_classes(out[, c("Chain", "Iteration", "Parameter", "Value")],
c("integer", "integer", "factor", "numeric"))
}

#' @rdname bayesplot-extractors
#' @export
#' @method nuts_params draws_df
nuts_params.draws_df <- function(object, pars = NULL, ...) {
nuts_params.draws_array(posterior::as_draws_array(object), pars = pars, ...)
}

#' @rdname bayesplot-extractors
#' @export
#' @method nuts_params draws_matrix
nuts_params.draws_matrix <- function(object, pars = NULL, ...) {
nuts_params.draws_array(posterior::as_draws_array(object), pars = pars, ...)
}


#' @rdname bayesplot-extractors
#' @export
Expand Down
18 changes: 18 additions & 0 deletions man/bayesplot-extractors.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

52 changes: 52 additions & 0 deletions tests/testthat/test-extractors.R
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,55 @@ test_that("cmdstanr methods work", {
expect_equal(range(np_one$Chain), c(1, 2))
expect_true(all(np_one$Value == 0))
})


# draws object methods ----------------------------------------------------
make_draws_array <- function(iter = 50, chains = 2) {
vars <- c("mu", "sigma", "lp__", "accept_stat__", "stepsize__",
"treedepth__", "n_leapfrog__", "divergent__", "energy__")
arr <- array(stats::rnorm(iter * chains * length(vars)),
dim = c(iter, chains, length(vars)),
dimnames = list(NULL, NULL, vars))
posterior::as_draws_array(arr)
}

test_that("log_posterior methods for draws objects return correct structure", {
d <- make_draws_array(iter = 50, chains = 2)

lp_arr <- log_posterior(d)
expect_identical(colnames(lp_arr), c("Chain", "Iteration", "Value"))
expect_equal(length(unique(lp_arr$Iteration)), 50)
expect_equal(length(unique(lp_arr$Chain)), 2)

lp_df <- log_posterior(posterior::as_draws_df(d))
lp_mat <- log_posterior(posterior::as_draws_matrix(d))
expect_equal(lp_df$Value, lp_arr$Value)
expect_equal(lp_mat$Value, lp_arr$Value)
})

test_that("nuts_params methods for draws objects return correct structure", {
d <- make_draws_array(iter = 50, chains = 2)

np <- nuts_params(d)
expect_identical(colnames(np), c("Chain", "Iteration", "Parameter", "Value"))
expect_setequal(
levels(np$Parameter),
c("accept_stat__", "stepsize__", "treedepth__",
"n_leapfrog__", "divergent__", "energy__")
)
expect_false("lp__" %in% levels(np$Parameter))

np_one <- nuts_params(d, pars = "divergent__")
expect_identical(levels(np_one$Parameter), "divergent__")

np_df <- nuts_params(posterior::as_draws_df(d))
expect_equal(np_df$Value, np$Value)
})

test_that("draws-object extractors error on missing variables", {
d <- make_draws_array()
bare <- d[, , c("mu", "sigma"), drop = FALSE]
expect_error(log_posterior(bare), "lp__")
expect_error(nuts_params(bare), "sampler diagnostic")
expect_error(nuts_params(d, pars = "nope__"), "nope__")
})
Loading