Skip to content

Commit ee97f9e

Browse files
authored
Merge branch 'master' into Roxygen-dev
2 parents d2589aa + 3fef9c3 commit ee97f9e

File tree

3 files changed

+12
-1
lines changed

3 files changed

+12
-1
lines changed

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88

99
* The model `udpate()` methods gained a `parameters` argument for cases when the parameters are contained in a tibble or list.
1010

11+
# [A bug](https://github.com/tidymodels/parsnip/issues/174) was fixed standardizing the column names of `nnet` class probability predictions.
12+
13+
1114
# parsnip 0.0.3.1
1215

1316
Test case update due to CRAN running extra tests [(#202)](https://github.com/tidymodels/parsnip/issues/202)

R/mlp.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ nnet_softmax <- function(results, object) {
389389

390390
results <- apply(results, 1, function(x) exp(x)/sum(exp(x)))
391391
results <- t(results)
392-
names(results) <- paste0(".pred_", object$lvl)
392+
colnames(results) <- object$lvl
393393
results <- as_tibble(results)
394394
results
395395
}

tests/testthat/test_mlp.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,11 @@ test_that('bad input', {
170170
expect_error(translate(mlp(mode = "regression", formula = y ~ x) %>% set_engine()))
171171
})
172172

173+
test_that("nnet_softmax", {
174+
obj <- mlp(mode = 'classification')
175+
obj$lvls <- c("a", "b")
176+
res <- nnet_softmax(matrix(c(.8, .2)), obj)
177+
expect_equal(names(res), obj$lvls)
178+
expect_equal(res$b, 1 - res$a)
179+
})
180+

0 commit comments

Comments
 (0)