Skip to content

Commit 93463c2

Browse files
Merge branch 'main' into doc-sparse-data
2 parents 44403fc + 8af5ddf commit 93463c2

File tree

8 files changed

+143
-11
lines changed

8 files changed

+143
-11
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
* `fit()` and `fit_xy()` can now take sparse tibbles as data values (#1165).
66

7+
* `predict()` can now take dgCMatrix and sparse tibble input for `new_data` argument, and error informatively when model doesn't support it (#1167).
8+
79
* Transitioned package errors and warnings to use cli (#1147 and #1148 by
810
@shum461, #1153 by @RobLBaker and @wright13, #1154 by @JamesHWade, #1160,
911
#1161, #1081).
File renamed without changes.

R/fit.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,10 @@ fit.model_spec <-
137137
cli::cli_abort(msg)
138138
}
139139

140+
if (is_sparse_matrix(data)) {
141+
data <- sparsevctrs::coerce_to_sparse_tibble(data)
142+
}
143+
140144
dots <- quos(...)
141145

142146
if (length(possible_engines(object)) == 0) {
@@ -444,6 +448,10 @@ check_xy_interface <- function(x, y, cl, model) {
444448
}
445449

446450
allow_sparse <- function(x) {
451+
if (inherits(x, "model_fit")) {
452+
x <- x$spec
453+
}
454+
447455
res <- get_from_env(paste0(class(x)[1], "_encoding"))
448456
all(res$allow_sparse_x[res$engine == x$engine])
449457
}

R/predict.R

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...)
160160
}
161161
check_pred_type_dots(object, type, ...)
162162

