1
1
# Kernel SHAP algorithm for a single row x
2
2
# If exact, a single call to predict() is necessary.
3
3
# If sampling is involved, we need at least two additional calls to predict().
4
- permshap_one <- function (x , v1 , object , pred_fun , feature_names , bg_w , exact , deg ,
5
- paired , m , tol , max_iter , v0 , precalc , ... ) {
4
+ kernelshap_one <- function (x , v1 , object , pred_fun , feature_names , bg_w , exact , deg ,
5
+ paired , m , tol , max_iter , v0 , precalc , ... ) {
6
6
p <- length(feature_names )
7
7
8
8
# Calculate A_exact and b_exact
@@ -12,28 +12,28 @@ permshap_one <- function(x, v1, object, pred_fun, feature_names, bg_w, exact, de
12
12
Z <- precalc [[" Z" ]] # (m_ex x p)
13
13
m_exact <- nrow(Z )
14
14
v0_m_exact <- v0 [rep(1L , m_exact ), , drop = FALSE ] # (m_ex x K)
15
-
15
+
16
16
# Most expensive part
17
17
vz <- get_vz( # (m_ex x K)
18
18
X = x [rep(1L , times = nrow(bg_X_exact )), , drop = FALSE ], # (m_ex*n_bg x p)
19
19
bg = bg_X_exact , # (m_ex*n_bg x p)
20
20
Z = Z , # (m_ex x p)
21
- object = object ,
21
+ object = object ,
22
22
pred_fun = pred_fun ,
23
23
feature_names = feature_names ,
24
- w = bg_w ,
24
+ w = bg_w ,
25
25
...
26
26
)
27
27
# Note: w is correctly replicated along columns of (vz - v0_m_exact)
28
28
b_exact <- crossprod(Z , precalc [[" w" ]] * (vz - v0_m_exact )) # (p x K)
29
-
29
+
30
30
# Some of the hybrid cases are exact as well
31
31
if (exact || trunc(p / 2 ) == deg ) {
32
32
beta <- solver(A_exact , b_exact , constraint = v1 - v0 ) # (p x K)
33
- return (list (beta = beta , sigma = 0 * beta , n_iter = 1L , converged = TRUE ))
33
+ return (list (beta = beta , sigma = 0 * beta , n_iter = 1L , converged = TRUE ))
34
34
}
35
- }
36
-
35
+ }
36
+
37
37
# Iterative sampling part, always using A_exact and b_exact to fill up the weights
38
38
bg_X_m <- precalc [[" bg_X_m" ]] # (m*n_bg x p)
39
39
X <- x [rep(1L , times = nrow(bg_X_m )), , drop = FALSE ] # (m*n_bg x p)
@@ -48,32 +48,32 @@ permshap_one <- function(x, v1, object, pred_fun, feature_names, bg_w, exact, de
48
48
A_exact <- A_sum
49
49
b_exact <- b_sum
50
50
}
51
-
51
+
52
52
while (! isTRUE(converged ) && n_iter < max_iter ) {
53
53
n_iter <- n_iter + 1L
54
54
input <- input_sampling(p = p , m = m , deg = deg , paired = paired )
55
55
Z <- input [[" Z" ]]
56
-
56
+
57
57
# Expensive # (m x K)
58
58
vz <- get_vz(
59
- X = X ,
60
- bg = bg_X_m ,
61
- Z = Z ,
62
- object = object ,
63
- pred_fun = pred_fun ,
64
- feature_names = feature_names ,
65
- w = bg_w ,
59
+ X = X ,
60
+ bg = bg_X_m ,
61
+ Z = Z ,
62
+ object = object ,
63
+ pred_fun = pred_fun ,
64
+ feature_names = feature_names ,
65
+ w = bg_w ,
66
66
...
67
67
)
68
-
68
+
69
69
# The sum of weights of A_exact and input[["A"]] is 1, same for b
70
70
A_temp <- A_exact + input [[" A" ]] # (p x p)
71
71
b_temp <- b_exact + crossprod(Z , input [[" w" ]] * (vz - v0_m )) # (p x K)
72
72
A_sum <- A_sum + A_temp # (p x p)
73
73
b_sum <- b_sum + b_temp # (p x K)
74
-
75
- # Least-squares with constraint that beta_1 + ... + beta_p = v_1 - v_0.
76
- # The additional constraint beta_0 = v_0 is dealt via offset
74
+
75
+ # Least-squares with constraint that beta_1 + ... + beta_p = v_1 - v_0.
76
+ # The additional constraint beta_0 = v_0 is dealt via offset
77
77
est_m [[n_iter ]] <- solver(A_temp , b_temp , constraint = v1 - v0 ) # (p x K)
78
78
79
79
# Covariance calculation would fail in the first iteration
@@ -116,7 +116,7 @@ ginv <- function (X, tol = sqrt(.Machine$double.eps)) {
116
116
} else if (! any(Positive )) {
117
117
array (0 , dim(X )[2L : 1L ])
118
118
} else {
119
- Xsvd $ v [, Positive , drop = FALSE ] %*%
119
+ Xsvd $ v [, Positive , drop = FALSE ] %*%
120
120
((1 / Xsvd $ d [Positive ]) * t(Xsvd $ u [, Positive , drop = FALSE ]))
121
121
}
122
122
}
@@ -126,11 +126,11 @@ get_vz <- function(X, bg, Z, object, pred_fun, feature_names, w, ...) {
126
126
m <- nrow(Z )
127
127
not_Z <- ! Z
128
128
n_bg <- nrow(bg ) / m # because bg was replicated m times
129
-
129
+
130
130
# Replicate not_Z, so that X, bg, not_Z are all of dimension (m*n_bg x p)
131
131
g <- rep(seq_len(m ), each = n_bg )
132
132
not_Z <- not_Z [g , , drop = FALSE ]
133
-
133
+
134
134
if (is.matrix(X )) {
135
135
# Remember that columns of X and bg are perfectly aligned in this case
136
136
X [not_Z ] <- bg [not_Z ]
@@ -143,7 +143,7 @@ get_vz <- function(X, bg, Z, object, pred_fun, feature_names, w, ...) {
143
143
}
144
144
}
145
145
preds <- check_pred(pred_fun(object , X , ... ), n = nrow(X ))
146
-
146
+
147
147
# Aggregate
148
148
if (is.null(w )) {
149
149
return (rowsum(preds , group = g , reorder = FALSE ) / n_bg )
@@ -162,15 +162,15 @@ weighted_colMeans <- function(x, w = NULL, ...) {
162
162
if (nrow(x ) != length(w )) {
163
163
stop(" Weights w not compatible with matrix x" )
164
164
}
165
- out <- colSums(x * w , ... ) / sum(w )
165
+ out <- colSums(x * w , ... ) / sum(w )
166
166
}
167
167
matrix (out , nrow = 1L )
168
168
}
169
169
170
170
# Binds list of matrices along new first axis
171
171
abind1 <- function (a ) {
172
172
out <- array (
173
- dim = c(length(a ), dim(a [[1L ]])),
173
+ dim = c(length(a ), dim(a [[1L ]])),
174
174
dimnames = c(list (NULL ), dimnames(a [[1L ]]))
175
175
)
176
176
for (i in seq_along(a )) {
@@ -196,9 +196,9 @@ reorganize_list <- function(alist, nms) {
196
196
# Checks and reshapes predictions to (n x K) matrix
197
197
check_pred <- function (x , n ) {
198
198
if (
199
- ! is.vector(x ) &&
200
- ! is.matrix(x ) &&
201
- ! is.data.frame(x ) &&
199
+ ! is.vector(x ) &&
200
+ ! is.matrix(x ) &&
201
+ ! is.data.frame(x ) &&
202
202
! (is.array(x ) && length(dim(x )) < = 2L )
203
203
) {
204
204
stop(" Predictions must be a vector, matrix, data.frame, or <=2D array" )
@@ -235,7 +235,7 @@ summarize_strategy <- function(p, exact, deg) {
235
235
}
236
236
if (deg == 0L ) {
237
237
return (" Kernel SHAP values by iterative sampling" )
238
- }
238
+ }
239
239
paste(" Kernel SHAP values by the hybrid strategy of degree" , deg )
240
240
}
241
241
0 commit comments