Skip to content

Commit 5e32de2

Browse files
authored
Merge pull request #585 from sebasv/blas-gemv-error
Fix blas mat-vec multiplication on array with only 1 nontrivial dimension
2 parents 35065cb + a5fe624 commit 5e32de2

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

blas-tests/tests/oper.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,14 @@ fn dot_product() {
6363
assert_eq!(a.dot(&b), dot as i32);
6464
}
6565

66+
#[test]
67+
fn mat_vec_product_1d() {
68+
let a = arr2(&[[1.], [2.]]);
69+
let b = arr1(&[1., 2.]);
70+
let ans = arr1(&[5.]);
71+
assert_eq!(a.t().dot(&b), ans);
72+
}
73+
6674
// test that we can dot product with a broadcast array
6775
#[test]
6876
fn dot_product_0() {

src/linalg/impl_linalg.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -611,8 +611,12 @@ pub fn general_mat_vec_mul<A, S1, S2, S3>(
611611
if blas_compat_1d::<$ty, _>(&x) && blas_compat_1d::<$ty, _>(&y) {
612612
let a_trans = CblasNoTrans;
613613
let a_stride = match layout {
614-
CBLAS_LAYOUT::CblasRowMajor => a.strides()[0] as blas_index,
615-
CBLAS_LAYOUT::CblasColMajor => a.strides()[1] as blas_index,
614+
CBLAS_LAYOUT::CblasRowMajor => {
615+
a.strides()[0].max(k as isize) as blas_index
616+
}
617+
CBLAS_LAYOUT::CblasColMajor => {
618+
a.strides()[1].max(m as isize) as blas_index
619+
}
616620
};
617621

618622
let x_stride = x.strides()[0] as blas_index;

0 commit comments

Comments
 (0)