163+
new_data <- to_sparse_data_frame(new_data, object)
164+
163165
res <- switch(
164166
type,
165167
numeric = predict_numeric(object = object, new_data = new_data, ...),
@@ -450,7 +452,7 @@ prepare_data <- function(object, new_data) {
450452
} else if (translate_from_xy_to_formula) {
451453
new_data <- .convert_xy_to_form_new(object$preproc, new_data)
452454
} else if (translate_from_xy_to_xy) {
453-
new_data <- new_data[, object$preproc$x_names]
455+
new_data <- new_data[, object$preproc$x_names, drop = FALSE]
454456
}
455457

456458
encodings <- get_encoding(class(object$spec)[1])

R/sparsevctrs.R

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,20 @@
1616
#' @name sparse_data
1717
NULL
1818

19-
to_sparse_data_frame <- function(x, object) {
20-
if (methods::is(x, "sparseMatrix")) {
19+
to_sparse_data_frame <- function(x, object, call = rlang::caller_env()) {
20+
if (is_sparse_matrix(x)) {
2121
if (allow_sparse(object)) {
2222
x <- sparsevctrs::coerce_to_sparse_data_frame(x)
2323
} else {
24+
if (inherits(object, "model_fit")) {
25+
object <- object$spec
26+
}
27+
2428
cli::cli_abort(
25-
"{.arg x} is a sparse matrix, but {.fn {class(object)[1]}} with
26-
engine {.code {object$engine}} doesn't accept that.")
29+
"{.arg x} is a sparse matrix, but {.fn {class(object)[1]}} with
30+
engine {.val {object$engine}} doesn't accept that.",
31+
call = call
32+
)
2733
}
2834
} else if (is.data.frame(x)) {
2935
x <- materialize_sparse_tibble(x, object, "x")
@@ -35,11 +41,19 @@ is_sparse_tibble <- function(x) {
3541
any(vapply(x, sparsevctrs::is_sparse_vector, logical(1)))
3642
}
3743

44+
is_sparse_matrix <- function(x) {
45+
methods::is(x, "sparseMatrix")
46+
}
47+
3848
materialize_sparse_tibble <- function(x, object, input) {
39-
if ((!allow_sparse(object)) && is_sparse_tibble(x)) {
49+
if (is_sparse_tibble(x) && (!allow_sparse(object))) {
50+
if (inherits(object, "model_fit")) {
51+
object <- object$spec
52+
}
53+
4054
cli::cli_warn(
4155
"{.arg {input}} is a sparse tibble, but {.fn {class(object)[1]}} with
42-
engine {.code {object$engine}} doesn't accept that. Converting to
56+
engine {.val {object$engine}} doesn't accept that. Converting to
4357
non-sparse."
4458
)
4559
for (i in seq_along(ncol(x))) {

tests/testthat/_snaps/sparsevctrs.md

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,47 @@
44
lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data[1:100, ])
55
Condition
66
Warning:
7-
`data` is a sparse tibble, but `linear_reg()` with engine `lm` doesn't accept that. Converting to non-sparse.
7+
`data` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse.
8+
9+
# sparse matrix can be passed to `fit()
10+
11+
Code
12+
lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data[1:100, ])
13+
Condition
14+
Warning:
15+
`data` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse.
816

917
# sparse tibble can be passed to `fit_xy()
1018

1119
Code
1220
lm_fit <- fit_xy(spec, x = hotel_data[1:100, -1], y = hotel_data[1:100, 1])
1321
Condition
1422
Warning:
15-
`x` is a sparse tibble, but `linear_reg()` with engine `lm` doesn't accept that. Converting to non-sparse.
23+
`x` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse.
1624

1725
# sparse matrices can be passed to `fit_xy()
1826

1927
Code
2028
lm_fit <- fit_xy(spec, x = hotel_data[1:100, -1], y = hotel_data[1:100, 1])
2129
Condition
22-
Error in `to_sparse_data_frame()`:
23-
! `x` is a sparse matrix, but `linear_reg()` with engine `lm` doesn't accept that.
30+
Error in `fit_xy()`:
31+
! `x` is a sparse matrix, but `linear_reg()` with engine "lm" doesn't accept that.
32+
33+
# sparse tibble can be passed to `predict()
34+
35+
Code
36+
preds <- predict(lm_fit, sparse_mtcars)
37+
Condition
38+
Warning:
39+
`x` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse.
40+
41+
# sparse matrices can be passed to `predict()
42+
43+
Code
44+
predict(lm_fit, sparse_mtcars)
45+
Condition
46+
Error in `predict()`:
47+
! `x` is a sparse matrix, but `linear_reg()` with engine "lm" doesn't accept that.
2448

2549
# to_sparse_data_frame() is used correctly
2650

tests/testthat/test-sparsevctrs.R

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,28 @@ test_that("sparse tibble can be passed to `fit()", {
2121
)
2222
})
2323

24+
test_that("sparse matrix can be passed to `fit()", {
25+
skip_if_not_installed("xgboost")
26+
27+
hotel_data <- sparse_hotel_rates()
28+
29+
spec <- boost_tree() %>%
30+
set_mode("regression") %>%
31+
set_engine("xgboost")
32+
33+
expect_no_error(
34+
lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data)
35+
)
36+
37+
spec <- linear_reg() %>%
38+
set_mode("regression") %>%
39+
set_engine("lm")
40+
41+
expect_snapshot(
42+
lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data[1:100, ])
43+
)
44+
})
45+
2446
test_that("sparse tibble can be passed to `fit_xy()", {
2547
skip_if_not_installed("xgboost")
2648

@@ -67,6 +89,66 @@ test_that("sparse matrices can be passed to `fit_xy()", {
6789
)
6890
})
6991

92+
test_that("sparse tibble can be passed to `predict()", {
93+
skip_if_not_installed("ranger")
94+
95+
hotel_data <- sparse_hotel_rates()
96+
hotel_data <- sparsevctrs::coerce_to_sparse_tibble(hotel_data)
97+
98+
spec <- rand_forest(trees = 10) %>%
99+
set_mode("regression") %>%
100+
set_engine("ranger")
101+
102+
tree_fit <- fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1])
103+
104+
expect_no_error(
105+
predict(tree_fit, hotel_data)
106+
)
107+
108+
spec <- linear_reg() %>%
109+
set_mode("regression") %>%
110+
set_engine("lm")
111+
112+
lm_fit <- fit(spec, mpg ~ ., data = mtcars)
113+
114+
sparse_mtcars <- mtcars %>%
115+
sparsevctrs::coerce_to_sparse_matrix() %>%
116+
sparsevctrs::coerce_to_sparse_tibble()
117+
118+
expect_snapshot(
119+
preds <- predict(lm_fit, sparse_mtcars)
120+
)
121+
})
122+
123+
test_that("sparse matrices can be passed to `predict()", {
124+
skip_if_not_installed("ranger")
125+
126+
hotel_data <- sparse_hotel_rates()
127+
128+
spec <- rand_forest(trees = 10) %>%
129+
set_mode("regression") %>%
130+
set_engine("ranger")
131+
132+
tree_fit <- fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1])
133+
134+
expect_no_error(
135+
predict(tree_fit, hotel_data)
136+
)
137+
138+
spec <- linear_reg() %>%
139+
set_mode("regression") %>%
140+
set_engine("lm")
141+
142+
lm_fit <- fit(spec, mpg ~ ., data = mtcars)
143+
144+
sparse_mtcars <- sparsevctrs::coerce_to_sparse_matrix(mtcars)
145+
146+
expect_snapshot(
147+
error = TRUE,
148+
predict(lm_fit, sparse_mtcars)
149+
)
150+
})
151+
70152
test_that("to_sparse_data_frame() is used correctly", {
71153
skip_if_not_installed("xgboost")
72154

0 commit comments

Comments
 (0)