Skip to content

Commit 7ee3c80

Browse files
Fix feature management and README.
Use references where possible.
1 parent 389e167 commit 7ee3c80

File tree

4 files changed

+15
-7
lines changed

4 files changed

+15
-7
lines changed

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
A collection of examples leveraging the `ndarray` ecosystem.
44

5+
Each example folder contains a description and instructions on how to run it. Do not run `cargo run` or `cargo build` from the top level folder!
6+
57
Table of contents:
68

79
- [Linear regression](https://github.com/rust-ndarray/ndarray-examples/tree/master/linear_regression)

linear_regression/Cargo.toml

+7-1
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,15 @@ authors = ["LukeMathWalker"]
55
edition = "2018"
66
workspace = ".."
77

8+
[features]
9+
default = []
10+
openblas = ["ndarray-linalg/openblas"]
11+
intel-mkl = ["ndarray-linalg/intel-mkl"]
12+
netlib = ["ndarray-linalg/netlib"]
13+
814
[dependencies]
915
ndarray = {version = "0.12", features = ["blas"]}
10-
ndarray-linalg = {version = "0.11.1", features = ["openblas"]}
16+
ndarray-linalg = {version = "0.11.1", optional = true, default-features = false}
1117
ndarray-stats = {git = "https://github.com/rust-ndarray/ndarray-stats", branch = "master"}
1218
ndarray-rand = "0.9"
1319
rand = "0.6"

linear_regression/src/lib.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ impl LinearRegression {
3636
/// to match the training data distribution.
3737
///
3838
/// `self` is modified in place, nothing is returned.
39-
pub fn fit<A, B>(&mut self, X: ArrayBase<A, Ix2>, y: ArrayBase<B, Ix1>)
39+
pub fn fit<A, B>(&mut self, X: &ArrayBase<A, Ix2>, y: &ArrayBase<B, Ix1>)
4040
where
4141
A: Data<Elem = f64>,
4242
B: Data<Elem = f64>,
@@ -50,7 +50,7 @@ impl LinearRegression {
5050
self.beta = if self.fit_intercept {
5151
let dummy_column: Array<f64, _> = Array::ones((n_samples, 1));
5252
let X = stack(Axis(1), &[dummy_column.view(), X.view()]).unwrap();
53-
Some(LinearRegression::solve_normal_equation(X, y))
53+
Some(LinearRegression::solve_normal_equation(&X, y))
5454
} else {
5555
Some(LinearRegression::solve_normal_equation(X, y))
5656
};
@@ -77,13 +77,13 @@ impl LinearRegression {
7777
}
7878
}
7979

80-
fn solve_normal_equation<A, B>(X: ArrayBase<A, Ix2>, y: ArrayBase<B, Ix1>) -> Array1<f64>
80+
fn solve_normal_equation<A, B>(X: &ArrayBase<A, Ix2>, y: &ArrayBase<B, Ix1>) -> Array1<f64>
8181
where
8282
A: Data<Elem = f64>,
8383
B: Data<Elem = f64>,
8484
{
85-
let rhs = X.t().dot(&y);
86-
let linear_operator = X.t().dot(&X);
85+
let rhs = X.t().dot(y);
86+
let linear_operator = X.t().dot(X);
8787
linear_operator.solve_into(rhs).unwrap()
8888
}
8989

linear_regression/src/main.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ pub fn main() {
3333
let (y_train, y_test) = y.view().split_at(Axis(0), n_train_samples);
3434

3535
let mut linear_regressor = LinearRegression::new(false);
36-
linear_regressor.fit(X_train, y_train);
36+
linear_regressor.fit(&X_train, &y_train);
3737

3838
let test_predictions = linear_regressor.predict(&X_test);
3939
let mean_squared_error = test_predictions.mean_sq_err(&y_test.to_owned()).unwrap();

0 commit comments

Comments
 (0)