Skip to content

Commit

Permalink
Merge pull request #11 from Masterchef365/master
Browse files Browse the repository at this point in the history
Handling malformed matrices gracefully
  • Loading branch information
RLado authored Feb 20, 2025
2 parents 66f361c + 4653647 commit 494a3fb
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 26 deletions.
42 changes: 32 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,25 @@
pub mod data;
use data::{Nmrc, Sprs, Symb};

#[derive(Copy, Clone, Debug)]
pub enum Error {
/// Cholesky factorization failed (not positive definite)
NotPositiveDefinite,
/// LU factorization failed (no pivot found)
NoPivot,
}

impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NoPivot => write!(f, "Could not find a pivot"),
Self::NotPositiveDefinite => write!(f, "Could not complete Cholesky factorization. Please provide a positive definite matrix"),
}
}
}

impl std::error::Error for Error {}

// --- Public functions --------------------------------------------------------

/// C = alpha * A + beta * B
Expand Down Expand Up @@ -253,7 +272,7 @@ pub fn add(a: &Sprs, b: &Sprs, alpha: f64, beta: f64) -> Sprs {
///
/// See: `schol(...)`
///
pub fn chol(a: &Sprs, s: &mut Symb) -> Nmrc {
pub fn chol(a: &Sprs, s: &mut Symb) -> Result<Nmrc, Error> {
let mut top;
let mut d;
let mut lki;
Expand Down Expand Up @@ -302,7 +321,7 @@ pub fn chol(a: &Sprs, s: &mut Symb) -> Nmrc {
// --- Compute L(k,k) -----------------------------------------------
if d <= 0. {
// not pos def
panic!("Could not complete Cholesky factorization. Please provide a positive definite matrix");
return Err(Error::NotPositiveDefinite);
}
let p = w[wc + k];
w[wc + k] += 1;
Expand All @@ -311,7 +330,7 @@ pub fn chol(a: &Sprs, s: &mut Symb) -> Nmrc {
}
n_mat.l.p[n] = s.cp[n]; // finalize L

n_mat
Ok(n_mat)
}

/// A\b solver using Cholesky factorization.
Expand Down Expand Up @@ -351,16 +370,18 @@ pub fn chol(a: &Sprs, s: &mut Symb) -> Nmrc {
/// println!("{:?}", &b);
/// ```
///
pub fn cholsol(a: &Sprs, b: &mut [f64], order: i8) {
pub fn cholsol(a: &Sprs, b: &mut [f64], order: i8) -> Result<(), Error> {
let n = a.n;
let mut s = schol(a, order); // ordering and symbolic analysis
let n_mat = chol(a, &mut s); // numeric Cholesky factorization
let n_mat = chol(a, &mut s)?; // numeric Cholesky factorization
let mut x = vec![0.; n];

ipvec(n, &s.pinv, b, &mut x[..]); // x = P*b
lsolve(&n_mat.l, &mut x); // x = L\x
ltsolve(&n_mat.l, &mut x); // x = L'\x
pvec(n, &s.pinv, &x[..], &mut b[..]); // b = P'*x

Ok(())
}

/// Generalized A times X Plus Y
Expand Down Expand Up @@ -490,7 +511,7 @@ pub fn ltsolve(l: &Sprs, x: &mut [f64]) {
///
/// See: `sqr(...)`
///
pub fn lu(a: &Sprs, s: &mut Symb, tol: f64) -> Nmrc {
pub fn lu(a: &Sprs, s: &mut Symb, tol: f64) -> Result<Nmrc, Error> {
let n = a.n;
let mut col;
let mut top;
Expand Down Expand Up @@ -556,7 +577,7 @@ pub fn lu(a: &Sprs, s: &mut Symb, tol: f64) -> Nmrc {
}
}
if ipiv == -1 || a_f <= 0. {
panic!("Could not find a pivot");
return Err(Error::NoPivot);
}
if n_mat.pinv.as_ref().unwrap()[col] < 0 && f64::abs(x[col]) >= a_f * tol {
ipiv = col as isize;
Expand Down Expand Up @@ -592,7 +613,7 @@ pub fn lu(a: &Sprs, s: &mut Symb, tol: f64) -> Nmrc {
n_mat.l.quick_trim();
n_mat.u.quick_trim();

n_mat
Ok(n_mat)
}

/// A\b solver using LU factorization.
Expand Down Expand Up @@ -642,16 +663,17 @@ pub fn lu(a: &Sprs, s: &mut Symb, tol: f64) -> Nmrc {
/// ```
///
pub fn lusol(a: &Sprs, b: &mut [f64], order: i8, tol: f64) {
pub fn lusol(a: &Sprs, b: &mut [f64], order: i8, tol: f64) -> Result<(), Error> {
let mut x = vec![0.; a.n];
let mut s;
s = sqr(a, order, false); // ordering and symbolic analysis
let n = lu(a, &mut s, tol); // numeric LU factorization
let n = lu(a, &mut s, tol)?; // numeric LU factorization

ipvec(a.n, &n.pinv, b, &mut x[..]); // x = P*b
lsolve(&n.l, &mut x); // x = L\x
usolve(&n.u, &mut x[..]); // x = U\x
ipvec(a.n, &s.q, &x[..], &mut b[..]); // b = Q*x
Ok(())
}

/// C = A * B
Expand Down
32 changes: 16 additions & 16 deletions tests/solver_tests.rs

Large diffs are not rendered by default.

0 comments on commit 494a3fb

Please sign in to comment.