Skip to content

Commit 5e2b6d1

Browse files
authored
Merge pull request #142 from ModelOriented/auto-select-background-data
auto-select background data
2 parents ad10ea4 + f50a25a commit 5e2b6d1

File tree

12 files changed

+270
-131
lines changed

12 files changed

+270
-131
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: kernelshap
22
Title: Kernel SHAP
3-
Version: 0.6.1
3+
Version: 0.7.0
44
Authors@R: c(
55
person("Michael", "Mayer", , "[email protected]", role = c("aut", "cre"),
66
comment = c(ORCID = "0009-0007-2540-9629")),

NEWS.md

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,21 @@
1-
# kernelshap 0.6.1
1+
# kernelshap 0.7.0
2+
3+
This release is intended to be the last before stable version 1.0.0.
4+
5+
## Major change
6+
7+
Passing a background dataset `bg_X` is now optional.
8+
9+
If the explanation data `X` is sufficiently large (>= 50 rows), `bg_X` is derived as a random sample of `bg_n = 200` rows from `X`. If `X` has less than `bg_n` rows, then simply
10+
`bg_X = X`. If `X` has too few rows (< 50), you will have to pass an explicit `bg_X`.
11+
12+
## Minor changes
213

314
- `ranger()` survival models now also work out-of-the-box without passing a tailored prediction function. Use the new argument `survival = "chf"` in `kernelshap()` and `permshap()` to distinguish cumulative hazards (default) and survival probabilities per time point.
15+
- The resulting object of `kernelshap()` and `permshap()` now contain `bg_X` and `bg_w` used to calculate the SHAP values.
416

517
# kernelshap 0.6.0
618

7-
This release is intended to be the last before stable version 1.0.0.
8-
919
## Major changes
1020

1121
- Factor-valued predictions are not supported anymore.

R/kernelshap.R

Lines changed: 52 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#' For up to \eqn{p=8} features, the resulting Kernel SHAP values are exact regarding
66
#' the selected background data. For larger \eqn{p}, an almost exact
77
#' hybrid algorithm involving iterative sampling is used, see Details.
8+
#' For up to eight features, however, we recomment to use [permshap()].
89
#'
910
#' Pure iterative Kernel SHAP sampling as in Covert and Lee (2021) works like this:
1011
#'
@@ -63,10 +64,11 @@
6364
#' The columns should only represent model features, not the response
6465
#' (but see `feature_names` on how to overrule this).
6566
#' @param bg_X Background data used to integrate out "switched off" features,
66-
#' often a subset of the training data (typically 50 to 500 rows)
67-
#' It should contain the same columns as `X`.
67+
#' often a subset of the training data (typically 50 to 500 rows).
6868
#' In cases with a natural "off" value (like MNIST digits),
6969
#' this can also be a single row with all values set to the off value.
70+
#' If no `bg_X` is passed (the default) and if `X` is sufficiently large,
71+
#' a random sample of `bg_n` rows from `X` serves as background data.
7072
#' @param pred_fun Prediction function of the form `function(object, X, ...)`,
7173
#' providing \eqn{K \ge 1} predictions per row. Its first argument
7274
#' represents the model `object`, its second argument a data structure like `X`.
@@ -76,6 +78,8 @@
7678
#' SHAP values. By default, this equals `colnames(X)`. Not supported if `X`
7779
#' is a matrix.
7880
#' @param bg_w Optional vector of case weights for each row of `bg_X`.
81+
#' If `bg_X = NULL`, must be of same length as `X`. Set to `NULL` for no weights.
82+
#' @param bg_n If `bg_X = NULL`: Size of background data to be sampled from `X`.
7983
#' @param exact If `TRUE`, the algorithm will produce exact Kernel SHAP values
8084
#' with respect to the background data. In this case, the arguments `hybrid_degree`,
8185
#' `m`, `paired_sampling`, `tol`, and `max_iter` are ignored.
@@ -130,6 +134,8 @@
130134
#' - `X`: Same as input argument `X`.
131135
#' - `baseline`: Vector of length K representing the average prediction on the
132136
#' background data.
137+
#' - `bg_X`: The background data.
138+
#' - `bg_w`: The background case weights.
133139
#' - `SE`: Standard errors corresponding to `S` (and organized like `S`).
134140
#' - `n_iter`: Integer vector of length n providing the number of iterations
135141
#' per row of `X`.
@@ -155,28 +161,25 @@
155161
#' @examples
156162
#' # MODEL ONE: Linear regression
157163
#' fit <- lm(Sepal.Length ~ ., data = iris)
158-
#'
164+
#'
159165
#' # Select rows to explain (only feature columns)
160-
#' X_explain <- iris[1:2, -1]
161-
#'
162-
#' # Select small background dataset (could use all rows here because iris is small)
163-
#' set.seed(1)
164-
#' bg_X <- iris[sample(nrow(iris), 100), ]
165-
#'
166+
#' X_explain <- iris[-1]
167+
#'
166168
#' # Calculate SHAP values
167-
#' s <- kernelshap(fit, X_explain, bg_X = bg_X)
169+
#' s <- kernelshap(fit, X_explain)
168170
#' s
169-
#'
171+
#'
170172
#' # MODEL TWO: Multi-response linear regression
171173
#' fit <- lm(as.matrix(iris[, 1:2]) ~ Petal.Length + Petal.Width + Species, data = iris)
172-
#' s <- kernelshap(fit, iris[1:4, 3:5], bg_X = bg_X)
173-
#' summary(s)
174-
#'
175-
#' # Non-feature columns can be dropped via 'feature_names'
174+
#' s <- kernelshap(fit, iris[3:5])
175+
#' s
176+
#'
177+
#' # Note 1: Feature columns can also be selected 'feature_names'
178+
#' # Note 2: Especially when X is small, pass a sufficiently large background data bg_X
176179
#' s <- kernelshap(
177-
#' fit,
180+
#' fit,
178181
#' iris[1:4, ],
179-
#' bg_X = bg_X,
182+
#' bg_X = iris,
180183
#' feature_names = c("Petal.Length", "Petal.Width", "Species")
181184
#' )
182185
#' s
@@ -189,10 +192,11 @@ kernelshap <- function(object, ...){
189192
kernelshap.default <- function(
190193
object,
191194
X,
192-
bg_X,
195+
bg_X = NULL,
193196
pred_fun = stats::predict,
194197
feature_names = colnames(X),
195198
bg_w = NULL,
199+
bg_n = 200L,
196200
exact = length(feature_names) <= 8L,
197201
hybrid_degree = 1L + length(feature_names) %in% 4:16,
198202
paired_sampling = TRUE,
@@ -204,24 +208,24 @@ kernelshap.default <- function(
204208
verbose = TRUE,
205209
...
206210
) {
207-
basic_checks(X = X, bg_X = bg_X, feature_names = feature_names, pred_fun = pred_fun)
208211
p <- length(feature_names)
212+
basic_checks(X = X, feature_names = feature_names, pred_fun = pred_fun)
209213
stopifnot(
210214
exact %in% c(TRUE, FALSE),
211215
p == 1L || exact || hybrid_degree %in% 0:(p / 2),
212216
paired_sampling %in% c(TRUE, FALSE),
213217
"m must be even" = trunc(m / 2) == m / 2
214218
)
215-
n <- nrow(X)
219+
prep_bg <- prepare_bg(X = X, bg_X = bg_X, bg_n = bg_n, bg_w = bg_w, verbose = verbose)
220+
bg_X <- prep_bg$bg_X
221+
bg_w <- prep_bg$bg_w
216222
bg_n <- nrow(bg_X)
217-
if (!is.null(bg_w)) {
218-
bg_w <- prep_w(bg_w, bg_n = bg_n)
219-
}
223+
n <- nrow(X)
220224

221225
# Calculate v1 and v0
222-
v1 <- align_pred(pred_fun(object, X, ...)) # Predictions on X: n x K
223-
bg_preds <- align_pred(pred_fun(object, bg_X[, colnames(X), drop = FALSE], ...))
226+
bg_preds <- align_pred(pred_fun(object, bg_X, ...))
224227
v0 <- wcolMeans(bg_preds, bg_w) # Average pred of bg data: 1 x K
228+
v1 <- align_pred(pred_fun(object, X, ...)) # Predictions on X: n x K
225229

226230
# For p = 1, exact Shapley values are returned
227231
if (p == 1L) {
@@ -231,18 +235,25 @@ kernelshap.default <- function(
231235
return(out)
232236
}
233237

238+
txt <- summarize_strategy(p, exact = exact, deg = hybrid_degree)
239+
if (verbose) {
240+
message(txt)
241+
}
242+
234243
# Drop unnecessary columns in bg_X. If X is matrix, also column order is relevant
235244
# In what follows, predictions will never be applied directly to bg_X anymore
236245
if (!identical(colnames(bg_X), feature_names)) {
237246
bg_X <- bg_X[, feature_names, drop = FALSE]
238247
}
239248

240-
# Precalculations for the real Kernel SHAP
249+
# Precalculations that are identical for each row to be explained
241250
if (exact || hybrid_degree >= 1L) {
242251
if (exact) {
243252
precalc <- input_exact(p, feature_names = feature_names)
244253
} else {
245-
precalc <- input_partly_exact(p, deg = hybrid_degree, feature_names = feature_names)
254+
precalc <- input_partly_exact(
255+
p, deg = hybrid_degree, feature_names = feature_names
256+
)
246257
}
247258
m_exact <- nrow(precalc[["Z"]])
248259
prop_exact <- sum(precalc[["w"]])
@@ -256,11 +267,6 @@ kernelshap.default <- function(
256267
precalc[["bg_X_m"]] <- rep_rows(bg_X, rep.int(seq_len(bg_n), m))
257268
}
258269

259-
# Some infos
260-
txt <- summarize_strategy(p, exact = exact, deg = hybrid_degree)
261-
if (verbose) {
262-
message(txt)
263-
}
264270
if (max(m, m_exact) * bg_n > 2e5) {
265271
warning_burden(max(m, m_exact), bg_n = bg_n)
266272
}
@@ -319,11 +325,18 @@ kernelshap.default <- function(
319325
if (verbose && !all(converged)) {
320326
warning("\nNon-convergence for ", sum(!converged), " rows.")
321327
}
328+
329+
if (verbose) {
330+
cat("\n")
331+
}
332+
322333
out <- list(
323-
S = reorganize_list(lapply(res, `[[`, "beta")),
324-
X = X,
325-
baseline = as.vector(v0),
326-
SE = reorganize_list(lapply(res, `[[`, "sigma")),
334+
S = reorganize_list(lapply(res, `[[`, "beta")),
335+
X = X,
336+
baseline = as.vector(v0),
337+
bg_X = bg_X,
338+
bg_w = bg_w,
339+
SE = reorganize_list(lapply(res, `[[`, "sigma")),
327340
n_iter = vapply(res, `[[`, "n_iter", FUN.VALUE = integer(1L)),
328341
converged = converged,
329342
m = m,
@@ -343,10 +356,11 @@ kernelshap.default <- function(
343356
kernelshap.ranger <- function(
344357
object,
345358
X,
346-
bg_X,
359+
bg_X = NULL,
347360
pred_fun = NULL,
348361
feature_names = colnames(X),
349362
bg_w = NULL,
363+
bg_n = 200L,
350364
exact = length(feature_names) <= 8L,
351365
hybrid_degree = 1L + length(feature_names) %in% 4:16,
352366
paired_sampling = TRUE,
@@ -371,6 +385,7 @@ kernelshap.ranger <- function(
371385
pred_fun = pred_fun,
372386
feature_names = feature_names,
373387
bg_w = bg_w,
388+
bg_n = bg_n,
374389
exact = exact,
375390
hybrid_degree = hybrid_degree,
376391
paired_sampling = paired_sampling,

R/permshap.R

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#'
33
#' Exact permutation SHAP algorithm with respect to a background dataset,
44
#' see Strumbelj and Kononenko. The function works for up to 14 features.
5+
#' For eight or more features, we recomment to switch to [kernelshap()].
56
#'
67
#' @inheritParams kernelshap
78
#' @returns
@@ -11,6 +12,8 @@
1112
#' - `X`: Same as input argument `X`.
1213
#' - `baseline`: Vector of length K representing the average prediction on the
1314
#' background data.
15+
#' - `bg_X`: The background data.
16+
#' - `bg_w`: The background case weights.
1417
#' - `m_exact`: Integer providing the effective number of exact on-off vectors used.
1518
#' - `exact`: Logical flag indicating whether calculations are exact or not
1619
#' (currently `TRUE`).
@@ -26,26 +29,23 @@
2629
#' fit <- lm(Sepal.Length ~ ., data = iris)
2730
#'
2831
#' # Select rows to explain (only feature columns)
29-
#' X_explain <- iris[1:2, -1]
30-
#'
31-
#' # Select small background dataset (could use all rows here because iris is small)
32-
#' set.seed(1)
33-
#' bg_X <- iris[sample(nrow(iris), 100), ]
32+
#' X_explain <- iris[-1]
3433
#'
3534
#' # Calculate SHAP values
36-
#' s <- permshap(fit, X_explain, bg_X = bg_X)
35+
#' s <- permshap(fit, X_explain)
3736
#' s
3837
#'
3938
#' # MODEL TWO: Multi-response linear regression
4039
#' fit <- lm(as.matrix(iris[, 1:2]) ~ Petal.Length + Petal.Width + Species, data = iris)
41-
#' s <- permshap(fit, iris[1:4, 3:5], bg_X = bg_X)
40+
#' s <- permshap(fit, iris[3:5])
4241
#' s
4342
#'
44-
#' # Non-feature columns can be dropped via 'feature_names'
43+
#' # Note 1: Feature columns can also be selected 'feature_names'
44+
#' # Note 2: Especially when X is small, pass a sufficiently large background data bg_X
4545
#' s <- permshap(
4646
#' fit,
4747
#' iris[1:4, ],
48-
#' bg_X = bg_X,
48+
#' bg_X = iris,
4949
#' feature_names = c("Petal.Length", "Petal.Width", "Species")
5050
#' )
5151
#' s
@@ -58,37 +58,40 @@ permshap <- function(object, ...) {
5858
permshap.default <- function(
5959
object,
6060
X,
61-
bg_X,
61+
bg_X = NULL,
6262
pred_fun = stats::predict,
6363
feature_names = colnames(X),
6464
bg_w = NULL,
65+
bg_n = 200L,
6566
parallel = FALSE,
6667
parallel_args = NULL,
6768
verbose = TRUE,
6869
...
6970
) {
70-
basic_checks(X = X, bg_X = bg_X, feature_names = feature_names, pred_fun = pred_fun)
7171
p <- length(feature_names)
7272
if (p <= 1L) {
7373
stop("Case p = 1 not implemented. Use kernelshap() instead.")
7474
}
7575
if (p > 14L) {
7676
stop("Permutation SHAP only supported for up to 14 features")
7777
}
78-
n <- nrow(X)
79-
bg_n <- nrow(bg_X)
80-
if (!is.null(bg_w)) {
81-
bg_w <- prep_w(bg_w, bg_n = bg_n)
82-
}
78+
8379
txt <- "Exact permutation SHAP"
8480
if (verbose) {
8581
message(txt)
8682
}
8783

84+
basic_checks(X = X, feature_names = feature_names, pred_fun = pred_fun)
85+
prep_bg <- prepare_bg(X = X, bg_X = bg_X, bg_n = bg_n, bg_w = bg_w, verbose = verbose)
86+
bg_X <- prep_bg$bg_X
87+
bg_w <- prep_bg$bg_w
88+
bg_n <- nrow(bg_X)
89+
n <- nrow(X)
90+
8891
# Baseline and predictions on explanation data
89-
bg_preds <- align_pred(pred_fun(object, bg_X[, colnames(X), drop = FALSE], ...))
90-
v0 <- wcolMeans(bg_preds, bg_w) # Average pred of bg data: 1 x K
91-
v1 <- align_pred(pred_fun(object, X, ...)) # Predictions on X: n x K
92+
bg_preds <- align_pred(pred_fun(object, bg_X, ...))
93+
v0 <- wcolMeans(bg_preds, w = bg_w) # Average pred of bg data: 1 x K
94+
v1 <- align_pred(pred_fun(object, X, ...)) # Predictions on X: n x K
9295

9396
# Drop unnecessary columns in bg_X. If X is matrix, also column order is relevant
9497
# Predictions will never be applied directly to bg_X anymore
@@ -143,10 +146,15 @@ permshap.default <- function(
143146
}
144147
}
145148
}
149+
if (verbose) {
150+
cat("\n")
151+
}
146152
out <- list(
147-
S = reorganize_list(res),
148-
X = X,
153+
S = reorganize_list(res),
154+
X = X,
149155
baseline = as.vector(v0),
156+
bg_X = bg_X,
157+
bg_w = bg_w,
150158
m_exact = m_exact,
151159
exact = TRUE,
152160
txt = txt,
@@ -162,10 +170,11 @@ permshap.default <- function(
162170
permshap.ranger <- function(
163171
object,
164172
X,
165-
bg_X,
173+
bg_X = NULL,
166174
pred_fun = NULL,
167175
feature_names = colnames(X),
168176
bg_w = NULL,
177+
bg_n = 200L,
169178
parallel = FALSE,
170179
parallel_args = NULL,
171180
verbose = TRUE,
@@ -184,6 +193,7 @@ permshap.ranger <- function(
184193
pred_fun = pred_fun,
185194
feature_names = feature_names,
186195
bg_w = bg_w,
196+
bg_n = bg_n,
187197
parallel = parallel,
188198
parallel_args = parallel_args,
189199
verbose = verbose,

0 commit comments

Comments
 (0)