Skip to content

Commit 2885946

Browse files
committed
better checks for interface issues (made to me model-specific)
1 parent e3a66b5 commit 2885946

File tree

2 files changed

+21
-13
lines changed

2 files changed

+21
-13
lines changed

R/fit.R

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ fit.model_spec <-
111111
) {
112112
cl <- match.call(expand.dots = TRUE)
113113
fit_interface <-
114-
check_interface(formula, recipe, x, y, data, cl)
114+
check_interface(formula, recipe, x, y, data, cl, object)
115115
object$engine <- engine
116116
object <- check_engine(object)
117117

@@ -210,7 +210,7 @@ show_call <- function(x)
210210
has_both_or_none <- function(a, b)
211211
(!is.null(a) & is.null(b)) | (is.null(a) & !is.null(b))
212212

213-
check_interface <- function(formula, recipe, x, y, data, cl) {
213+
check_interface <- function(formula, recipe, x, y, data, cl, model) {
214214
inher(formula, "formula", cl)
215215
inher(recipe, "recipe", cl)
216216
inher(x, c("data.frame", "matrix", "tbl_spark"), cl)
@@ -238,6 +238,10 @@ check_interface <- function(formula, recipe, x, y, data, cl) {
238238
stop("Too many specifications of arguments; used either 'x/y', ",
239239
"'formula/data', or 'recipe/data' combinations.", call. = FALSE)
240240

241+
if (inherits(model, "surv_reg") &&
242+
(matrix_interface | df_interface))
243+
stop("Survival models must use the formula or recipe interface.", call. = FALSE)
244+
241245
if (matrix_interface)
242246
return("data.frame")
243247
if (df_interface)

tests/testthat/test_fit_interfaces.R

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,30 @@ library(rlang)
66
rec <- recipe(~ ., data = iris)
77
f <- y ~ x
88

9+
smod <- surv_reg()
10+
rmod <- linear_reg()
11+
912
tester <-
10-
function(object, formula = NULL, recipe = NULL, x = NULL, y = NULL, data = NULL)
11-
parsnip:::check_interface(formula, recipe, x, y, data, match.call(expand.dots = TRUE))
13+
function(object, formula = NULL, recipe = NULL, x = NULL, y = NULL, data = NULL, model)
14+
parsnip:::check_interface(formula, recipe, x, y, data, match.call(expand.dots = TRUE), model)
1215

1316
test_that('good args', {
14-
expect_equal(tester(NULL, formula = f, data = iris), "formula")
15-
expect_equal(tester(NULL, recipe = rec, data = iris), "recipe")
16-
expect_equal(tester(NULL, x = iris, y = iris), "data.frame")
17-
expect_equal(tester(NULL, f, data = iris), "formula")
18-
expect_equal(tester(NULL, formula = f, data = iris, y = iris), "formula")
17+
expect_equal(tester(NULL, formula = f, data = iris, model = rmod), "formula")
18+
expect_equal(tester(NULL, recipe = rec, data = iris, model = rmod), "recipe")
19+
expect_equal(tester(NULL, x = iris, y = iris, model = rmod), "data.frame")
20+
expect_equal(tester(NULL, f, data = iris, model = rmod), "formula")
21+
expect_equal(tester(NULL, formula = f, data = iris, y = iris, model = rmod), "formula")
1922
})
2023

2124
test_that('unnamed args', {
22-
expect_error(tester(NULL, rec, data = iris))
23-
expect_error(tester(NULL, iris, y = iris))
24-
expect_error(tester(NULL, data = iris))
25+
expect_error(tester(NULL, rec, data = iris, model = rmod))
26+
expect_error(tester(NULL, iris, y = iris, model = rmod))
27+
expect_error(tester(NULL, data = iris, model = rmod))
2528
})
2629

2730
test_that('wrong args', {
28-
expect_error(tester(NULL, x = iris, data = iris))
31+
expect_error(tester(NULL, x = iris, data = iris, model = rmod))
32+
expect_error(tester(NULL, x = iris, y = iris$Sepal.Length, model = smod))
2933
expect_error(tester(NULL, f, x = iris, y = iris, data = iris))
3034
})
3135

0 commit comments

Comments
 (0)