Skip to content

Commit 31506c7

Browse files
committed
switch to fit_xy() if sparse matrix is passed to fit()
1 parent e9ff997 commit 31506c7

File tree

3 files changed

+33
-23
lines changed

3 files changed

+33
-23
lines changed

R/fit.R

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,13 @@ fit.model_spec <-
175175
eval_env$weights <- wts
176176

177177
if (is_sparse_matrix(data)) {
178-
cli::cli_abort(c(
179-
x = "Sparse matrices cannot be used with {.fn fit}.",
180-
i = "Please use {.fn fit_xy} interface instead."
181-
))
178+
outcome_names <- all.names(rlang::f_lhs(formula))
179+
outcome_ind <- match(outcome_names, colnames(data))
180+
181+
y <- data[, outcome_ind]
182+
x <- data[, -outcome_ind, drop = TRUE]
183+
184+
return(fit_xy(object, x, y, case_weights, control, ...))
182185
}
183186

184187
data <- materialize_sparse_tibble(data, object, "data")

tests/testthat/_snaps/sparsevctrs.md

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,14 @@
66
Warning:
77
`data` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse.
88

9+
# sparse matrix can be passed to `fit()
10+
11+
Code
12+
lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data[1:100, ])
13+
Condition
14+
Error in `fit_xy()`:
15+
! `x` is a sparse matrix, but `linear_reg()` with engine "lm" doesn't accept that.
16+
917
# sparse tibble can be passed to `fit_xy()
1018

1119
Code
@@ -22,15 +30,6 @@
2230
Error in `fit_xy()`:
2331
! `x` is a sparse matrix, but `linear_reg()` with engine "lm" doesn't accept that.
2432

25-
# sparse matrices can not be passed to `fit()
26-
27-
Code
28-
hotel_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data)
29-
Condition
30-
Error in `fit()`:
31-
x Sparse matrices cannot be used with `fit()`.
32-
i Please use `fit_xy()` interface instead.
33-
3433
# sparse tibble can be passed to `predict()
3534

3635
Code

tests/testthat/test-sparsevctrs.R

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,33 +21,34 @@ test_that("sparse tibble can be passed to `fit()", {
2121
)
2222
})
2323

24-
test_that("sparse tibble can be passed to `fit_xy()", {
24+
test_that("sparse matrix can be passed to `fit()", {
2525
skip_if_not_installed("xgboost")
2626

2727
hotel_data <- sparse_hotel_rates()
28-
hotel_data <- sparsevctrs::coerce_to_sparse_tibble(hotel_data)
2928

3029
spec <- boost_tree() %>%
3130
set_mode("regression") %>%
3231
set_engine("xgboost")
3332

3433
expect_no_error(
35-
lm_fit <- fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1])
34+
lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data)
3635
)
3736

3837
spec <- linear_reg() %>%
3938
set_mode("regression") %>%
4039
set_engine("lm")
4140

4241
expect_snapshot(
43-
lm_fit <- fit_xy(spec, x = hotel_data[1:100, -1], y = hotel_data[1:100, 1])
42+
error = TRUE,
43+
lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data[1:100, ])
4444
)
4545
})
4646

47-
test_that("sparse matrices can be passed to `fit_xy()", {
47+
test_that("sparse tibble can be passed to `fit_xy()", {
4848
skip_if_not_installed("xgboost")
4949

5050
hotel_data <- sparse_hotel_rates()
51+
hotel_data <- sparsevctrs::coerce_to_sparse_tibble(hotel_data)
5152

5253
spec <- boost_tree() %>%
5354
set_mode("regression") %>%
@@ -62,12 +63,11 @@ test_that("sparse matrices can be passed to `fit_xy()", {
6263
set_engine("lm")
6364

6465
expect_snapshot(
65-
lm_fit <- fit_xy(spec, x = hotel_data[1:100, -1], y = hotel_data[1:100, 1]),
66-
error = TRUE
66+
lm_fit <- fit_xy(spec, x = hotel_data[1:100, -1], y = hotel_data[1:100, 1])
6767
)
6868
})
6969

70-
test_that("sparse matrices can not be passed to `fit()", {
70+
test_that("sparse matrices can be passed to `fit_xy()", {
7171
skip_if_not_installed("xgboost")
7272

7373
hotel_data <- sparse_hotel_rates()
@@ -76,9 +76,17 @@ test_that("sparse matrices can not be passed to `fit()", {
7676
set_mode("regression") %>%
7777
set_engine("xgboost")
7878

79+
expect_no_error(
80+
lm_fit <- fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1])
81+
)
82+
83+
spec <- linear_reg() %>%
84+
set_mode("regression") %>%
85+
set_engine("lm")
86+
7987
expect_snapshot(
80-
error = TRUE,
81-
hotel_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data)
88+
lm_fit <- fit_xy(spec, x = hotel_data[1:100, -1], y = hotel_data[1:100, 1]),
89+
error = TRUE
8290
)
8391
})
8492

0 commit comments

Comments
 (0)