automl
automates model selection and training on top of the smartcore
machine learning library, helping Rust developers quickly build regression, classification, and clustering models.
Install from crates.io or use the GitHub repository for the latest changes:
# Cargo.toml
[dependencies]
automl = "0.2.9"
# Cargo.toml
[dependencies]
automl = { git = "https://github.com/cmccomb/rust-automl" }
use automl::{RegressionModel, RegressionSettings};
use smartcore::linalg::basic::matrix::DenseMatrix;
let x = DenseMatrix::from_2d_vec(&vec![
vec![1.0_f64, 2.0, 3.0],
vec![2.0, 3.0, 4.0],
vec![3.0, 4.0, 5.0],
]).unwrap();
let y = vec![1.0_f64, 2.0, 3.0];
let _model = RegressionModel::new(x, y, RegressionSettings::default());
Use load_labeled_csv
to read a dataset and separate the target column:
use automl::{RegressionModel, RegressionSettings};
use automl::utils::load_labeled_csv;
let (x, y) = load_labeled_csv("tests/fixtures/supervised_sample.csv", 2).unwrap();
let mut model = RegressionModel::new(x, y, RegressionSettings::default());
Use load_csv_features
to read unlabeled data for clustering:
use automl::{ClusteringModel};
use automl::settings::ClusteringSettings;
use automl::utils::load_csv_features;
let x = load_csv_features("tests/fixtures/clustering_points.csv").unwrap();
let mut model = ClusteringModel::new(x.clone(), ClusteringSettings::default().with_k(2));
model.train();
let clusters: Vec<u8> = model.predict(&x).unwrap();
use automl::{ClassificationModel};
use automl::settings::{ClassificationSettings, RandomForestClassifierParameters};
use smartcore::linalg::basic::matrix::DenseMatrix;
let x = DenseMatrix::from_2d_vec(&vec![
vec![0.0_f64, 0.0],
vec![1.0, 1.0],
vec![1.0, 0.0],
vec![0.0, 1.0],
]).unwrap();
let y = vec![0_u32, 1, 1, 0];
let settings = ClassificationSettings::default()
.with_random_forest_classifier_settings(
RandomForestClassifierParameters::default().with_n_trees(10),
);
let _model = ClassificationModel::new(x, y, settings);
use automl::ClusteringModel;
use automl::settings::ClusteringSettings;
use smartcore::linalg::basic::matrix::DenseMatrix;
let x = DenseMatrix::from_2d_vec(&vec![
vec![1.0_f64, 1.0],
vec![1.2, 0.8],
vec![8.0, 8.0],
vec![8.2, 8.2],
]).unwrap();
let mut model = ClusteringModel::new(x.clone(), ClusteringSettings::default().with_k(2));
model.train();
let truth = vec![1_u8, 1, 2, 2];
model.evaluate(&truth);
println!("{model}");
let _clusters: Vec<u8> = model.predict(&x).expect("prediction");
Additional runnable examples are available in the examples/ directory, including minimal_classification.rs, maximal_classification.rs, minimal_regression.rs, maximal_regression.rs, minimal_clustering.rs, and maximal_clustering.rs.
Model comparison:
βββββββββββββββββββββββββββββββββ¬ββββββββββββββββββββββ¬ββββββββββββββββββββ¬βββββββββββββββββββ
β Model β Time β Training Accuracy β Testing Accuracy β
βββββββββββββββββββββββββββββββββͺββββββββββββββββββββββͺββββββββββββββββββββͺβββββββββββββββββββ‘
β Random Forest Classifier β 835ms 393us 583ns β 1.00 β 0.96 β
βββββββββββββββββββββββββββββββββΌββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββββββββ€
β Decision Tree Classifier β 15ms 404us 750ns β 1.00 β 0.93 β
βββββββββββββββββββββββββββββββββΌββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββββββββ€
β KNN Classifier β 28ms 874us 208ns β 0.96 β 0.92 β
βββββββββββββββββββββββββββββββββ΄ββββββββββββββββββββββ΄ββββββββββββββββββββ΄βββββββββββββββββββ
- Feature Engineering: PCA, SVD, interaction terms, polynomial terms
- Regression: Decision Tree, KNN, Random Forest, Linear, Ridge, LASSO, Elastic Net, Support Vector Regression
- Classification: Random Forest, Decision Tree, KNN, Logistic Regression, Gaussian Naive Bayes
- Clustering: K-Means, Agglomerative, DBSCAN
- Meta-learning: Blending (experimental)
- Persistence: Save/load settings and models
Before submitting changes, run:
cargo fmt --all -- --check
cargo clippy --all-targets -- -D warnings
cargo test
cargo audit
cargo test --doc
Security audits run weekly via a scheduled workflow, but running cargo audit
locally before submitting changes helps catch issues earlier.
Pull requests are welcome!
Licensed under the MIT OR Apache-2.0 license.