Skip to content

Commit 227ab92

Browse files
use type checkers in remaining functions (#1186)
--------- Co-authored-by: Emil Hvitfeldt <[email protected]>
1 parent a188159 commit 227ab92

File tree

5 files changed

+12
-28
lines changed

5 files changed

+12
-28
lines changed

R/autoplot.R

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ autoplot.model_fit <- function(object, ...) {
3434
#' @rdname autoplot.model_fit
3535
autoplot.glmnet <- function(object, ..., min_penalty = 0, best_penalty = NULL,
3636
top_n = 3L) {
37+
check_number_decimal(min_penalty, min = 0, max = 1)
38+
check_number_decimal(best_penalty, min = 0, max = 1, allow_null = TRUE)
39+
check_number_whole(top_n, min = 1, max = Inf, allow_infinite = TRUE)
3740
autoplot_glmnet(object, min_penalty, best_penalty, top_n, ...)
3841
}
3942

@@ -87,8 +90,6 @@ top_coefs <- function(x, top_n = 5) {
8790
}
8891

8992
autoplot_glmnet <- function(x, min_penalty = 0, best_penalty = NULL, top_n = 3L, ...) {
90-
check_penalty_value(min_penalty)
91-
9293
tidy_coefs <-
9394
map_glmnet_coefs(x) %>%
9495
dplyr::filter(penalty >= min_penalty)
@@ -138,7 +139,6 @@ autoplot_glmnet <- function(x, min_penalty = 0, best_penalty = NULL, top_n = 3L,
138139
}
139140

140141
if (!is.null(best_penalty)) {
141-
check_penalty_value(best_penalty)
142142
p <- p + ggplot2::geom_vline(xintercept = best_penalty, lty = 3)
143143
}
144144

@@ -159,13 +159,4 @@ autoplot_glmnet <- function(x, min_penalty = 0, best_penalty = NULL, top_n = 3L,
159159
p
160160
}
161161

162-
check_penalty_value <- function(x) {
163-
cl <- match.call()
164-
arg_val <- as.character(cl$x)
165-
if (!is.vector(x) || length(x) != 1 || !is.numeric(x) || x < 0) {
166-
cli::cli_abort("{.arg {arg_val}} should be a single, non-negative value.")
167-
}
168-
invisible(x)
169-
}
170-
171162
# nocov end

R/control_parsnip.R

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,8 @@ check_control <- function(x, call = rlang::caller_env()) {
3737
and {.field catch}.",
3838
call = call
3939
)
40-
# based on ?is.integer
41-
int_check <- function(x, tol = .Machine$double.eps^0.5) abs(x - round(x)) < tol
42-
if (!int_check(x$verbosity))
43-
cli::cli_abort("{.arg verbosity} should be an integer.", call = call)
44-
if (!is.logical(x$catch))
45-
cli::cli_abort("{.arg catch} should be a logical.", call = call)
40+
check_number_whole(x$verbosity, call = call)
41+
check_bool(x$catch, call = call)
4642
x
4743
}
4844

R/required_pkgs.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,14 @@ required_pkgs.model_spec <- function(x, infra = TRUE, ...) {
2626
if (is.null(x$engine)) {
2727
cli::cli_abort("Please set an engine.")
2828
}
29+
check_bool(infra)
2930
get_pkgs(x, infra)
3031
}
3132

3233
#' @export
3334
#' @rdname required_pkgs.model_spec
3435
required_pkgs.model_fit <- function(x, infra = TRUE, ...) {
36+
check_bool(infra)
3537
get_pkgs(x$spec, infra)
3638
}
3739

R/tidy_glmnet.R

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,9 @@ get_glmn_coefs <- function(x, penalty = 0.01) {
5555
res
5656
}
5757

58-
tidy_glmnet <- function(x, penalty = NULL, ...) {
58+
tidy_glmnet <- function(x, penalty = NULL, ..., call = caller_env()) {
5959
check_installs(x$spec)
6060
load_libs(x$spec, quiet = TRUE, attach = TRUE)
61-
if (is.null(penalty)) {
62-
if (isTRUE(is.numeric(x$spec$args$penalty))){
63-
penalty <- x$spec$args$penalty
64-
} else {
65-
rlang::abort("Please pick a single value of `penalty`.")
66-
}
67-
}
61+
check_number_decimal(penalty, min = 0, max = 1, allow_null = TRUE, call = call)
6862
get_glmn_coefs(x$fit, penalty = penalty)
6963
}

R/tune_args.R

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,10 @@ tune_tbl <- function(name = character(),
5656
source = character(),
5757
component = character(),
5858
component_id = character(),
59-
full = FALSE) {
60-
59+
full = FALSE,
60+
call = caller_env()) {
6161

62+
check_bool(full, call = call)
6263
complete_id <- id[!is.na(id)]
6364
dups <- duplicated(complete_id)
6465
if (any(dups)) {

0 commit comments

Comments
 (0)