Skip to content

Commit 670e75d

Browse files
authored
Merge pull request #268 from tidymodels/using-parsnip-models
adjust for cases where parsnip predictions are from a parsnip model
2 parents ff954df + 952575a commit 670e75d

File tree

3 files changed

+19
-6
lines changed

3 files changed

+19
-6
lines changed

NEWS.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# parsnip (development version)
22

3+
# parsnip 0.0.5.9000
4+
5+
6+
37
# parsnip 0.0.5
48

59
## Fixes

R/predict.R

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,6 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...)
144144
raw = predict_raw(object = object, new_data = new_data, opts = opts, ...),
145145
rlang::abort(glue::glue("I don't know about type = '{type}'"))
146146
)
147-
148147
if (!inherits(res, "tbl_spark")) {
149148
res <- switch(
150149
type,
@@ -186,9 +185,11 @@ format_num <- function(x) {
186185
if (inherits(x, "tbl_spark"))
187186
return(x)
188187

189-
if (isTRUE(ncol(x) > 1)) {
188+
if (isTRUE(ncol(x) > 1) | is.data.frame(x)) {
190189
x <- as_tibble(x, .name_repair = "minimal")
191-
names(x) <- paste0(".pred_", names(x))
190+
if (!any(grepl("^\\.pred", names(x)))) {
191+
names(x) <- paste0(".pred_", names(x))
192+
}
192193
} else {
193194
x <- tibble(.pred = x)
194195
}
@@ -204,7 +205,9 @@ format_class <- function(x) {
204205
}
205206

206207
format_classprobs <- function(x) {
207-
names(x) <- paste0(".pred_", names(x))
208+
if (!any(grepl("^\\.pred_", names(x)))) {
209+
names(x) <- paste0(".pred_", names(x))
210+
}
208211
x <- as_tibble(x)
209212
x
210213
}

R/predict_class.R

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,14 @@ predict_class.model_fit <- function(object, new_data, ...) {
4040
if (is.vector(res) || is.factor(res)) {
4141
res <- factor(as.character(res), levels = object$lvl)
4242
} else {
43-
if (!inherits(res, "tbl_spark"))
44-
res$values <- factor(as.character(res$values), levels = object$lvl)
43+
if (!inherits(res, "tbl_spark")) {
44+
# Now case where a parsnip model generated `res`
45+
if (is.data.frame(res) && ncol(res) == 1 && is.factor(res[[1]])) {
46+
res <- res[[1]]
47+
} else {
48+
res$values <- factor(as.character(res$values), levels = object$lvl)
49+
}
50+
}
4551
}
4652

4753
res

0 commit comments

Comments
 (0)