Skip to content

Commit 0fe046d

Browse files
committed
Use matrixmultiply cgemm and improve tests
Refactor tests to use generics in a bit smarter way. We test both .dot() and general_mat_mul. Tests now use relative accuracy more explicitly (this works better with generics, instead of using approx).
1 parent 99a5cb1 commit 0fe046d

File tree

5 files changed

+226
-111
lines changed

5 files changed

+226
-111
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ approx-0_5 = { package = "approx", version = "0.5", optional = true , default-fe
4040
cblas-sys = { version = "0.1.4", optional = true, default-features = false }
4141
libc = { version = "0.2.82", optional = true }
4242

43-
matrixmultiply = { version = "0.3.0", default-features = false}
43+
matrixmultiply = { version = "0.3.2", default-features = false, features=["cgemm"] }
4444

4545
serde = { version = "1.0", optional = true, default-features = false, features = ["alloc"] }
4646
rawpointer = { version = "0.2" }

benches/gemv.rs renamed to benches/gemv_gemm.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,13 @@
99
extern crate test;
1010
use test::Bencher;
1111

12+
use num_complex::Complex;
13+
use num_traits::{Float, One, Zero};
14+
1215
use ndarray::prelude::*;
1316

17+
use ndarray::LinalgScalar;
18+
use ndarray::linalg::general_mat_mul;
1419
use ndarray::linalg::general_mat_vec_mul;
1520

1621
#[bench]
@@ -45,3 +50,27 @@ fn gemv_64_32(bench: &mut Bencher) {
4550
general_mat_vec_mul(1.0, &a, &x, 1.0, &mut y);
4651
});
4752
}
53+
54+
#[bench]
55+
fn cgemm_100(bench: &mut Bencher) {
56+
cgemm_bench::<f32>(100, bench);
57+
}
58+
59+
#[bench]
60+
fn zgemm_100(bench: &mut Bencher) {
61+
cgemm_bench::<f64>(100, bench);
62+
}
63+
64+
fn cgemm_bench<A>(size: usize, bench: &mut Bencher)
65+
where
66+
A: LinalgScalar + Float,
67+
{
68+
let (m, k, n) = (size, size, size);
69+
let a = Array::<Complex<A>, _>::zeros((m, k));
70+
71+
let x = Array::zeros((k, n));
72+
let mut y = Array::zeros((m, n));
73+
bench.iter(|| {
74+
general_mat_mul(Complex::one(), &a, &x, Complex::zero(), &mut y);
75+
});
76+
}

src/linalg/impl_linalg.rs

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ use std::any::TypeId;
1818
use std::mem::MaybeUninit;
1919
use alloc::vec::Vec;
2020

21+
use num_complex::Complex;
22+
use num_complex::{Complex32 as c32, Complex64 as c64};
23+
2124
#[cfg(feature = "blas")]
2225
use libc::c_int;
2326
#[cfg(feature = "blas")]
@@ -30,9 +33,6 @@ use cblas_sys as blas_sys;
3033
#[cfg(feature = "blas")]
3134
use cblas_sys::{CblasNoTrans, CblasRowMajor, CblasTrans, CBLAS_LAYOUT};
3235

33-
#[cfg(feature = "blas")]
34-
use num_complex::{Complex32 as c32, Complex64 as c64};
35-
3636
/// len of vector before we use blas
3737
#[cfg(feature = "blas")]
3838
const DOT_BLAS_CUTOFF: usize = 32;
@@ -505,7 +505,7 @@ fn mat_mul_general<A>(
505505
let (rsc, csc) = (c.strides()[0], c.strides()[1]);
506506
if same_type::<A, f32>() {
507507
unsafe {
508-
::matrixmultiply::sgemm(
508+
matrixmultiply::sgemm(
509509
m,
510510
k,
511511
n,
@@ -524,7 +524,7 @@ fn mat_mul_general<A>(
524524
}
525525
} else if same_type::<A, f64>() {
526526
unsafe {
527-
::matrixmultiply::dgemm(
527+
matrixmultiply::dgemm(
528528
m,
529529
k,
530530
n,
@@ -541,6 +541,48 @@ fn mat_mul_general<A>(
541541
csc,
542542
);
543543
}
544+
} else if same_type::<A, c32>() {
545+
unsafe {
546+
matrixmultiply::cgemm(
547+
matrixmultiply::CGemmOption::Standard,
548+
matrixmultiply::CGemmOption::Standard,
549+
m,
550+
k,
551+
n,
552+
complex_array(cast_as(&alpha)),
553+
ap as *const _,
554+
lhs.strides()[0],
555+
lhs.strides()[1],
556+
bp as *const _,
557+
rhs.strides()[0],
558+
rhs.strides()[1],
559+
complex_array(cast_as(&beta)),
560+
cp as *mut _,
561+
rsc,
562+
csc,
563+
);
564+
}
565+
} else if same_type::<A, c64>() {
566+
unsafe {
567+
matrixmultiply::zgemm(
568+
matrixmultiply::CGemmOption::Standard,
569+
matrixmultiply::CGemmOption::Standard,
570+
m,
571+
k,
572+
n,
573+
complex_array(cast_as(&alpha)),
574+
ap as *const _,
575+
lhs.strides()[0],
576+
lhs.strides()[1],
577+
bp as *const _,
578+
rhs.strides()[0],
579+
rhs.strides()[1],
580+
complex_array(cast_as(&beta)),
581+
cp as *mut _,
582+
rsc,
583+
csc,
584+
);
585+
}
544586
} else {
545587
// It's a no-op if `c` has zero length.
546588
if c.is_empty() {
@@ -768,10 +810,17 @@ fn same_type<A: 'static, B: 'static>() -> bool {
768810
//
769811
// **Panics** if `A` and `B` are not the same type
770812
fn cast_as<A: 'static + Copy, B: 'static + Copy>(a: &A) -> B {
771-
assert!(same_type::<A, B>());
813+
assert!(same_type::<A, B>(), "expect type {} and {} to match",
814+
std::any::type_name::<A>(), std::any::type_name::<B>());
772815
unsafe { ::std::ptr::read(a as *const _ as *const B) }
773816
}
774817

818+
/// Return the complex in the form of an array [re, im]
819+
#[inline]
820+
fn complex_array<A: 'static + Copy>(z: Complex<A>) -> [A; 2] {
821+
[z.re, z.im]
822+
}
823+
775824
#[cfg(feature = "blas")]
776825
fn blas_compat_1d<A, S>(a: &ArrayBase<S, Ix1>) -> bool
777826
where

xtest-numeric/Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ name = "numeric-tests"
33
version = "0.1.0"
44
authors = ["bluss"]
55
publish = false
6+
edition = "2018"
67

78
[dependencies]
89
approx = "0.4"
@@ -17,6 +18,10 @@ openblas-src = { optional = true, version = "0.10", default-features = false, fe
1718
version = "0.8.0"
1819
features = ["small_rng"]
1920

21+
[dev-dependencies]
22+
num-traits = { version = "0.2.14", default-features = false }
23+
num-complex = { version = "0.4", default-features = false }
24+
2025
[lib]
2126
test = false
2227

0 commit comments

Comments
 (0)