Skip to content

Commit 3c62686

Browse files
authored
feat: expose hyper tuning module in model_selection (#179)
* feat: expose hyper tuning module in model_selection * Move to a folder Co-authored-by: Luis Moreno <[email protected]>
1 parent 9c59e37 commit 3c62686

File tree

3 files changed

+65
-49
lines changed

3 files changed

+65
-49
lines changed

src/model_selection/hyper_tuning.rs renamed to src/model_selection/hyper_tuning/grid_search.rs

Lines changed: 61 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
use crate::{
2+
api::Predictor,
3+
error::{Failed, FailedError},
4+
linalg::Matrix,
5+
math::num::RealNumber,
6+
};
7+
8+
use crate::model_selection::{cross_validate, BaseKFold, CrossValidationResult};
9+
110
/// grid search results.
211
#[derive(Clone, Debug)]
312
pub struct GridSearchResult<T: RealNumber, I: Clone> {
@@ -60,58 +69,61 @@ where
6069

6170
#[cfg(test)]
6271
mod tests {
63-
use crate::linear::logistic_regression::{
64-
LogisticRegression, LogisticRegressionSearchParameters,
65-
};
72+
use crate::{
73+
linalg::naive::dense_matrix::DenseMatrix,
74+
linear::logistic_regression::{LogisticRegression, LogisticRegressionSearchParameters},
75+
metrics::accuracy,
76+
model_selection::{hyper_tuning::grid_search, KFold},
77+
};
6678

67-
#[test]
68-
fn test_grid_search() {
69-
let x = DenseMatrix::from_2d_array(&[
70-
&[5.1, 3.5, 1.4, 0.2],
71-
&[4.9, 3.0, 1.4, 0.2],
72-
&[4.7, 3.2, 1.3, 0.2],
73-
&[4.6, 3.1, 1.5, 0.2],
74-
&[5.0, 3.6, 1.4, 0.2],
75-
&[5.4, 3.9, 1.7, 0.4],
76-
&[4.6, 3.4, 1.4, 0.3],
77-
&[5.0, 3.4, 1.5, 0.2],
78-
&[4.4, 2.9, 1.4, 0.2],
79-
&[4.9, 3.1, 1.5, 0.1],
80-
&[7.0, 3.2, 4.7, 1.4],
81-
&[6.4, 3.2, 4.5, 1.5],
82-
&[6.9, 3.1, 4.9, 1.5],
83-
&[5.5, 2.3, 4.0, 1.3],
84-
&[6.5, 2.8, 4.6, 1.5],
85-
&[5.7, 2.8, 4.5, 1.3],
86-
&[6.3, 3.3, 4.7, 1.6],
87-
&[4.9, 2.4, 3.3, 1.0],
88-
&[6.6, 2.9, 4.6, 1.3],
89-
&[5.2, 2.7, 3.9, 1.4],
90-
]);
91-
let y = vec![
92-
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
93-
];
79+
#[test]
80+
fn test_grid_search() {
81+
let x = DenseMatrix::from_2d_array(&[
82+
&[5.1, 3.5, 1.4, 0.2],
83+
&[4.9, 3.0, 1.4, 0.2],
84+
&[4.7, 3.2, 1.3, 0.2],
85+
&[4.6, 3.1, 1.5, 0.2],
86+
&[5.0, 3.6, 1.4, 0.2],
87+
&[5.4, 3.9, 1.7, 0.4],
88+
&[4.6, 3.4, 1.4, 0.3],
89+
&[5.0, 3.4, 1.5, 0.2],
90+
&[4.4, 2.9, 1.4, 0.2],
91+
&[4.9, 3.1, 1.5, 0.1],
92+
&[7.0, 3.2, 4.7, 1.4],
93+
&[6.4, 3.2, 4.5, 1.5],
94+
&[6.9, 3.1, 4.9, 1.5],
95+
&[5.5, 2.3, 4.0, 1.3],
96+
&[6.5, 2.8, 4.6, 1.5],
97+
&[5.7, 2.8, 4.5, 1.3],
98+
&[6.3, 3.3, 4.7, 1.6],
99+
&[4.9, 2.4, 3.3, 1.0],
100+
&[6.6, 2.9, 4.6, 1.3],
101+
&[5.2, 2.7, 3.9, 1.4],
102+
]);
103+
let y = vec![
104+
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
105+
];
94106

95-
let cv = KFold {
96-
n_splits: 5,
97-
..KFold::default()
98-
};
107+
let cv = KFold {
108+
n_splits: 5,
109+
..KFold::default()
110+
};
99111

100-
let parameters = LogisticRegressionSearchParameters {
101-
alpha: vec![0., 1.],
102-
..Default::default()
103-
};
112+
let parameters = LogisticRegressionSearchParameters {
113+
alpha: vec![0., 1.],
114+
..Default::default()
115+
};
104116

105-
let results = grid_search(
106-
LogisticRegression::fit,
107-
&x,
108-
&y,
109-
parameters.into_iter(),
110-
cv,
111-
&accuracy,
112-
)
113-
.unwrap();
117+
let results = grid_search(
118+
LogisticRegression::fit,
119+
&x,
120+
&y,
121+
parameters.into_iter(),
122+
cv,
123+
&accuracy,
124+
)
125+
.unwrap();
114126

115-
assert!([0., 1.].contains(&results.parameters.alpha));
116-
}
127+
assert!([0., 1.].contains(&results.parameters.alpha));
128+
}
117129
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
mod grid_search;
2+
pub use grid_search::{grid_search, GridSearchResult};

src/model_selection/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,10 @@ use crate::math::num::RealNumber;
110110
use crate::rand::get_rng_impl;
111111
use rand::seq::SliceRandom;
112112

113+
pub(crate) mod hyper_tuning;
113114
pub(crate) mod kfold;
114115

116+
pub use hyper_tuning::{grid_search, GridSearchResult};
115117
pub use kfold::{KFold, KFoldIter};
116118

117119
/// An interface for the K-Folds cross-validator

0 commit comments

Comments
 (0)