File tree Expand file tree Collapse file tree 3 files changed +12
-1
lines changed Expand file tree Collapse file tree 3 files changed +12
-1
lines changed Original file line number Diff line number Diff line change 8
8
9
9
* The model ` udpate() ` methods gained a ` parameters ` argument for cases when the parameters are contained in a tibble or list.
10
10
11
+ # [ A bug] ( https://github.com/tidymodels/parsnip/issues/174 ) was fixed standardizing the column names of ` nnet ` class probability predictions.
12
+
13
+
11
14
# parsnip 0.0.3.1
12
15
13
16
Test case update due to CRAN running extra tests [ (#202 )] ( https://github.com/tidymodels/parsnip/issues/202 )
Original file line number Diff line number Diff line change @@ -389,7 +389,7 @@ nnet_softmax <- function(results, object) {
389
389
390
390
results <- apply(results , 1 , function (x ) exp(x )/ sum(exp(x )))
391
391
results <- t(results )
392
- names (results ) <- paste0( " .pred_ " , object $ lvl )
392
+ colnames (results ) <- object $ lvl
393
393
results <- as_tibble(results )
394
394
results
395
395
}
Original file line number Diff line number Diff line change @@ -170,3 +170,11 @@ test_that('bad input', {
170
170
expect_error(translate(mlp(mode = " regression" , formula = y ~ x ) %> % set_engine()))
171
171
})
172
172
173
+ test_that(" nnet_softmax" , {
174
+ obj <- mlp(mode = ' classification' )
175
+ obj $ lvls <- c(" a" , " b" )
176
+ res <- nnet_softmax(matrix (c(.8 , .2 )), obj )
177
+ expect_equal(names(res ), obj $ lvls )
178
+ expect_equal(res $ b , 1 - res $ a )
179
+ })
180
+
You can’t perform that action at this time.
0 commit comments