Skip to content

Commit 7ed0dbb

Browse files
committed
conditional logic for data frames and matrices
1 parent 1d3658f commit 7ed0dbb

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

R/fitter.R

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
# TODO protect engine = "spark" with non-spark data object
88

9+
# TODO formula method and others pass symbols like formula method
10+
911
fit_interface_matrix <- function(x, y, object, control, ...) {
1012
if (object$engine == "spark")
1113
stop("spark objects can only be used with the formula interface to `fit` ",
@@ -238,8 +240,13 @@ recipe_data <- function(recipe, data, control, output = "matrix", combine = FALS
238240
x = juice(recipe, all_predictors(), composition = output),
239241
y = juice(recipe, all_outcomes(), composition = output)
240242
)
241-
if (ncol(out$y) == 1)
242-
out$y <- out$y[[1]]
243+
if (ncol(out$y) == 1) {
244+
if (is.matrix(out$y))
245+
out$y <- out$y[, 1]
246+
else
247+
out$y <- out$y[[1]]
248+
}
249+
243250
}
244251
out
245252
}

0 commit comments

Comments
 (0)