Skip to content

Commit 27df158

Browse files
function to get prediction columns (#1224)
* function to get prediction columns * forgotten pkgdown entry * also, bump version number * fix for workflows * Apply suggestions from code review Co-authored-by: Emil Hvitfeldt <[email protected]> --------- Co-authored-by: Emil Hvitfeldt <[email protected]>
1 parent a212f78 commit 27df158

File tree

7 files changed

+174
-1
lines changed

7 files changed

+174
-1
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: parsnip
22
Title: A Common API to Modeling and Analysis Functions
3-
Version: 1.2.1.9003
3+
Version: 1.2.1.9004
44
Authors@R: c(
55
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre")),
66
person("Davis", "Vaughan", , "[email protected]", role = "aut"),

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ export(.dat)
185185
export(.extract_surv_status)
186186
export(.extract_surv_time)
187187
export(.facts)
188+
export(.get_prediction_column_names)
188189
export(.lvls)
189190
export(.model_param_name_key)
190191
export(.obs)

R/misc.R

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,3 +575,75 @@ is_cran_check <- function() {
575575
}
576576
# nocov end
577577

578+
# ------------------------------------------------------------------------------
579+
580+
#' Obtain names of prediction columns for a fitted model or workflow
581+
#'
582+
#' [.get_prediction_column_names()] returns a list that has the names of the
583+
#' columns for the primary prediction types for a model.
584+
#' @param x A fitted parsnip model (class `"model_fit"`) or a fitted workflow.
585+
#' @param syms Should the column names be converted to symbols? Defaults to `FALSE`.
586+
#' @return A list with elements `"estimate"` and `"probabilities"`.
587+
#' @examplesIf !parsnip:::is_cran_check()
588+
#' library(dplyr)
589+
#' library(modeldata)
590+
#' data("two_class_dat")
591+
#'
592+
#' levels(two_class_dat$Class)
593+
#' lr_fit <- logistic_reg() %>% fit(Class ~ ., data = two_class_dat)
594+
#'
595+
#' .get_prediction_column_names(lr_fit)
596+
#' .get_prediction_column_names(lr_fit, syms = TRUE)
597+
#' @export
598+
.get_prediction_column_names <- function(x, syms = FALSE) {
599+
if (!inherits(x, c("model_fit", "workflow"))) {
600+
cli::cli_abort("{.arg x} should be an object with class {.cls model_fit} or
601+
{.cls workflow}, not {.obj_type_friendly {x}}.")
602+
}
603+
604+
if (inherits(x, "workflow")) {
605+
x <- x %>% extract_fit_parsnip(x)
606+
}
607+
model_spec <- extract_spec_parsnip(x)
608+
model_engine <- model_spec$engine
609+
model_mode <- model_spec$mode
610+
model_type <- class(model_spec)[1]
611+
612+
# appropriate populate the model db
613+
inst_res <- purrr::map(required_pkgs(x), rlang::check_installed)
614+
predict_types <-
615+
get_from_env(paste0(model_type, "_predict")) %>%
616+
dplyr::filter(engine == model_engine & mode == model_mode) %>%
617+
purrr::pluck("type")
618+
619+
if (length(predict_types) == 0) {
620+
cli::cli_abort("Prediction information could not be found for this
621+
{.fn {model_type}} with engine {.val {model_engine}} and mode
622+
{.val {model_mode}}. Does a parsnip extension package need to
623+
be loaded?")
624+
}
625+
626+
res <- list(estimate = character(0), probabilities = character(0))
627+
628+
if (model_mode == "regression") {
629+
res$estimate <- ".pred"
630+
} else if (model_mode == "classification") {
631+
res$estimate <- ".pred_class"
632+
if (any(predict_types == "prob")) {
633+
res$probabilities <- paste0(".pred_", x$lvl)
634+
}
635+
} else if (model_mode == "censored regression") {
636+
res$estimate <- ".pred_time"
637+
if (any(predict_types %in% c("survival"))) {
638+
res$probabilities <- ".pred"
639+
}
640+
} else {
641+
# Should be unreachable
642+
cli::cli_abort("Unsupported model mode {model_mode}.")
643+
}
644+
645+
if (syms) {
646+
res <- purrr::map(res, rlang::syms)
647+
}
648+
res
649+
}

_pkgdown.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,4 @@ reference:
111111
- .extract_surv_status
112112
- .extract_surv_time
113113
- .model_param_name_key
114+
- .get_prediction_column_names

man/dot-get_prediction_column_names.Rd

Lines changed: 33 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/_snaps/misc.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,3 +227,19 @@
227227
Error in `check_outcome()`:
228228
! For a censored regression model, the outcome should be a <Surv> object, not an integer vector.
229229

230+
# obtaining prediction columns
231+
232+
Code
233+
.get_prediction_column_names(1)
234+
Condition
235+
Error in `.get_prediction_column_names()`:
236+
! `x` should be an object with class <model_fit> or <workflow>, not a number.
237+
238+
---
239+
240+
Code
241+
.get_prediction_column_names(unk_fit)
242+
Condition
243+
Error in `.get_prediction_column_names()`:
244+
! Prediction information could not be found for this `linear_reg()` with engine "lm" and mode "Depeche". Does a parsnip extension package need to be loaded?
245+

tests/testthat/test-misc.R

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,3 +249,53 @@ test_that('check_outcome works as expected', {
249249
check_outcome(1:2, cens_spec)
250250
)
251251
})
252+
253+
# ------------------------------------------------------------------------------
254+
255+
test_that('obtaining prediction columns', {
256+
skip_if_not_installed("modeldata")
257+
data(two_class_dat, package = "modeldata")
258+
259+
### classification
260+
lr_fit <- logistic_reg() %>% fit(Class ~ ., data = two_class_dat)
261+
expect_equal(
262+
.get_prediction_column_names(lr_fit),
263+
list(estimate = ".pred_class",
264+
probabilities = c(".pred_Class1", ".pred_Class2"))
265+
)
266+
expect_equal(
267+
.get_prediction_column_names(lr_fit, syms = TRUE),
268+
list(estimate = list(quote(.pred_class)),
269+
probabilities = list(quote(.pred_Class1), quote(.pred_Class2)))
270+
)
271+
272+
### regression
273+
ols_fit <- linear_reg() %>% fit(mpg ~ ., data = mtcars)
274+
expect_equal(
275+
.get_prediction_column_names(ols_fit),
276+
list(estimate = ".pred",
277+
probabilities = character(0))
278+
)
279+
expect_equal(
280+
.get_prediction_column_names(ols_fit, syms = TRUE),
281+
list(estimate = list(quote(.pred)),
282+
probabilities = list())
283+
)
284+
285+
### censored regression
286+
# in extratests
287+
288+
### bad input
289+
expect_snapshot(
290+
.get_prediction_column_names(1),
291+
error = TRUE
292+
)
293+
294+
unk_fit <- ols_fit
295+
unk_fit$spec$mode <- "Depeche"
296+
expect_snapshot(
297+
.get_prediction_column_names(unk_fit),
298+
error = TRUE
299+
)
300+
301+
})

0 commit comments

Comments
 (0)