|
| 1 | +use crate::{ |
| 2 | + api::Predictor, |
| 3 | + error::{Failed, FailedError}, |
| 4 | + linalg::Matrix, |
| 5 | + math::num::RealNumber, |
| 6 | +}; |
| 7 | + |
| 8 | +use crate::model_selection::{cross_validate, BaseKFold, CrossValidationResult}; |
| 9 | + |
1 | 10 | /// grid search results.
|
2 | 11 | #[derive(Clone, Debug)]
|
3 | 12 | pub struct GridSearchResult<T: RealNumber, I: Clone> {
|
@@ -60,58 +69,61 @@ where
|
60 | 69 |
|
61 | 70 | #[cfg(test)]
|
62 | 71 | mod tests {
|
63 |
| - use crate::linear::logistic_regression::{ |
64 |
| - LogisticRegression, LogisticRegressionSearchParameters, |
65 |
| -}; |
| 72 | + use crate::{ |
| 73 | + linalg::naive::dense_matrix::DenseMatrix, |
| 74 | + linear::logistic_regression::{LogisticRegression, LogisticRegressionSearchParameters}, |
| 75 | + metrics::accuracy, |
| 76 | + model_selection::{hyper_tuning::grid_search, KFold}, |
| 77 | + }; |
66 | 78 |
|
67 |
| - #[test] |
68 |
| - fn test_grid_search() { |
69 |
| - let x = DenseMatrix::from_2d_array(&[ |
70 |
| - &[5.1, 3.5, 1.4, 0.2], |
71 |
| - &[4.9, 3.0, 1.4, 0.2], |
72 |
| - &[4.7, 3.2, 1.3, 0.2], |
73 |
| - &[4.6, 3.1, 1.5, 0.2], |
74 |
| - &[5.0, 3.6, 1.4, 0.2], |
75 |
| - &[5.4, 3.9, 1.7, 0.4], |
76 |
| - &[4.6, 3.4, 1.4, 0.3], |
77 |
| - &[5.0, 3.4, 1.5, 0.2], |
78 |
| - &[4.4, 2.9, 1.4, 0.2], |
79 |
| - &[4.9, 3.1, 1.5, 0.1], |
80 |
| - &[7.0, 3.2, 4.7, 1.4], |
81 |
| - &[6.4, 3.2, 4.5, 1.5], |
82 |
| - &[6.9, 3.1, 4.9, 1.5], |
83 |
| - &[5.5, 2.3, 4.0, 1.3], |
84 |
| - &[6.5, 2.8, 4.6, 1.5], |
85 |
| - &[5.7, 2.8, 4.5, 1.3], |
86 |
| - &[6.3, 3.3, 4.7, 1.6], |
87 |
| - &[4.9, 2.4, 3.3, 1.0], |
88 |
| - &[6.6, 2.9, 4.6, 1.3], |
89 |
| - &[5.2, 2.7, 3.9, 1.4], |
90 |
| - ]); |
91 |
| - let y = vec![ |
92 |
| - 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., |
93 |
| - ]; |
| 79 | + #[test] |
| 80 | + fn test_grid_search() { |
| 81 | + let x = DenseMatrix::from_2d_array(&[ |
| 82 | + &[5.1, 3.5, 1.4, 0.2], |
| 83 | + &[4.9, 3.0, 1.4, 0.2], |
| 84 | + &[4.7, 3.2, 1.3, 0.2], |
| 85 | + &[4.6, 3.1, 1.5, 0.2], |
| 86 | + &[5.0, 3.6, 1.4, 0.2], |
| 87 | + &[5.4, 3.9, 1.7, 0.4], |
| 88 | + &[4.6, 3.4, 1.4, 0.3], |
| 89 | + &[5.0, 3.4, 1.5, 0.2], |
| 90 | + &[4.4, 2.9, 1.4, 0.2], |
| 91 | + &[4.9, 3.1, 1.5, 0.1], |
| 92 | + &[7.0, 3.2, 4.7, 1.4], |
| 93 | + &[6.4, 3.2, 4.5, 1.5], |
| 94 | + &[6.9, 3.1, 4.9, 1.5], |
| 95 | + &[5.5, 2.3, 4.0, 1.3], |
| 96 | + &[6.5, 2.8, 4.6, 1.5], |
| 97 | + &[5.7, 2.8, 4.5, 1.3], |
| 98 | + &[6.3, 3.3, 4.7, 1.6], |
| 99 | + &[4.9, 2.4, 3.3, 1.0], |
| 100 | + &[6.6, 2.9, 4.6, 1.3], |
| 101 | + &[5.2, 2.7, 3.9, 1.4], |
| 102 | + ]); |
| 103 | + let y = vec![ |
| 104 | + 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., |
| 105 | + ]; |
94 | 106 |
|
95 |
| - let cv = KFold { |
96 |
| - n_splits: 5, |
97 |
| - ..KFold::default() |
98 |
| - }; |
| 107 | + let cv = KFold { |
| 108 | + n_splits: 5, |
| 109 | + ..KFold::default() |
| 110 | + }; |
99 | 111 |
|
100 |
| - let parameters = LogisticRegressionSearchParameters { |
101 |
| - alpha: vec![0., 1.], |
102 |
| - ..Default::default() |
103 |
| - }; |
| 112 | + let parameters = LogisticRegressionSearchParameters { |
| 113 | + alpha: vec![0., 1.], |
| 114 | + ..Default::default() |
| 115 | + }; |
104 | 116 |
|
105 |
| - let results = grid_search( |
106 |
| - LogisticRegression::fit, |
107 |
| - &x, |
108 |
| - &y, |
109 |
| - parameters.into_iter(), |
110 |
| - cv, |
111 |
| - &accuracy, |
112 |
| - ) |
113 |
| - .unwrap(); |
| 117 | + let results = grid_search( |
| 118 | + LogisticRegression::fit, |
| 119 | + &x, |
| 120 | + &y, |
| 121 | + parameters.into_iter(), |
| 122 | + cv, |
| 123 | + &accuracy, |
| 124 | + ) |
| 125 | + .unwrap(); |
114 | 126 |
|
115 |
| - assert!([0., 1.].contains(&results.parameters.alpha)); |
116 |
| - } |
| 127 | + assert!([0., 1.].contains(&results.parameters.alpha)); |
| 128 | + } |
117 | 129 | }
|
0 commit comments