5
5
# ' For up to \eqn{p=8} features, the resulting Kernel SHAP values are exact regarding
6
6
# ' the selected background data. For larger \eqn{p}, an almost exact
7
7
# ' hybrid algorithm involving iterative sampling is used, see Details.
8
+ # ' For up to eight features, however, we recomment to use [permshap()].
8
9
# '
9
10
# ' Pure iterative Kernel SHAP sampling as in Covert and Lee (2021) works like this:
10
11
# '
63
64
# ' The columns should only represent model features, not the response
64
65
# ' (but see `feature_names` on how to overrule this).
65
66
# ' @param bg_X Background data used to integrate out "switched off" features,
66
- # ' often a subset of the training data (typically 50 to 500 rows)
67
- # ' It should contain the same columns as `X`.
67
+ # ' often a subset of the training data (typically 50 to 500 rows).
68
68
# ' In cases with a natural "off" value (like MNIST digits),
69
69
# ' this can also be a single row with all values set to the off value.
70
+ # ' If no `bg_X` is passed (the default) and if `X` is sufficiently large,
71
+ # ' a random sample of `bg_n` rows from `X` serves as background data.
70
72
# ' @param pred_fun Prediction function of the form `function(object, X, ...)`,
71
73
# ' providing \eqn{K \ge 1} predictions per row. Its first argument
72
74
# ' represents the model `object`, its second argument a data structure like `X`.
76
78
# ' SHAP values. By default, this equals `colnames(X)`. Not supported if `X`
77
79
# ' is a matrix.
78
80
# ' @param bg_w Optional vector of case weights for each row of `bg_X`.
81
+ # ' If `bg_X = NULL`, must be of same length as `X`. Set to `NULL` for no weights.
82
+ # ' @param bg_n If `bg_X = NULL`: Size of background data to be sampled from `X`.
79
83
# ' @param exact If `TRUE`, the algorithm will produce exact Kernel SHAP values
80
84
# ' with respect to the background data. In this case, the arguments `hybrid_degree`,
81
85
# ' `m`, `paired_sampling`, `tol`, and `max_iter` are ignored.
130
134
# ' - `X`: Same as input argument `X`.
131
135
# ' - `baseline`: Vector of length K representing the average prediction on the
132
136
# ' background data.
137
+ # ' - `bg_X`: The background data.
138
+ # ' - `bg_w`: The background case weights.
133
139
# ' - `SE`: Standard errors corresponding to `S` (and organized like `S`).
134
140
# ' - `n_iter`: Integer vector of length n providing the number of iterations
135
141
# ' per row of `X`.
155
161
# ' @examples
156
162
# ' # MODEL ONE: Linear regression
157
163
# ' fit <- lm(Sepal.Length ~ ., data = iris)
158
- # '
164
+ # '
159
165
# ' # Select rows to explain (only feature columns)
160
- # ' X_explain <- iris[1:2, -1]
161
- # '
162
- # ' # Select small background dataset (could use all rows here because iris is small)
163
- # ' set.seed(1)
164
- # ' bg_X <- iris[sample(nrow(iris), 100), ]
165
- # '
166
+ # ' X_explain <- iris[-1]
167
+ # '
166
168
# ' # Calculate SHAP values
167
- # ' s <- kernelshap(fit, X_explain, bg_X = bg_X )
169
+ # ' s <- kernelshap(fit, X_explain)
168
170
# ' s
169
- # '
171
+ # '
170
172
# ' # MODEL TWO: Multi-response linear regression
171
173
# ' fit <- lm(as.matrix(iris[, 1:2]) ~ Petal.Length + Petal.Width + Species, data = iris)
172
- # ' s <- kernelshap(fit, iris[1:4, 3:5], bg_X = bg_X)
173
- # ' summary(s)
174
- # '
175
- # ' # Non-feature columns can be dropped via 'feature_names'
174
+ # ' s <- kernelshap(fit, iris[3:5])
175
+ # ' s
176
+ # '
177
+ # ' # Note 1: Feature columns can also be selected 'feature_names'
178
+ # ' # Note 2: Especially when X is small, pass a sufficiently large background data bg_X
176
179
# ' s <- kernelshap(
177
- # ' fit,
180
+ # ' fit,
178
181
# ' iris[1:4, ],
179
- # ' bg_X = bg_X,
182
+ # ' bg_X = iris,
180
183
# ' feature_names = c("Petal.Length", "Petal.Width", "Species")
181
184
# ' )
182
185
# ' s
@@ -189,10 +192,11 @@ kernelshap <- function(object, ...){
189
192
kernelshap.default <- function (
190
193
object ,
191
194
X ,
192
- bg_X ,
195
+ bg_X = NULL ,
193
196
pred_fun = stats :: predict ,
194
197
feature_names = colnames(X ),
195
198
bg_w = NULL ,
199
+ bg_n = 200L ,
196
200
exact = length(feature_names ) < = 8L ,
197
201
hybrid_degree = 1L + length(feature_names ) %in% 4 : 16 ,
198
202
paired_sampling = TRUE ,
@@ -204,24 +208,24 @@ kernelshap.default <- function(
204
208
verbose = TRUE ,
205
209
...
206
210
) {
207
- basic_checks(X = X , bg_X = bg_X , feature_names = feature_names , pred_fun = pred_fun )
208
211
p <- length(feature_names )
212
+ basic_checks(X = X , feature_names = feature_names , pred_fun = pred_fun )
209
213
stopifnot(
210
214
exact %in% c(TRUE , FALSE ),
211
215
p == 1L || exact || hybrid_degree %in% 0 : (p / 2 ),
212
216
paired_sampling %in% c(TRUE , FALSE ),
213
217
" m must be even" = trunc(m / 2 ) == m / 2
214
218
)
215
- n <- nrow(X )
219
+ prep_bg <- prepare_bg(X = X , bg_X = bg_X , bg_n = bg_n , bg_w = bg_w , verbose = verbose )
220
+ bg_X <- prep_bg $ bg_X
221
+ bg_w <- prep_bg $ bg_w
216
222
bg_n <- nrow(bg_X )
217
- if (! is.null(bg_w )) {
218
- bg_w <- prep_w(bg_w , bg_n = bg_n )
219
- }
223
+ n <- nrow(X )
220
224
221
225
# Calculate v1 and v0
222
- v1 <- align_pred(pred_fun(object , X , ... )) # Predictions on X: n x K
223
- bg_preds <- align_pred(pred_fun(object , bg_X [, colnames(X ), drop = FALSE ], ... ))
226
+ bg_preds <- align_pred(pred_fun(object , bg_X , ... ))
224
227
v0 <- wcolMeans(bg_preds , bg_w ) # Average pred of bg data: 1 x K
228
+ v1 <- align_pred(pred_fun(object , X , ... )) # Predictions on X: n x K
225
229
226
230
# For p = 1, exact Shapley values are returned
227
231
if (p == 1L ) {
@@ -231,18 +235,25 @@ kernelshap.default <- function(
231
235
return (out )
232
236
}
233
237
238
+ txt <- summarize_strategy(p , exact = exact , deg = hybrid_degree )
239
+ if (verbose ) {
240
+ message(txt )
241
+ }
242
+
234
243
# Drop unnecessary columns in bg_X. If X is matrix, also column order is relevant
235
244
# In what follows, predictions will never be applied directly to bg_X anymore
236
245
if (! identical(colnames(bg_X ), feature_names )) {
237
246
bg_X <- bg_X [, feature_names , drop = FALSE ]
238
247
}
239
248
240
- # Precalculations for the real Kernel SHAP
249
+ # Precalculations that are identical for each row to be explained
241
250
if (exact || hybrid_degree > = 1L ) {
242
251
if (exact ) {
243
252
precalc <- input_exact(p , feature_names = feature_names )
244
253
} else {
245
- precalc <- input_partly_exact(p , deg = hybrid_degree , feature_names = feature_names )
254
+ precalc <- input_partly_exact(
255
+ p , deg = hybrid_degree , feature_names = feature_names
256
+ )
246
257
}
247
258
m_exact <- nrow(precalc [[" Z" ]])
248
259
prop_exact <- sum(precalc [[" w" ]])
@@ -256,11 +267,6 @@ kernelshap.default <- function(
256
267
precalc [[" bg_X_m" ]] <- rep_rows(bg_X , rep.int(seq_len(bg_n ), m ))
257
268
}
258
269
259
- # Some infos
260
- txt <- summarize_strategy(p , exact = exact , deg = hybrid_degree )
261
- if (verbose ) {
262
- message(txt )
263
- }
264
270
if (max(m , m_exact ) * bg_n > 2e5 ) {
265
271
warning_burden(max(m , m_exact ), bg_n = bg_n )
266
272
}
@@ -319,11 +325,18 @@ kernelshap.default <- function(
319
325
if (verbose && ! all(converged )) {
320
326
warning(" \n Non-convergence for " , sum(! converged ), " rows." )
321
327
}
328
+
329
+ if (verbose ) {
330
+ cat(" \n " )
331
+ }
332
+
322
333
out <- list (
323
- S = reorganize_list(lapply(res , `[[` , " beta" )),
324
- X = X ,
325
- baseline = as.vector(v0 ),
326
- SE = reorganize_list(lapply(res , `[[` , " sigma" )),
334
+ S = reorganize_list(lapply(res , `[[` , " beta" )),
335
+ X = X ,
336
+ baseline = as.vector(v0 ),
337
+ bg_X = bg_X ,
338
+ bg_w = bg_w ,
339
+ SE = reorganize_list(lapply(res , `[[` , " sigma" )),
327
340
n_iter = vapply(res , `[[` , " n_iter" , FUN.VALUE = integer(1L )),
328
341
converged = converged ,
329
342
m = m ,
@@ -343,10 +356,11 @@ kernelshap.default <- function(
343
356
kernelshap.ranger <- function (
344
357
object ,
345
358
X ,
346
- bg_X ,
359
+ bg_X = NULL ,
347
360
pred_fun = NULL ,
348
361
feature_names = colnames(X ),
349
362
bg_w = NULL ,
363
+ bg_n = 200L ,
350
364
exact = length(feature_names ) < = 8L ,
351
365
hybrid_degree = 1L + length(feature_names ) %in% 4 : 16 ,
352
366
paired_sampling = TRUE ,
@@ -371,6 +385,7 @@ kernelshap.ranger <- function(
371
385
pred_fun = pred_fun ,
372
386
feature_names = feature_names ,
373
387
bg_w = bg_w ,
388
+ bg_n = bg_n ,
374
389
exact = exact ,
375
390
hybrid_degree = hybrid_degree ,
376
391
paired_sampling = paired_sampling ,
0 commit comments