Skip to content

Commit 0b09e78

Browse files
Merge pull request #1174 from tidymodels/sparse-matrix-fit-error
Make sure all sparse data errors look nice
2 parents 474152f + 236a39b commit 0b09e78

File tree

4 files changed

+52
-12
lines changed

4 files changed

+52
-12
lines changed

R/fit.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,10 @@ fit.model_spec <-
137137
cli::cli_abort(msg)
138138
}
139139

140+
if (is_sparse_matrix(data)) {
141+
data <- sparsevctrs::coerce_to_sparse_tibble(data)
142+
}
143+
140144
dots <- quos(...)
141145

142146
if (length(possible_engines(object)) == 0) {

R/sparsevctrs.R

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
to_sparse_data_frame <- function(x, object) {
2-
if (methods::is(x, "sparseMatrix")) {
1+
to_sparse_data_frame <- function(x, object, call = rlang::caller_env()) {
2+
if (is_sparse_matrix(x)) {
33
if (allow_sparse(object)) {
44
x <- sparsevctrs::coerce_to_sparse_data_frame(x)
55
} else {
@@ -8,8 +8,10 @@ to_sparse_data_frame <- function(x, object) {
88
}
99

1010
cli::cli_abort(
11-
"{.arg x} is a sparse matrix, but {.fn {class(object)[1]}} with
12-
engine {.code {object$engine}} doesn't accept that.")
11+
"{.arg x} is a sparse matrix, but {.fn {class(object)[1]}} with
12+
engine {.val {object$engine}} doesn't accept that.",
13+
call = call
14+
)
1315
}
1416
} else if (is.data.frame(x)) {
1517
x <- materialize_sparse_tibble(x, object, "x")
@@ -21,6 +23,10 @@ is_sparse_tibble <- function(x) {
2123
any(vapply(x, sparsevctrs::is_sparse_vector, logical(1)))
2224
}
2325

26+
is_sparse_matrix <- function(x) {
27+
methods::is(x, "sparseMatrix")
28+
}
29+
2430
materialize_sparse_tibble <- function(x, object, input) {
2531
if (is_sparse_tibble(x) && (!allow_sparse(object))) {
2632
if (inherits(object, "model_fit")) {
@@ -29,7 +35,7 @@ materialize_sparse_tibble <- function(x, object, input) {
2935

3036
cli::cli_warn(
3137
"{.arg {input}} is a sparse tibble, but {.fn {class(object)[1]}} with
32-
engine {.code {object$engine}} doesn't accept that. Converting to
38+
engine {.val {object$engine}} doesn't accept that. Converting to
3339
non-sparse."
3440
)
3541
for (i in seq_along(ncol(x))) {

tests/testthat/_snaps/sparsevctrs.md

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,39 +4,47 @@
44
lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data[1:100, ])
55
Condition
66
Warning:
7-
`data` is a sparse tibble, but `linear_reg()` with engine `lm` doesn't accept that. Converting to non-sparse.
7+
`data` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse.
8+
9+
# sparse matrix can be passed to `fit()
10+
11+
Code
12+
lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data[1:100, ])
13+
Condition
14+
Warning:
15+
`data` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse.
816

917
# sparse tibble can be passed to `fit_xy()
1018

1119
Code
1220
lm_fit <- fit_xy(spec, x = hotel_data[1:100, -1], y = hotel_data[1:100, 1])
1321
Condition
1422
Warning:
15-
`x` is a sparse tibble, but `linear_reg()` with engine `lm` doesn't accept that. Converting to non-sparse.
23+
`x` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse.
1624

1725
# sparse matrices can be passed to `fit_xy()
1826

1927
Code
2028
lm_fit <- fit_xy(spec, x = hotel_data[1:100, -1], y = hotel_data[1:100, 1])
2129
Condition
22-
Error in `to_sparse_data_frame()`:
23-
! `x` is a sparse matrix, but `linear_reg()` with engine `lm` doesn't accept that.
30+
Error in `fit_xy()`:
31+
! `x` is a sparse matrix, but `linear_reg()` with engine "lm" doesn't accept that.
2432

2533
# sparse tibble can be passed to `predict()
2634

2735
Code
2836
preds <- predict(lm_fit, sparse_mtcars)
2937
Condition
3038
Warning:
31-
`x` is a sparse tibble, but `linear_reg()` with engine `lm` doesn't accept that. Converting to non-sparse.
39+
`x` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse.
3240

3341
# sparse matrices can be passed to `predict()
3442

3543
Code
3644
predict(lm_fit, sparse_mtcars)
3745
Condition
38-
Error in `to_sparse_data_frame()`:
39-
! `x` is a sparse matrix, but `linear_reg()` with engine `lm` doesn't accept that.
46+
Error in `predict()`:
47+
! `x` is a sparse matrix, but `linear_reg()` with engine "lm" doesn't accept that.
4048

4149
# to_sparse_data_frame() is used correctly
4250

tests/testthat/test-sparsevctrs.R

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,28 @@ test_that("sparse tibble can be passed to `fit()", {
2121
)
2222
})
2323

24+
test_that("sparse matrix can be passed to `fit()", {
25+
skip_if_not_installed("xgboost")
26+
27+
hotel_data <- sparse_hotel_rates()
28+
29+
spec <- boost_tree() %>%
30+
set_mode("regression") %>%
31+
set_engine("xgboost")
32+
33+
expect_no_error(
34+
lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data)
35+
)
36+
37+
spec <- linear_reg() %>%
38+
set_mode("regression") %>%
39+
set_engine("lm")
40+
41+
expect_snapshot(
42+
lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data[1:100, ])
43+
)
44+
})
45+
2446
test_that("sparse tibble can be passed to `fit_xy()", {
2547
skip_if_not_installed("xgboost")
2648

0 commit comments

Comments
 (0)