Skip to content

Commit 376b885

Browse files
authored
Revert "Rename things"
1 parent ac6ad11 commit 376b885

15 files changed

+366
-230
lines changed

CRAN-SUBMISSION

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Version: 0.3.7
2+
Date: 2023-05-17 06:52:31 UTC
3+
SHA: b6e4ce87f93a54e5c451cd06315ab810bb29eb8a

DESCRIPTION

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,17 @@ Package: kernelshap
22
Title: Kernel SHAP
33
Version: 0.3.8
44
Authors@R: c(
5-
person("Michael", "Mayer", , "[email protected]", role = c("aut", "cre"))
5+
person("Michael", "Mayer", , "[email protected]", role = c("aut", "cre")),
6+
person("David", "Watson", , "[email protected]", role = "aut"),
7+
person("Przemyslaw", "Biecek", , "[email protected]", role = "ctb",
8+
comment = c(ORCID = "0000-0001-8423-1823"))
69
)
7-
Description: Implementation of ... The package plays well together
10+
Description: Efficient implementation of Kernel SHAP, see Lundberg and Lee
11+
(2017) <https://dl.acm.org/doi/10.5555/3295222.3295230>, and Covert
12+
and Lee (2021) <http://proceedings.mlr.press/v130/covert21a>. For
13+
models with up to eight features, the results are exact regarding the
14+
selected background data. Otherwise, an almost exact hybrid algorithm
15+
involving iterative sampling is used. The package plays well together
816
with meta-learning packages like 'tidymodels', 'caret' or 'mlr3'.
917
Visualizations can be done using the R package 'shapviz'.
1018
License: GPL (>= 2)

R/exact.R

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,19 @@ input_exact <- function(p) {
99
Z <- exact_Z(p)
1010
# Each Kernel weight(j) is divided by the number of vectors z having sum(z) = j
1111
w <- kernel_weights(p) / choose(p, 1:(p - 1L))
12-
list(Z = Z, w = w[rowSums(Z)])
12+
list(Z = Z, w = w[rowSums(Z)], A = exact_A(p))
13+
}
14+
15+
# Calculates exact A. Notice the difference to the off-diagnonals in the Supplement of
16+
# Covert and Lee (2021). Credits to David Watson for figuring out the correct formula,
17+
# see our discussions in https://github.com/ModelOriented/kernelshap/issues/22
18+
exact_A <- function(p) {
19+
S <- 1:(p - 1L)
20+
c_pr <- S * (S - 1) / p / (p - 1)
21+
off_diag <- sum(kernel_weights(p) * c_pr)
22+
A <- matrix(off_diag, nrow = p, ncol = p)
23+
diag(A) <- 0.5
24+
A
1325
}
1426

