Skip to content

Commit ad10ea4

Browse files
authored
Merge pull request #141 from ModelOriented/fix-ranger-survival
fix problematic argument survival in ranger models
2 parents a1ed340 + 74be45d commit ad10ea4

File tree

4 files changed

+53
-23
lines changed

4 files changed

+53
-23
lines changed

R/kernelshap.R

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -359,12 +359,11 @@ kernelshap.ranger <- function(
359359
survival = c("chf", "prob"),
360360
...
361361
) {
362-
survival <- match.arg(survival)
363-
362+
364363
if (is.null(pred_fun)) {
365-
pred_fun <- pred_ranger
364+
pred_fun <- create_ranger_pred_fun(object$treetype, survival = match.arg(survival))
366365
}
367-
366+
368367
kernelshap.default(
369368
object = object,
370369
X = X,
@@ -381,7 +380,6 @@ kernelshap.ranger <- function(
381380
parallel = parallel,
382381
parallel_args = parallel_args,
383382
verbose = verbose,
384-
survival = survival,
385383
...
386384
)
387385
}

R/permshap.R

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,11 @@ permshap.ranger <- function(
172172
survival = c("chf", "prob"),
173173
...
174174
) {
175-
survival <- match.arg(survival)
176-
175+
177176
if (is.null(pred_fun)) {
178-
pred_fun <- pred_ranger
177+
pred_fun <- create_ranger_pred_fun(object$treetype, survival = match.arg(survival))
179178
}
180-
179+
181180
permshap.default(
182181
object = object,
183182
X = X,
@@ -188,7 +187,6 @@ permshap.ranger <- function(
188187
parallel = parallel,
189188
parallel_args = parallel_args,
190189
verbose = verbose,
191-
survival = survival,
192190
...
193191
)
194192
}

R/pred_fun.R

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,33 @@
11
#' Predict Function for Ranger
22
#'
3-
#' Internal function that prepares the predictions of different types of ranger models,
4-
#' including survival models.
3+
#' Returns prediction function for different modes of ranger.
54
#'
65
#' @noRd
76
#' @keywords internal
8-
#' @param model Fitted ranger model.
9-
#' @param newdata Data to predict on.
7+
#' @param treetype The value of `fit$treetype` in a fitted ranger model.
108
#' @param survival Cumulative hazards "chf" (default) or probabilities "prob" per time.
11-
#' @param ... Additional arguments passed to ranger's predict function.
129
#'
13-
#' @returns A vector or matrix with predictions.
14-
pred_ranger <- function(model, newdata, survival = c("chf", "prob"), ...) {
10+
#' @returns A function with signature f(model, newdata, ...).
11+
create_ranger_pred_fun <- function(treetype, survival = c("chf", "prob")) {
1512
survival <- match.arg(survival)
1613

17-
pred <- stats::predict(model, newdata, ...)
14+
if (treetype != "Survival") {
15+
pred_fun <- function(model, newdata, ...) {
16+
stats::predict(model, newdata, ...)$predictions
17+
}
18+
return(pred_fun)
19+
}
20+
21+
if (survival == "prob") {
22+
survival <- "survival"
23+
}
1824

19-
if (model$treetype == "Survival") {
20-
out <- if (survival == "chf") pred$chf else pred$survival
25+
pred_fun <- function(model, newdata, ...) {
26+
pred <- stats::predict(model, newdata, ...)
27+
out <- pred[[survival]]
2128
colnames(out) <- paste0("t", pred$unique.death.times)
22-
} else {
23-
out <- pred$predictions
29+
return(out)
2430
}
25-
return(out)
31+
return(pred_fun)
2632
}
2733

backlog/test_ranger.R

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
library(ranger)
2+
library(survival)
3+
library(kernelshap)
4+
5+
set.seed(1)
6+
7+
fit <- ranger(Surv(time, status) ~ ., data = veteran, num.trees = 20)
8+
fit2 <- ranger(time ~ . - status, data = veteran, num.trees = 20)
9+
fit3 <- ranger(time ~ . - status, data = veteran, quantreg = TRUE, num.trees = 20)
10+
fit4 <- ranger(status ~ . - time, data = veteran, probability = TRUE, num.trees = 20)
11+
12+
xvars <- setdiff(colnames(veteran), c("time", "status"))
13+
14+
kernelshap(fit, head(veteran), feature_names = xvars, bg_X = veteran)
15+
permshap(fit, head(veteran), feature_names = xvars, bg_X = veteran)
16+
17+
kernelshap(fit, head(veteran), feature_names = xvars, bg_X = veteran, survival = "prob")
18+
permshap(fit, head(veteran), feature_names = xvars, bg_X = veteran, survival = "prob")
19+
20+
kernelshap(fit2, head(veteran), feature_names = xvars, bg_X = veteran)
21+
permshap(fit2, head(veteran), feature_names = xvars, bg_X = veteran)
22+
23+
kernelshap(fit3, head(veteran), feature_names = xvars, bg_X = veteran, type = "quantiles")
24+
permshap(fit3, head(veteran), feature_names = xvars, bg_X = veteran, type = "quantiles")
25+
26+
kernelshap(fit4, head(veteran), feature_names = xvars, bg_X = veteran)
27+
permshap(fit4, head(veteran), feature_names = xvars, bg_X = veteran)
28+

0 commit comments

Comments
 (0)