Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated docstrings #147

Merged
merged 1 commit into from
Sep 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Documentation

- More compact README.
- Updated function description.

# kernelshap 0.7.0

Expand Down
27 changes: 14 additions & 13 deletions R/additive_shap.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,20 @@
#' - `gam::gam()`,
#' - [survival::coxph()], and
#' - [survival::survreg()].
#'
#'
#' The SHAP values are extracted via `predict(object, newdata = X, type = "terms")`,
#' a logic heavily inspired by `fastshap:::explain.lm(..., exact = TRUE)`.
#' a logic adopted from `fastshap:::explain.lm(..., exact = TRUE)`.
#' Models with interactions (specified via `:` or `*`), or with terms of
#' multiple features like `log(x1/x2)` are not supported.
#'
#'
#' Note that the SHAP values obtained by [additive_shap()] are expected to
#' match those of [permshap()] and [kernelshap()] as long as their background
#' data equals the full training data (which is typically not feasible).
#'
#' @inheritParams kernelshap
#' @param X Dataframe with rows to be explained. Will be used like
#' @param object Fitted additive model.
#' @param X Dataframe with rows to be explained. Passed to
#' `predict(object, newdata = X, type = "terms")`.
#' @param verbose Set to `FALSE` to suppress messages.
#' @param ... Currently unused.
#' @returns
#' An object of class "kernelshap" with the following components:
Expand All @@ -38,15 +39,15 @@
#' fit <- lm(Sepal.Length ~ ., data = iris)
#' s <- additive_shap(fit, head(iris))
#' s
#'
#'
#' # MODEL TWO: More complicated (but not very clever) formula
#' fit <- lm(
#' Sepal.Length ~ poly(Sepal.Width, 2) + log(Petal.Length) + log(Sepal.Width),
#' data = iris
#' )
#' s_add <- additive_shap(fit, head(iris))
#' s_add
#'
#'
#' # Equals kernelshap()/permshap() when background data is full training data
#' s_kernel <- kernelshap(
#' fit, head(iris[c("Sepal.Width", "Petal.Length")]), bg_X = iris
Expand All @@ -59,28 +60,28 @@ additive_shap <- function(object, X, verbose = TRUE, ...) {
if (any(attr(stats::terms(object), "order") > 1)) {
stop("Additive SHAP not appropriate for models with interactions.")
}

txt <- "Exact additive SHAP via predict(..., type = 'terms')"
if (verbose) {
message(txt)
}

S <- stats::predict(object, newdata = X, type = "terms")
rownames(S) <- NULL

# Baseline value
b <- as.vector(attr(S, "constant"))
if (is.null(b)) {
b <- 0
}

# Which columns of X are used in each column of S?
s_names <- colnames(S)
cols_used <- lapply(s_names, function(z) all.vars(stats::reformulate(z)))
if (any(lengths(cols_used) > 1L)) {
stop("The formula contains terms with multiple features (not supported).")
}

# Collapse all columns in S using the same column in X and rename accordingly
mapping <- split(
s_names, factor(unlist(cols_used), levels = colnames(X)), drop = TRUE
Expand All @@ -89,7 +90,7 @@ additive_shap <- function(object, X, verbose = TRUE, ...) {
cbind,
lapply(mapping, function(z) rowSums(S[, z, drop = FALSE], na.rm = TRUE))
)

structure(
list(
S = S,
Expand Down
203 changes: 105 additions & 98 deletions R/kernelshap.R

Large diffs are not rendered by default.

22 changes: 11 additions & 11 deletions R/permshap.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#'
#' Exact permutation SHAP algorithm with respect to a background dataset,
#' see Strumbelj and Kononenko. The function works for up to 14 features.
#' For eight or more features, we recomment to switch to [kernelshap()].
#' For more than eight features, we recommend [kernelshap()] due to its higher speed.
#'
#' @inheritParams kernelshap
#' @returns
Expand All @@ -16,12 +16,12 @@
#' - `bg_w`: The background case weights.
#' - `m_exact`: Integer providing the effective number of exact on-off vectors used.
#' - `exact`: Logical flag indicating whether calculations are exact or not
#' (currently `TRUE`).
#' (currently always `TRUE`).
#' - `txt`: Summary text.
#' - `predictions`: \eqn{(n \times K)} matrix with predictions of `X`.
#' - `algorithm`: "permshap".
#' @references
#' 1. Erik Strumbelj and Igor Kononenko. Explaining prediction models and individual
#' 1. Erik Strumbelj and Igor Kononenko. Explaining prediction models and individual
#' predictions with feature contributions. Knowledge and Information Systems 41, 2014.
#' @export
#' @examples
Expand Down Expand Up @@ -80,7 +80,7 @@ permshap.default <- function(
if (verbose) {
message(txt)
}

basic_checks(X = X, feature_names = feature_names, pred_fun = pred_fun)
prep_bg <- prepare_bg(X = X, bg_X = bg_X, bg_n = bg_n, bg_w = bg_w, verbose = verbose)
bg_X <- prep_bg$bg_X
Expand All @@ -92,32 +92,32 @@ permshap.default <- function(
bg_preds <- align_pred(pred_fun(object, bg_X, ...))
v0 <- wcolMeans(bg_preds, w = bg_w) # Average pred of bg data: 1 x K
v1 <- align_pred(pred_fun(object, X, ...)) # Predictions on X: n x K

# Drop unnecessary columns in bg_X. If X is matrix, also column order is relevant
# Predictions will never be applied directly to bg_X anymore
if (!identical(colnames(bg_X), feature_names)) {
bg_X <- bg_X[, feature_names, drop = FALSE]
}

# Precalculations that are identical for each row to be explained
Z <- exact_Z(p, feature_names = feature_names, keep_extremes = TRUE)
m_exact <- nrow(Z) - 2L # We won't evaluate vz for first and last row
precalc <- list(
Z = Z,
Z_code = rowpaste(Z),
Z_code = rowpaste(Z),
bg_X_rep = rep_rows(bg_X, rep.int(seq_len(bg_n), m_exact))
)

if (m_exact * bg_n > 2e5) {
warning_burden(m_exact, bg_n = bg_n)
}

# Apply permutation SHAP to each row of X
if (isTRUE(parallel)) {
parallel_args <- c(list(i = seq_len(n)), parallel_args)
res <- do.call(foreach::foreach, parallel_args) %dopar% permshap_one(
x = X[i, , drop = FALSE],
v1 = v1[i, , drop = FALSE],
v1 = v1[i, , drop = FALSE],
object = object,
pred_fun = pred_fun,
bg_w = bg_w,
Expand All @@ -133,7 +133,7 @@ permshap.default <- function(
for (i in seq_len(n)) {
res[[i]] <- permshap_one(
x = X[i, , drop = FALSE],
v1 = v1[i, , drop = FALSE],
v1 = v1[i, , drop = FALSE],
object = object,
pred_fun = pred_fun,
bg_w = bg_w,
Expand Down
17 changes: 8 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,18 @@

The package contains three functions to crunch SHAP values:

- `permshap()`: Exact permutation SHAP algorithm of [1]. Recommended for models with up to 8 features.
- `kernelshap()`: Kernel SHAP algorithm of [2] and [3]. Recommended for models with more than 8 features.
- `additive_shap()`: For *additive models* fitted via `lm()`, `glm()`, `mgcv::gam()`, `mgcv::bam()`, `gam::gam()`, `survival::coxph()`, or `survival::survreg()`. Exponentially faster than the model-agnostic options above, and recommended if possible.
- **`permshap()`**: Exact permutation SHAP algorithm of [1]. Recommended for models with up to 8 features.
- **`kernelshap()`**: Kernel SHAP algorithm of [2] and [3]. Recommended for models with more than 8 features.
- **`additive_shap()`**: For *additive models* fitted via `lm()`, `glm()`, `mgcv::gam()`, `mgcv::bam()`, `gam::gam()`, `survival::coxph()`, or `survival::survreg()`. Exponentially faster than the model-agnostic options above, and recommended if possible.

To explain your model, select an explanation dataset `X` (up to 1000 rows from the training data) and apply the recommended function. Use {shapviz} to visualize the resulting SHAP values.
To explain your model, select an explanation dataset `X` (up to 1000 rows from the training data, feature columns only) and apply the recommended function. Use {shapviz} to visualize the resulting SHAP values.

**Remarks for `permshap()` and `kernelshap()`**
**Remarks to `permshap()` and `kernelshap()`**

- `X` should only contain feature columns.
- Both algorithms need a representative background data `bg_X` to calculate marginal means (up to 500 rows from the training data). In cases with a natural "off" value (like MNIST digits), this can also be a single row with all values set to the off value. If unspecified, 200 rows are randomly sampled from `X`.
- By changing the defaults in `kernelshap()`, the iterative pure sampling approach of [3] can be enforced.
- `permshap()` vs. `kernelshap()`: For models with interactions of order up to two, exact Kernel SHAP agrees with exact permutation SHAP.
- `additive_shap()` vs. the model-agnostic explainers: The results would agree if the full training data would be used as background data.
- Exact Kernel SHAP is an approximation to exact permutation SHAP. Since exact calculations are usually sufficiently fast for up to eight features, we recommend `permshap()` in this case. With more features, `kernelshap()` switches to a comparably fast, almost exact algorithm. That is why we recommend `kernelshap()` in this case.
- For models with interactions of order up to two, SHAP values of exact permutation SHAP and exact Kernel SHAP agree.
- `permshap()` and `kernelshap()` give the same results as `additive_shap` as long as the full training data would be used as background data.

## Installation

Expand Down
8 changes: 4 additions & 4 deletions man/additive_shap.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 8 additions & 3 deletions man/kernelshap.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/permshap.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading