Skip to content

Commit eeaf82f

Browse files
Merge pull request #1168 from tidymodels/fix1166
save x column names from fit_xy()
2 parents d3744d2 + b5fee7b commit eeaf82f

File tree

4 files changed

+24
-1
lines changed

4 files changed

+24
-1
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
* Ensure that `knit_engine_docs()` has the required packages installed (#1156).
1616

17+
* Fixed bug where some models fit using `fit_xy()` couldn't predict (#1166).
1718

1819
# parsnip 1.2.1
1920

R/fit_helpers.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ xy_xy <- function(object,
115115
} else {
116116
y_name <- colnames(env$y)
117117
}
118-
res$preproc <- list(y_var = y_name)
118+
res$preproc <- list(y_var = y_name, x_names = colnames(env$x))
119119
res$elapsed <- list(elapsed = elapsed, print = control$verbosity > 1L)
120120
res
121121
}

R/predict.R

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,11 +440,17 @@ prepare_data <- function(object, new_data) {
440440
preproc_names <- names(object$preproc)
441441
translate_from_formula_to_xy <- any(preproc_names == "terms", na.rm = TRUE)
442442
translate_from_xy_to_formula <- any(preproc_names == "x_var", na.rm = TRUE)
443+
# For backwards compatibility, only do this if `y_var` is missing and
444+
# `x_names` is present
445+
translate_from_xy_to_xy <- any(preproc_names == "x_names", na.rm = TRUE) &&
446+
identical(object$preproc$y_var, character(0))
443447

444448
if (translate_from_formula_to_xy) {
445449
new_data <- .convert_form_to_xy_new(object$preproc, new_data)$x
446450
} else if (translate_from_xy_to_formula) {
447451
new_data <- .convert_xy_to_form_new(object$preproc, new_data)
452+
} else if (translate_from_xy_to_xy) {
453+
new_data <- new_data[, object$preproc$x_names]
448454
}
449455

450456
encodings <- get_encoding(class(object$spec)[1])

tests/testthat/test-predict_formats.R

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,4 +106,20 @@ test_that('non-factor classification', {
106106
)
107107
})
108108

109+
test_that("predict() works for model fit with fit_xy() (#1166)", {
110+
skip_if_not_installed("xgboost")
109111

112+
spec <- boost_tree() %>%
113+
set_mode("regression") %>%
114+
set_engine("xgboost")
115+
116+
tree_fit <- fit(spec, mpg ~ ., data = mtcars)
117+
118+
exp <- predict(tree_fit, mtcars)
119+
120+
tree_fit <- fit_xy(spec, x = mtcars[, -1], y = mtcars[, 1])
121+
122+
res <- predict(tree_fit, mtcars)
123+
124+
expect_identical(exp, res)
125+
})

0 commit comments

Comments
 (0)