Skip to content

Commit e347727

Browse files
committed
better appraoch to creating formulas from recipes based on the model type
1 parent a68cf23 commit e347727

File tree

3 files changed

+77
-28
lines changed

3 files changed

+77
-28
lines changed

R/fitter.R

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -226,18 +226,17 @@ formula_to_matrix <- function(object, formula, data, control, ...) {
226226

227227
#' @importFrom recipes prep juice all_predictors all_outcomes
228228

229-
recipe_data <- function(recipe, data, control, output = "matrix", combine = FALSE) {
229+
# add case weights as extra object returned (out$weights)
230+
recipe_data <- function(recipe, data, object, control, output = "matrix", combine = FALSE) {
230231
recipe <-
231232
prep(recipe, training = data, retain = TRUE, verbose = control$verbosity > 1)
232233

233234
if (combine) {
234-
out <- list(data = juice(recipe, all_predictors(), all_outcomes(), composition = output))
235-
data_info <- summary(recipe)
236-
y_names <- data_info$variable[data_info$role == "outcome"]
237-
if (length(y_names) > 1)
238-
out$form <- paste0("cbind(", paste0(y_names, collapse = ","), ")~.")
239-
else
240-
out$form <- paste0(y_names, "~.")
235+
out <- list(
236+
data = juice(recipe, composition = output),
237+
form = formula(object, (recipe))
238+
)
239+
241240
} else {
242241
out <-
243242
list(
@@ -257,16 +256,16 @@ recipe_data <- function(recipe, data, control, output = "matrix", combine = FALS
257256

258257
recipe_to_formula <-
259258
function(object, recipe, data, control, ...) {
260-
info <- recipe_data(recipe, data, control, output = "tibble", combine = TRUE)
259+
info <- recipe_data(recipe, data, object, control, output = "tibble", combine = TRUE)
261260
formula_to_formula(object, info$form, info$data, control, ...)
262261
}
263262

264263
recipe_to_data.frame <- function(object, recipe, data, control, ...) {
265-
info <- recipe_data(recipe, data, control, output = "tibble", combine = FALSE)
264+
info <- recipe_data(recipe, data, object, control, output = "tibble", combine = FALSE)
266265
xy_to_xy(object, info$x, info$y, control, ...)
267266
}
268267

269268
recipe_to_matrix <- function(object, recipe, data, control, ...) {
270-
info <- recipe_data(recipe, data, control, output = "matrix", combine = FALSE)
269+
info <- recipe_data(recipe, data, object, control, output = "matrix", combine = FALSE)
271270
xy_to_xy(object, info$x, info$y, control, ...)
272271
}

R/form_recipe.R

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Class for creating an appropriate formula for a model based on the roles
2+
# contained in the recipe.
3+
4+
# maybe use `formula` generic with `formula.model_spec` acting as the generic?
5+
6+
formula.model_spec <- function(x, recipe, ...) {
7+
rec_vars <- summary(recipe)
8+
y_names <- rec_vars$variable[rec_vars$role == "outcome"]
9+
if (length(y_names) > 1)
10+
form_text <- paste0("cbind(", paste0(y_names, collapse = ","), ")~.")
11+
else
12+
form_text <- paste0(y_names, "~.")
13+
14+
form <- try(as.formula(form_text), silent = TRUE)
15+
if(inherits(form, "try-error"))
16+
stop("Could not parse the model formula: ", form_text, call. = FALSE)
17+
form
18+
}
19+
20+
21+
formula.surv_reg <- function(x, recipe, ...) {
22+
rec_vars <- summary(recipe)
23+
y_names <- rec_vars$variable[rec_vars$role == "outcome"]
24+
if (length(y_names) > 2 | length(y_names) < 1)
25+
stop("There should be 1-2 variables in the `outcome` role.", call. = FALSE)
26+
cens_names <- rec_vars$variable[rec_vars$role == "censoring var"]
27+
if (length(cens_names) > 1)
28+
stop("There should be 0-1 variables in the `censoring` role.", call. = FALSE)
29+
x_names <- rec_vars$variable[rec_vars$role == "predictor"]
30+
31+
# construct basic formula
32+
form_text <- paste0("Surv(", paste0(y_names, collapse = ", "))
33+
if (length(cens_names) == 1)
34+
form_text <- paste0(form_text, ", ", cens_names, ") ~")
35+
else
36+
form_text <- paste0(form_text, ") ~")
37+
38+
if (length(x_names) == 0)
39+
form_text <- paste0(form_text, "1")
40+
else
41+
form_text <- paste0(form_text, paste0(x_names, collapse = "+"))
42+
43+
# engine-speciifc options (e.g. spark needing censor var in text)
44+
if (!is.null(x$engine) && x$engine == "flexsurv") {
45+
extra_ind <- which(rec_vars$role %in% flexsurv_params)
46+
if (length(extra_ind) > 0) {
47+
extra_terms <- paste0(
48+
rec_vars$role[extra_ind], "(",
49+
rec_vars$variable[extra_ind], ")"
50+
)
51+
form_text <- paste0(form_text, "+", paste0(extra_terms, collapse = "+"))
52+
}
53+
if (any(rec_vars$role == "strata"))
54+
warning(
55+
"`flexsurv` does not use the `strata` function; instead use ",
56+
"the parameter roles for differential values (e.g. `sigma`).",
57+
call. = FALSE
58+
)
59+
}
60+
61+
form <- try(as.formula(form_text), silent = TRUE)
62+
if(inherits(form, "try-error"))
63+
stop("Could not parse the model formula: ", form_text, call. = FALSE)
64+
form
65+
}
66+
67+
flexsurv_params <- c("sigma", "shape", "sdlog", "Q", "k", "P", "S1", "s2")

R/misc.R

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -101,20 +101,3 @@ resolve_args <- function(args, ...) {
101101
}
102102
args
103103
}
104-
105-
106-
# A function to convert a R formula to spark format
107-
surv_to_spark_formula <- function(f) {
108-
if (!inherits(f, "formula"))
109-
stop("A formula is required.")
110-
if (length(f[[2]]) != 3)
111-
stop("spark requires the `Surv` object to have a ",
112-
"censoring indicator.")
113-
if (!all.equal(f[[2]][[1]], as.name("Surv")))
114-
stop("The formula sould contain a `Surv` object.")
115-
f2 <- f
116-
f2[[2]] <- f[[2]][[2]]
117-
list(formula = f2, censor = deparse(f[[2]][[3]]))
118-
}
119-
120-

0 commit comments

Comments
 (0)