1527
# Creates (2^p-2) x p matrix with all on-off vectors z of length p
@@ -53,10 +65,10 @@ input_partly_exact <- function(p, deg) {
5365
if (p < 2L * deg) {
5466
stop("p must be >=2*deg")
5567
}
56-
68+
5769
kw <- kernel_weights(p)
5870
Z <- w <- vector("list", deg)
59-
71+
6072
for (k in seq_len(deg)) {
6173
Z[[k]] <- partly_exact_Z(p, k = k)
6274
n <- nrow(Z[[k]])
@@ -65,6 +77,40 @@ input_partly_exact <- function(p, deg) {
6577
}
6678
w <- unlist(w, recursive = FALSE, use.names = FALSE)
6779
Z <- do.call(rbind, Z)
68-
69-
list(Z = Z, w = w)
80+
81+
list(Z = Z, w = w, A = crossprod(Z, w * Z))
7082
}
83+
84+
# Case p = 1 returns exact Shapley values
85+
case_p1 <- function(n, nms, v0, v1, X, verbose) {
86+
txt <- "Exact Shapley values (p = 1)"
87+
if (verbose) {
88+
message(txt)
89+
}
90+
S <- v1 - v0[rep(1L, n), , drop = FALSE]
91+
SE <- matrix(numeric(n), dimnames = list(NULL, nms))
92+
if (ncol(v1) > 1L) {
93+
SE <- replicate(ncol(v1), SE, simplify = FALSE)
94+
S <- lapply(
95+
asplit(S, MARGIN = 2L), function(M) as.matrix(M, dimnames = list(NULL, nms))
96+
)
97+
} else {
98+
colnames(S) <- nms
99+
}
100+
out <- list(
101+
S = S,
102+
X = X,
103+
baseline = as.vector(v0),
104+
SE = SE,
105+
n_iter = integer(n),
106+
converged = rep(TRUE, n),
107+
m = 0L,
108+
m_exact = 0L,
109+
prop_exact = 1,
110+
exact = TRUE,
111+
txt = txt,
112+
predictions = v1
113+
)
114+
class(out) <- "kernelshap"
115+
out
116+
}

R/kernelshap.R

Lines changed: 136 additions & 136 deletions
Large diffs are not rendered by default.

R/methods.R

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,39 @@
11
#' Print Method
2+
#'
3+
#' Prints the first two rows of the matrix (or matrices) of SHAP values.
24
#'
3-
#' Prints the first two rows of the matrix (or matrices) of SHAP values.
4-
#'
5-
#' @param x An object of class "permshap".
5+
#' @param x An object of class "kernelshap".
66
#' @param n Maximum number of rows of SHAP values to print.
77
#' @param ... Further arguments passed from other methods.
88
#' @returns Invisibly, the input is returned.
99
#' @export
1010
#' @examples
1111
#' fit <- stats::lm(Sepal.Length ~ ., data = iris)
12-
#' s <- permshap(fit, iris[1:3, -1], bg_X = iris[-1])
12+
#' s <- kernelshap(fit, iris[1:3, -1], bg_X = iris[-1])
1313
#' s
14-
#' @seealso [permshap()]
15-
print.permshap <- function(x, n = 2L, ...) {
14+
#' @seealso [kernelshap()]
15+
print.kernelshap <- function(x, n = 2L, ...) {
1616
cat("SHAP values of first", n, "observations:\n")
1717
print(head_list(getElement(x, "S"), n = n))
1818
invisible(x)
1919
}
2020

2121
#' Summary Method
2222
#'
23-
#' @param object An object of class "permshap".
24-
#' @param compact Set to `TRUE` to hide printing the top n SHAP values,
25-
#' standard errors and feature values.
26-
#' @param n Maximum number of rows of SHAP values, standard errors and feature values
23+
#' @param object An object of class "kernelshap".
24+
#' @param compact Set to `TRUE` to hide printing the top n SHAP values,
25+
#' standard errors and feature values.
26+
#' @param n Maximum number of rows of SHAP values, standard errors and feature values
2727
#' to print.
2828
#' @param ... Further arguments passed from other methods.
2929
#' @returns Invisibly, the input is returned.
3030
#' @export
3131
#' @examples
3232
#' fit <- stats::lm(Sepal.Length ~ ., data = iris)
33-
#' s <- permshap(fit, iris[1:3, -1], bg_X = iris[-1])
33+
#' s <- kernelshap(fit, iris[1:3, -1], bg_X = iris[-1])
3434
#' summary(s)
35-
#' @seealso [permshap()]
36-
summary.permshap <- function(object, compact = FALSE, n = 2L, ...) {
35+
#' @seealso [kernelshap()]
36+
summary.kernelshap <- function(object, compact = FALSE, n = 2L, ...) {
3737
cat(getElement(object, "txt"))
3838

3939
S <- getElement(object, "S")
@@ -68,19 +68,19 @@ summary.permshap <- function(object, compact = FALSE, n = 2L, ...) {
6868
invisible(object)
6969
}
7070

71-
#' Check for permshap
71+
#' Check for kernelshap
7272
#'
73-
#' Is object of class "permshap"?
73+
#' Is object of class "kernelshap"?
7474
#'
7575
#' @param object An R object.
76-
#' @returns `TRUE` if `object` is of class "permshap", and `FALSE` otherwise.
76+
#' @returns `TRUE` if `object` is of class "kernelshap", and `FALSE` otherwise.
7777
#' @export
7878
#' @examples
7979
#' fit <- stats::lm(Sepal.Length ~ ., data = iris)
80-
#' s <- permshap(fit, iris[1:2, -1], bg_X = iris[-1])
81-
#' is.permshap(s)
82-
#' is.permshap("a")
83-
#' @seealso [permshap()]
84-
is.permshap <- function(object){
85-
inherits(object, "permshap")
80+
#' s <- kernelshap(fit, iris[1:2, -1], bg_X = iris[-1])
81+
#' is.kernelshap(s)
82+
#' is.kernelshap("a")
83+
#' @seealso [kernelshap()]
84+
is.kernelshap <- function(object){
85+
inherits(object, "kernelshap")
8686
}

R/utils.R

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Kernel SHAP algorithm for a single row x
22
# If exact, a single call to predict() is necessary.
33
# If sampling is involved, we need at least two additional calls to predict().
4-
permshap_one <- function(x, v1, object, pred_fun, feature_names, bg_w, exact, deg,
5-
paired, m, tol, max_iter, v0, precalc, ...) {
4+
kernelshap_one <- function(x, v1, object, pred_fun, feature_names, bg_w, exact, deg,
5+
paired, m, tol, max_iter, v0, precalc, ...) {
66
p <- length(feature_names)
77

88
# Calculate A_exact and b_exact
@@ -12,28 +12,28 @@ permshap_one <- function(x, v1, object, pred_fun, feature_names, bg_w, exact, de
1212
Z <- precalc[["Z"]] # (m_ex x p)
1313
m_exact <- nrow(Z)
1414
v0_m_exact <- v0[rep(1L, m_exact), , drop = FALSE] # (m_ex x K)
15-
15+
1616
# Most expensive part
1717
vz <- get_vz( # (m_ex x K)
1818
X = x[rep(1L, times = nrow(bg_X_exact)), , drop = FALSE], # (m_ex*n_bg x p)
1919
bg = bg_X_exact, # (m_ex*n_bg x p)
2020
Z = Z, # (m_ex x p)
21-
object = object,
21+
object = object,
2222
pred_fun = pred_fun,
2323
feature_names = feature_names,
24-
w = bg_w,
24+
w = bg_w,
2525
...
2626
)
2727
# Note: w is correctly replicated along columns of (vz - v0_m_exact)
2828
b_exact <- crossprod(Z, precalc[["w"]] * (vz - v0_m_exact)) # (p x K)
29-
29+
3030
# Some of the hybrid cases are exact as well
3131
if (exact || trunc(p / 2) == deg) {
3232
beta <- solver(A_exact, b_exact, constraint = v1 - v0) # (p x K)
33-
return(list(beta = beta, sigma = 0 * beta, n_iter = 1L, converged = TRUE))
33+
return(list(beta = beta, sigma = 0 * beta, n_iter = 1L, converged = TRUE))
3434
}
35-
}
36-
35+
}
36+
3737
# Iterative sampling part, always using A_exact and b_exact to fill up the weights
3838
bg_X_m <- precalc[["bg_X_m"]] # (m*n_bg x p)
3939
X <- x[rep(1L, times = nrow(bg_X_m)), , drop = FALSE] # (m*n_bg x p)
@@ -48,32 +48,32 @@ permshap_one <- function(x, v1, object, pred_fun, feature_names, bg_w, exact, de
4848
A_exact <- A_sum
4949
b_exact <- b_sum
5050
}
51-
51+
5252
while(!isTRUE(converged) && n_iter < max_iter) {
5353
n_iter <- n_iter + 1L
5454
input <- input_sampling(p = p, m = m, deg = deg, paired = paired)
5555
Z <- input[["Z"]]
56-
56+
5757
# Expensive # (m x K)
5858
vz <- get_vz(
59-
X = X,
60-
bg = bg_X_m,
61-
Z = Z,
62-
object = object,
63-
pred_fun = pred_fun,
64-
feature_names = feature_names,
65-
w = bg_w,
59+
X = X,
60+
bg = bg_X_m,
61+
Z = Z,
62+
object = object,
63+
pred_fun = pred_fun,
64+
feature_names = feature_names,
65+
w = bg_w,
6666
...
6767
)
68-
68+
6969
# The sum of weights of A_exact and input[["A"]] is 1, same for b
7070
A_temp <- A_exact + input[["A"]] # (p x p)
7171
b_temp <- b_exact + crossprod(Z, input[["w"]] * (vz - v0_m)) # (p x K)
7272
A_sum <- A_sum + A_temp # (p x p)
7373
b_sum <- b_sum + b_temp # (p x K)
74-
75-
# Least-squares with constraint that beta_1 + ... + beta_p = v_1 - v_0.
76-
# The additional constraint beta_0 = v_0 is dealt via offset
74+
75+
# Least-squares with constraint that beta_1 + ... + beta_p = v_1 - v_0.
76+
# The additional constraint beta_0 = v_0 is dealt via offset
7777
est_m[[n_iter]] <- solver(A_temp, b_temp, constraint = v1 - v0) # (p x K)
7878

7979
# Covariance calculation would fail in the first iteration
@@ -116,7 +116,7 @@ ginv <- function (X, tol = sqrt(.Machine$double.eps)) {
116116
} else if (!any(Positive)) {
117117
array(0, dim(X)[2L:1L])
118118
} else {
119-
Xsvd$v[, Positive, drop = FALSE] %*%
119+
Xsvd$v[, Positive, drop = FALSE] %*%
120120
((1 / Xsvd$d[Positive]) * t(Xsvd$u[, Positive, drop = FALSE]))
121121
}
122122
}
@@ -126,11 +126,11 @@ get_vz <- function(X, bg, Z, object, pred_fun, feature_names, w, ...) {
126126
m <- nrow(Z)
127127
not_Z <- !Z
128128
n_bg <- nrow(bg) / m # because bg was replicated m times
129-
129+
130130
# Replicate not_Z, so that X, bg, not_Z are all of dimension (m*n_bg x p)
131131
g <- rep(seq_len(m), each = n_bg)
132132
not_Z <- not_Z[g, , drop = FALSE]
133-
133+
134134
if (is.matrix(X)) {
135135
# Remember that columns of X and bg are perfectly aligned in this case
136136
X[not_Z] <- bg[not_Z]
@@ -143,7 +143,7 @@ get_vz <- function(X, bg, Z, object, pred_fun, feature_names, w, ...) {
143143
}
144144
}
145145
preds <- check_pred(pred_fun(object, X, ...), n = nrow(X))
146-
146+
147147
# Aggregate
148148
if (is.null(w)) {
149149
return(rowsum(preds, group = g, reorder = FALSE) / n_bg)
@@ -162,15 +162,15 @@ weighted_colMeans <- function(x, w = NULL, ...) {
162162
if (nrow(x) != length(w)) {
163163
stop("Weights w not compatible with matrix x")
164164
}
165-
out <- colSums(x * w, ...) / sum(w)
165+
out <- colSums(x * w, ...) / sum(w)
166166
}
167167
matrix(out, nrow = 1L)
168168
}
169169

170170
# Binds list of matrices along new first axis
171171
abind1 <- function(a) {
172172
out <- array(
173-
dim = c(length(a), dim(a[[1L]])),
173+
dim = c(length(a), dim(a[[1L]])),
174174
dimnames = c(list(NULL), dimnames(a[[1L]]))
175175
)
176176
for (i in seq_along(a)) {
@@ -196,9 +196,9 @@ reorganize_list <- function(alist, nms) {
196196
# Checks and reshapes predictions to (n x K) matrix
197197
check_pred <- function(x, n) {
198198
if (
199-
!is.vector(x) &&
200-
!is.matrix(x) &&
201-
!is.data.frame(x) &&
199+
!is.vector(x) &&
200+
!is.matrix(x) &&
201+
!is.data.frame(x) &&
202202
!(is.array(x) && length(dim(x)) <= 2L)
203203
) {
204204
stop("Predictions must be a vector, matrix, data.frame, or <=2D array")
@@ -235,7 +235,7 @@ summarize_strategy <- function(p, exact, deg) {
235235
}
236236
if (deg == 0L) {
237237
return("Kernel SHAP values by iterative sampling")
238-
}
238+
}
239239
paste("Kernel SHAP values by the hybrid strategy of degree", deg)
240240
}
241241

0 commit comments

Comments
 (0)