@@ -28,6 +28,12 @@ lr_fit_2 <-
28
28
set_engine(" glm" ) %> %
29
29
fit(Ozone ~ . , data = class_dat2 )
30
30
31
+ lr_fit_3 <-
32
+ mlp(mode = ' classification' ) %> %
33
+ set_engine(" nnet" ) %> %
34
+ fit(Ozone ~ . , data = class_dat2 [1 : 5 , ])
35
+
36
+
31
37
# ------------------------------------------------------------------------------
32
38
33
39
test_that(' regression predictions' , {
@@ -54,8 +60,11 @@ test_that('non-standard levels', {
54
60
55
61
expect_true(is_tibble(predict(lr_fit_2 , new_data = class_dat2 [1 : 5 ,- 1 ], type = " prob" )))
56
62
expect_true(is_tibble(parsnip ::: predict_classprob.model_fit(lr_fit_2 , new_data = class_dat2 [1 : 5 ,- 1 ])))
63
+ final_colnames <- c(" .pred_2low" , " .pred_high+values" )
57
64
expect_equal(names(predict(lr_fit_2 , new_data = class_dat2 [1 : 5 ,- 1 ], type = " prob" )),
58
- c(" .pred_2low" , " .pred_high+values" ))
65
+ final_colnames )
66
+ expect_equal(names(predict(lr_fit_3 , new_data = class_dat2 , type = ' prob' )),
67
+ final_colnames )
59
68
expect_equal(names(parsnip ::: predict_classprob.model_fit(lr_fit_2 , new_data = class_dat2 [1 : 5 ,- 1 ])),
60
69
c(" 2low" , " high+values" ))
61
70
})
0 commit comments