Skip to content

Commit d32fc65

Browse files
Merge pull request #1173 from tidymodels/xgboost-sparse-data
make sure xgboost works with sparse data
2 parents 8af5ddf + c6b1bd1 commit d32fc65

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

R/convert_data.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,10 @@
124124
options = options
125125
)
126126
} else if (composition == "dgCMatrix") {
127+
y_cols <- attr(mod_terms, "response")
128+
if (length(y_cols) > 0) {
129+
data <- data[, -y_cols, drop = FALSE]
130+
}
127131
x <- sparsevctrs::coerce_to_sparse_matrix(data)
128132
res <-
129133
list(

tests/testthat/test-sparsevctrs.R

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,43 @@ test_that("sparse matrices can be passed to `predict()", {
149149
)
150150
})
151151

152+
test_that("sparse data work with xgboost engine", {
153+
skip_if_not_installed("xgboost")
154+
155+
spec <- boost_tree() %>%
156+
set_mode("regression") %>%
157+
set_engine("xgboost")
158+
159+
hotel_data <- sparse_hotel_rates()
160+
161+
expect_no_error(
162+
tree_fit <- fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1])
163+
)
164+
165+
expect_no_error(
166+
predict(tree_fit, hotel_data)
167+
)
168+
169+
hotel_data <- sparsevctrs::coerce_to_sparse_tibble(hotel_data)
170+
171+
172+
expect_no_error(
173+
tree_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data)
174+
)
175+
176+
expect_no_error(
177+
predict(tree_fit, hotel_data)
178+
)
179+
180+
expect_no_error(
181+
tree_fit <- fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1])
182+
)
183+
184+
expect_no_error(
185+
predict(tree_fit, hotel_data)
186+
)
187+
})
188+
152189
test_that("to_sparse_data_frame() is used correctly", {
153190
skip_if_not_installed("xgboost")
154191

0 commit comments

Comments
 (0)