Skip to content

Commit bed7523

Browse files
committed
standardize class prediction names for nnet
1 parent e039f07 commit bed7523

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

R/mlp.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ nnet_softmax <- function(results, object) {
381381

382382
results <- apply(results, 1, function(x) exp(x)/sum(exp(x)))
383383
results <- t(results)
384-
names(results) <- paste0(".pred_", object$lvl)
384+
colnames(results) <- object$lvl
385385
results <- as_tibble(results)
386386
results
387387
}

tests/testthat/test_predict_formats.R

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ lr_fit_2 <-
2828
set_engine("glm") %>%
2929
fit(Ozone ~ ., data = class_dat2)
3030

31+
lr_fit_3 <-
32+
mlp(mode = 'classification') %>%
33+
set_engine("nnet") %>%
34+
fit(Ozone ~ ., data = class_dat2[1:5, ])
35+
36+
3137
# ------------------------------------------------------------------------------
3238

3339
test_that('regression predictions', {
@@ -54,8 +60,11 @@ test_that('non-standard levels', {
5460

5561
expect_true(is_tibble(predict(lr_fit_2, new_data = class_dat2[1:5,-1], type = "prob")))
5662
expect_true(is_tibble(parsnip:::predict_classprob.model_fit(lr_fit_2, new_data = class_dat2[1:5,-1])))
63+
final_colnames <- c(".pred_2low", ".pred_high+values")
5764
expect_equal(names(predict(lr_fit_2, new_data = class_dat2[1:5,-1], type = "prob")),
58-
c(".pred_2low", ".pred_high+values"))
65+
final_colnames)
66+
expect_equal(names(predict(lr_fit_3, new_data = class_dat2, type = 'prob')),
67+
final_colnames)
5968
expect_equal(names(parsnip:::predict_classprob.model_fit(lr_fit_2, new_data = class_dat2[1:5,-1])),
6069
c("2low", "high+values"))
6170
})

0 commit comments

Comments
 (0)