Skip to content

Commit c6b1bd1

Browse files
committed
make sure xgboost works with sparse data
1 parent a9aadfb commit c6b1bd1

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

R/convert_data.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,10 @@
124124
options = options
125125
)
126126
} else if (composition == "dgCMatrix") {
127+
y_cols <- attr(mod_terms, "response")
128+
if (length(y_cols) > 0) {
129+
data <- data[, -y_cols, drop = FALSE]
130+
}
127131
x <- sparsevctrs::coerce_to_sparse_matrix(data)
128132
res <-
129133
list(

tests/testthat/test-sparsevctrs.R

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,43 @@ test_that("sparse matrices can be passed to `predict()", {
127127
)
128128
})
129129

130+
test_that("sparse data work with xgboost engine", {
131+
skip_if_not_installed("xgboost")
132+
133+
spec <- boost_tree() %>%
134+
set_mode("regression") %>%
135+
set_engine("xgboost")
136+
137+
hotel_data <- sparse_hotel_rates()
138+
139+
expect_no_error(
140+
tree_fit <- fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1])
141+
)
142+
143+
expect_no_error(
144+
predict(tree_fit, hotel_data)
145+
)
146+
147+
hotel_data <- sparsevctrs::coerce_to_sparse_tibble(hotel_data)
148+
149+
150+
expect_no_error(
151+
tree_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data)
152+
)
153+
154+
expect_no_error(
155+
predict(tree_fit, hotel_data)
156+
)
157+
158+
expect_no_error(
159+
tree_fit <- fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1])
160+
)
161+
162+
expect_no_error(
163+
predict(tree_fit, hotel_data)
164+
)
165+
})
166+
130167
test_that("to_sparse_data_frame() is used correctly", {
131168
skip_if_not_installed("xgboost")
132169

0 commit comments

Comments
 (0)