Skip to content

Commit 58e4329

Browse files
topepo‘topepo’simonpcouch
authored
better call routing for errors (#1214)
* pass call through user-facing predict methods * pass call for internal code; probably will never be surfaced by user * show error is from autoplot() instead of map_glmnet_coefs() * un-used bartMachine code * fix bug in condense_control and route user-facing call * unit tests for one-hot encodings * pass calls through data conversion functions * redoc * small formatting changes * route some glmnet checking calls * route some spec updating calls * some predict call routing * make dev function as internal * redoc * revert passing in predict * Apply suggestions from code review Co-authored-by: Simon P. Couch <[email protected]> * update snapshots * redoc --------- Co-authored-by: ‘topepo’ <‘[email protected]’> Co-authored-by: Simon P. Couch <[email protected]>
1 parent 33f621c commit 58e4329

40 files changed

+262
-164
lines changed

NAMESPACE

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,6 @@ export(bag_mars)
202202
export(bag_mlp)
203203
export(bag_tree)
204204
export(bart)
205-
export(bartMachine_interval_calc)
206205
export(boost_tree)
207206
export(case_weights_allowed)
208207
export(cforest_train)

R/arguments.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ make_form_call <- function(object, env = NULL) {
258258
}
259259

260260
# TODO we need something to indicate that case weights are being used.
261-
make_xy_call <- function(object, target, env) {
261+
make_xy_call <- function(object, target, env, call = rlang::caller_env()) {
262262
fit_args <- object$method$fit$args
263263
uses_weights <- has_weights(env)
264264

@@ -283,7 +283,7 @@ make_xy_call <- function(object, target, env) {
283283
data.frame = rlang::expr(maybe_data_frame(x)),
284284
matrix = rlang::expr(maybe_matrix(x)),
285285
dgCMatrix = rlang::expr(maybe_sparse_matrix(x)),
286-
cli::cli_abort("Invalid data type target: {target}.")
286+
cli::cli_abort("Invalid data type target: {target}.", call = call)
287287
)
288288
if (uses_weights) {
289289
object$method$fit$args[[ unname(data_args["weights"]) ]] <- rlang::expr(weights)

R/autoplot.R

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,15 @@ autoplot.glmnet <- function(object, ..., min_penalty = 0, best_penalty = NULL,
4141
}
4242

4343

44-
map_glmnet_coefs <- function(x) {
44+
map_glmnet_coefs <- function(x, call = rlang::caller_env()) {
4545
coefs <- coef(x)
4646
# If parsnip is used to fit the model, glmnet should be attached and this will
4747
# work. If an object is loaded from a new session, they will need to load the
4848
# package.
4949
if (is.null(coefs)) {
5050
cli::cli_abort(
51-
"Please load the {.pkg glmnet} package before running {.fun autoplot}."
51+
"Please load the {.pkg glmnet} package before running {.fun autoplot}.",
52+
call = call
5253
)
5354
}
5455
p <- x$dim[1]
@@ -89,9 +90,10 @@ top_coefs <- function(x, top_n = 5) {
8990
dplyr::slice(seq_len(top_n))
9091
}
9192

92-
autoplot_glmnet <- function(x, min_penalty = 0, best_penalty = NULL, top_n = 3L, ...) {
93+
autoplot_glmnet <- function(x, min_penalty = 0, best_penalty = NULL, top_n = 3L,
94+
call = rlang::caller_env(), ...) {
9395
tidy_coefs <-
94-
map_glmnet_coefs(x) %>%
96+
map_glmnet_coefs(x, call = call) %>%
9597
dplyr::filter(penalty >= min_penalty)
9698

9799
actual_min_penalty <- min(tidy_coefs$penalty)

R/bart.R

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -130,61 +130,13 @@ update.bart <-
130130
)
131131
}
132132

133-
134133
#' Developer functions for predictions via BART models
135-
#' @export
136-
#' @keywords internal
137134
#' @name bart-internal
138135
#' @inheritParams predict.model_fit
139136
#' @param obj A parsnip object.
140-
#' @param ci Confidence (TRUE) or prediction interval (FALSE)
141137
#' @param level Confidence level.
142138
#' @param std_err Attach column for standard error of prediction or not.
143-
bartMachine_interval_calc <- function(new_data, obj, ci = TRUE, level = 0.95) {
144-
if (obj$spec$mode == "classification") {
145-
cli::cli_abort(
146-
"Prediction intervals are not possible for classification"
147-
)
148-
}
149-
get_std_err <- obj$spec$method$pred$pred_int$extras$std_error
150-
151-
if (ci) {
152-
cl <-
153-
rlang::call2(
154-
"calc_credible_intervals",
155-
.ns = "bartMachine",
156-
bart_machine = rlang::expr(obj$fit),
157-
new_data = rlang::expr(new_data),
158-
ci_conf = level
159-
)
160-
161-
} else {
162-
cl <-
163-
rlang::call2(
164-
"calc_prediction_intervals",
165-
.ns = "bartMachine",
166-
bart_machine = rlang::expr(obj$fit),
167-
new_data = rlang::expr(new_data),
168-
pi_conf = level
169-
)
170-
}
171-
res <- rlang::eval_tidy(cl)
172-
if (!ci) {
173-
if (get_std_err) {
174-
.std_error <- apply(res$all_prediction_samples, 1, stats::sd, na.rm = TRUE)
175-
}
176-
res <- res$interval
177-
}
178-
res <- tibble::as_tibble(res)
179-
names(res) <- c(".pred_lower", ".pred_upper")
180-
if (!ci & get_std_err) {
181-
res$.std_err <- .std_error
182-
}
183-
res
184-
}
185-
186139
#' @export
187-
#' @rdname bart-internal
188140
#' @keywords internal
189141
dbart_predict_calc <- function(obj, new_data, type, level = 0.95, std_err = FALSE) {
190142
types <- c("numeric", "class", "prob", "conf_int", "pred_int")

R/condense_control.R

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
#'
1111
#' @return A control object with the same elements and classes of `ref`, with
1212
#' values of `x`.
13+
#' @param call The execution environment of a currently running function, e.g.
14+
#' `caller_env()`. The function will be mentioned in error messages as the
15+
#' source of the error. See the call argument of [rlang::abort()] for more
16+
#' information.
1317
#' @keywords internal
1418
#' @export
1519
#'
@@ -20,16 +24,17 @@
2024
#'
2125
#' ctrl <- condense_control(ctrl, control_parsnip())
2226
#' str(ctrl)
23-
condense_control <- function(x, ref) {
27+
condense_control <- function(x, ref, ..., call = rlang::caller_env()) {
28+
check_dots_empty()
2429
mismatch <- setdiff(names(ref), names(x))
2530
if (length(mismatch)) {
2631
cli::cli_abort(
2732
c(
28-
"Object of class {.cls class(x)[1]} cannot be coerced to
29-
object of class {.cls class(ref)[1]}.",
33+
"{.obj_type_friendly {x}} cannot be coerced to {.obj_type_friendly {ref}}.",
3034
"i" = "{cli::qty(mismatch)} The argument{?s} {.arg {mismatch}}
3135
{?is/are} missing."
32-
)
36+
),
37+
call = call
3338
)
3439
}
3540
res <- x[names(ref)]

R/contr_one_hot.R

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
#' This contrast function produces a model matrix with indicator columns for
44
#' each level of each factor.
55
#'
6-
#' @param n A vector of character factor levels or the number of unique levels.
6+
#' @param n A vector of character factor levels (of length >=1) or the number
7+
#' of unique levels (>= 1).
78
#' @param contrasts This argument is for backwards compatibility and only the
89
#' default of `TRUE` is supported.
910
#' @param sparse This argument is for backwards compatibility and only the
@@ -24,9 +25,13 @@ contr_one_hot <- function(n, contrasts = TRUE, sparse = FALSE) {
2425
}
2526

2627
if (is.character(n)) {
28+
if (length(n) < 1) {
29+
cli::cli_abort("{.arg n} cannot be empty.")
30+
}
2731
names <- n
2832
n <- length(names)
2933
} else if (is.numeric(n)) {
34+
check_number_whole(n, min = 1)
3035
n <- as.integer(n)
3136

3237
if (length(n) != 1L) {
@@ -35,7 +40,7 @@ contr_one_hot <- function(n, contrasts = TRUE, sparse = FALSE) {
3540

3641
names <- as.character(seq_len(n))
3742
} else {
38-
cli::cli_abort("{.arg n} must be a character vector or an integer of size 1.")
43+
check_number_whole(n, min = 1)
3944
}
4045

4146
out <- diag(n)

R/convert_data.R

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,21 @@
4040
na.action = na.omit,
4141
indicators = "traditional",
4242
composition = "data.frame",
43-
remove_intercept = TRUE) {
43+
remove_intercept = TRUE,
44+
call = rlang::caller_env()) {
4445
if (!(composition %in% c("data.frame", "matrix", "dgCMatrix"))) {
4546
cli::cli_abort(
4647
"{.arg composition} should be either {.val data.frame}, {.val matrix}, or
47-
{.val dgCMatrix}."
48+
{.val dgCMatrix}.",
49+
call = call
4850
)
4951
}
5052

5153
if (sparsevctrs::has_sparse_elements(data)) {
5254
cli::cli_abort(
53-
"Sparse data cannot be used with formula interface. Please use
54-
{.fn fit_xy} instead."
55+
"Sparse data cannot be used with formula interface. Please use
56+
{.fn fit_xy} instead.",
57+
call = call
5558
)
5659
}
5760

@@ -84,7 +87,7 @@
8487

8588
w <- as.vector(model.weights(mod_frame))
8689
if (!is.null(w) && !is.numeric(w)) {
87-
cli::cli_abort("{.arg weights} must be a numeric vector.")
90+
cli::cli_abort("{.arg weights} must be a numeric vector.", call = call)
8891
}
8992

9093
# TODO: Do we actually use the offset when fitting?
@@ -175,10 +178,12 @@
175178
.convert_form_to_xy_new <- function(object,
176179
new_data,
177180
na.action = na.pass,
178-
composition = "data.frame") {
181+
composition = "data.frame",
182+
call = rlang::caller_env()) {
179183
if (!(composition %in% c("data.frame", "matrix"))) {
180184
cli::cli_abort(
181-
"{.arg composition} should be either {.val data.frame} or {.val matrix}."
185+
"{.arg composition} should be either {.val data.frame} or {.val matrix}.",
186+
call = call
182187
)
183188
}
184189

@@ -244,9 +249,10 @@
244249
y,
245250
weights = NULL,
246251
y_name = "..y",
247-
remove_intercept = TRUE) {
252+
remove_intercept = TRUE,
253+
call = rlang::caller_env()) {
248254
if (is.vector(x)) {
249-
cli::cli_abort("{.arg x} cannot be a vector.")
255+
cli::cli_abort("{.arg x} cannot be a vector.", call = call)
250256
}
251257

252258
if (remove_intercept) {
@@ -279,10 +285,10 @@
279285

280286
if (!is.null(weights)) {
281287
if (!is.numeric(weights)) {
282-
cli::cli_abort("{.arg weights} must be a numeric vector.")
288+
cli::cli_abort("{.arg weights} must be a numeric vector.", call = call)
283289
}
284290
if (length(weights) != nrow(x)) {
285-
cli::cli_abort("{.arg weights} should have {nrow(x)} elements.")
291+
cli::cli_abort("{.arg weights} should have {nrow(x)} elements.", call = call)
286292
}
287293

288294
form <- patch_formula_environment_with_case_weights(

R/descriptors.R

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,22 +103,23 @@ NULL
103103

104104
# Descriptor retrievers --------------------------------------------------------
105105

106-
get_descr_form <- function(formula, data) {
106+
get_descr_form <- function(formula, data, call = rlang::caller_env()) {
107107
if (inherits(data, "tbl_spark")) {
108108
res <- get_descr_spark(formula, data)
109109
} else {
110-
res <- get_descr_df(formula, data)
110+
res <- get_descr_df(formula, data, call = call)
111111
}
112112
res
113113
}
114114

115-
get_descr_df <- function(formula, data) {
115+
get_descr_df <- function(formula, data, call = rlang::caller_env()) {
116116

117117
tmp_dat <-
118118
.convert_form_to_xy_fit(formula,
119119
data,
120120
indicators = "none",
121-
remove_intercept = TRUE)
121+
remove_intercept = TRUE,
122+
call = call)
122123

123124
if(is.factor(tmp_dat$y)) {
124125
.lvls <- function() {
@@ -136,7 +137,8 @@ get_descr_df <- function(formula, data) {
136137
formula,
137138
data,
138139
indicators = "traditional",
139-
remove_intercept = TRUE
140+
remove_intercept = TRUE,
141+
call = call
140142
)$x
141143
)
142144
}
@@ -263,7 +265,7 @@ get_descr_spark <- function(formula, data) {
263265
)
264266
}
265267

266-
get_descr_xy <- function(x, y) {
268+
get_descr_xy <- function(x, y, call = rlang::caller_env()) {
267269

268270
.lvls <- if (is.factor(y)) {
269271
function() table(y, dnn = NULL)
@@ -291,7 +293,7 @@ get_descr_xy <- function(x, y) {
291293
}
292294

293295
.dat <- function() {
294-
.convert_xy_to_form_fit(x, y, remove_intercept = TRUE)$data
296+
.convert_xy_to_form_fit(x, y, remove_intercept = TRUE, call = call)$data
295297
}
296298

297299
.x <- function() {

R/fit.R

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ fit.model_spec <-
157157
}
158158

159159
if (all(c("x", "y") %in% names(dots))) {
160-
cli::cli_abort("`fit.model_spec()` is for the formula methods. Use `fit_xy()` instead.")
160+
cli::cli_abort("{.fn fit.model_spec} is for the formula methods. Use {.fn fit_xy} instead.")
161161
}
162162
cl <- match.call(expand.dots = TRUE)
163163
# Create an environment with the evaluated argument objects. This will be
@@ -307,7 +307,8 @@ fit_xy.model_spec <-
307307

308308
if (object$engine == "spark") {
309309
cli::cli_abort(
310-
"spark objects can only be used with the formula interface to {.fn fit} with a spark data object."
310+
"spark objects can only be used with the formula interface to {.fn fit}
311+
with a spark data object."
311312
)
312313
}
313314

R/fit_helpers.R

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ form_form <-
2727

2828
# if descriptors are needed, update descr_env with the calculated values
2929
if (requires_descrs(object)) {
30-
data_stats <- get_descr_form(env$formula, env$data)
30+
data_stats <- get_descr_form(env$formula, env$data, call = call)
3131
scoped_descrs(data_stats)
3232
}
3333

@@ -86,7 +86,7 @@ xy_xy <- function(object,
8686

8787
# if descriptors are needed, update descr_env with the calculated values
8888
if (requires_descrs(object)) {
89-
data_stats <- get_descr_xy(env$x, env$y)
89+
data_stats <- get_descr_xy(env$x, env$y, call = call)
9090
scoped_descrs(data_stats)
9191
}
9292

@@ -96,7 +96,7 @@ xy_xy <- function(object,
9696
# sub in arguments to actual syntax for corresponding engine
9797
object <- translate(object, engine = object$engine)
9898

99-
fit_call <- make_xy_call(object, target, env)
99+
fit_call <- make_xy_call(object, target, env, call)
100100

101101
res <- list(lvl = levels(env$y), spec = object)
102102

@@ -141,7 +141,8 @@ form_xy <- function(object, control, env,
141141
...,
142142
composition = target,
143143
indicators = indicators,
144-
remove_intercept = remove_intercept
144+
remove_intercept = remove_intercept,
145+
call = call
145146
)
146147
env$x <- data_obj$x
147148
env$y <- data_obj$y

0 commit comments

Comments
 (0)