Skip to content

Commit 0172657

Browse files
authored
Merge pull request #1106 from ethanhs/complexmatmul
Complex dot()
2 parents 1c685ef + 84cc038 commit 0172657

File tree

3 files changed

+135
-21
lines changed

3 files changed

+135
-21
lines changed

src/linalg/impl_linalg.rs

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,32 @@
77
// except according to those terms.
88

99
use crate::imp_prelude::*;
10-
use crate::numeric_util;
10+
1111
#[cfg(feature = "blas")]
1212
use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr;
13+
use crate::numeric_util;
1314

1415
use crate::{LinalgScalar, Zip};
1516

1617
use std::any::TypeId;
1718
use std::mem::MaybeUninit;
1819
use alloc::vec::Vec;
1920

21+
#[cfg(feature = "blas")]
22+
use libc::c_int;
2023
#[cfg(feature = "blas")]
2124
use std::cmp;
2225
#[cfg(feature = "blas")]
2326
use std::mem::swap;
24-
#[cfg(feature = "blas")]
25-
use libc::c_int;
2627

2728
#[cfg(feature = "blas")]
2829
use cblas_sys as blas_sys;
2930
#[cfg(feature = "blas")]
3031
use cblas_sys::{CblasNoTrans, CblasRowMajor, CblasTrans, CBLAS_LAYOUT};
3132

33+
#[cfg(feature = "blas")]
34+
use num_complex::{Complex32 as c32, Complex64 as c64};
35+
3236
/// len of vector before we use blas
3337
#[cfg(feature = "blas")]
3438
const DOT_BLAS_CUTOFF: usize = 32;
@@ -377,7 +381,12 @@ fn mat_mul_impl<A>(
377381
// size cutoff for using BLAS
378382
let cut = GEMM_BLAS_CUTOFF;
379383
let ((mut m, a), (_, mut n)) = (lhs.dim(), rhs.dim());
380-
if !(m > cut || n > cut || a > cut) || !(same_type::<A, f32>() || same_type::<A, f64>()) {
384+
if !(m > cut || n > cut || a > cut)
385+
|| !(same_type::<A, f32>()
386+
|| same_type::<A, f64>()
387+
|| same_type::<A, c32>()
388+
|| same_type::<A, c64>())
389+
{
381390
return mat_mul_general(alpha, lhs, rhs, beta, c);
382391
}
383392
{
@@ -407,8 +416,23 @@ fn mat_mul_impl<A>(
407416
rhs_trans = CblasTrans;
408417
}
409418

419+
macro_rules! gemm_scalar_cast {
420+
(f32, $var:ident) => {
421+
cast_as(&$var)
422+
};
423+
(f64, $var:ident) => {
424+
cast_as(&$var)
425+
};
426+
(c32, $var:ident) => {
427+
&$var as *const A as *const _
428+
};
429+
(c64, $var:ident) => {
430+
&$var as *const A as *const _
431+
};
432+
}
433+
410434
macro_rules! gemm {
411-
($ty:ty, $gemm:ident) => {
435+
($ty:tt, $gemm:ident) => {
412436
if blas_row_major_2d::<$ty, _>(&lhs_)
413437
&& blas_row_major_2d::<$ty, _>(&rhs_)
414438
&& blas_row_major_2d::<$ty, _>(&c_)
@@ -428,25 +452,25 @@ fn mat_mul_impl<A>(
428452
let lhs_stride = cmp::max(lhs_.strides()[0] as blas_index, k as blas_index);
429453
let rhs_stride = cmp::max(rhs_.strides()[0] as blas_index, n as blas_index);
430454
let c_stride = cmp::max(c_.strides()[0] as blas_index, n as blas_index);
431-
455+
432456
// gemm is C ← αA^Op B^Op + βC
433457
// Where Op is notrans/trans/conjtrans
434458
unsafe {
435459
blas_sys::$gemm(
436460
CblasRowMajor,
437461
lhs_trans,
438462
rhs_trans,
439-
m as blas_index, // m, rows of Op(a)
440-
n as blas_index, // n, cols of Op(b)
441-
k as blas_index, // k, cols of Op(a)
442-
cast_as(&alpha), // alpha
443-
lhs_.ptr.as_ptr() as *const _, // a
444-
lhs_stride, // lda
445-
rhs_.ptr.as_ptr() as *const _, // b
446-
rhs_stride, // ldb
447-
cast_as(&beta), // beta
448-
c_.ptr.as_ptr() as *mut _, // c
449-
c_stride, // ldc
463+
m as blas_index, // m, rows of Op(a)
464+
n as blas_index, // n, cols of Op(b)
465+
k as blas_index, // k, cols of Op(a)
466+
gemm_scalar_cast!($ty, alpha), // alpha
467+
lhs_.ptr.as_ptr() as *const _, // a
468+
lhs_stride, // lda
469+
rhs_.ptr.as_ptr() as *const _, // b
470+
rhs_stride, // ldb
471+
gemm_scalar_cast!($ty, beta), // beta
472+
c_.ptr.as_ptr() as *mut _, // c
473+
c_stride, // ldc
450474
);
451475
}
452476
return;
@@ -455,6 +479,9 @@ fn mat_mul_impl<A>(
455479
}
456480
gemm!(f32, cblas_sgemm);
457481
gemm!(f64, cblas_dgemm);
482+
483+
gemm!(c32, cblas_cgemm);
484+
gemm!(c64, cblas_zgemm);
458485
}
459486
mat_mul_general(alpha, lhs, rhs, beta, c)
460487
}
@@ -603,9 +630,7 @@ pub fn general_mat_vec_mul<A, S1, S2, S3>(
603630
S3: DataMut<Elem = A>,
604631
A: LinalgScalar,
605632
{
606-
unsafe {
607-
general_mat_vec_mul_impl(alpha, a, x, beta, y.raw_view_mut())
608-
}
633+
unsafe { general_mat_vec_mul_impl(alpha, a, x, beta, y.raw_view_mut()) }
609634
}
610635

611636
/// General matrix-vector multiplication

xtest-blas/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ test = false
1111
approx = "0.4"
1212
defmac = "0.2"
1313
num-traits = "0.2"
14+
num-complex = { version = "0.4", default-features = false }
1415

1516
[dependencies]
1617
ndarray = { path = "../", features = ["approx", "blas"] }

xtest-blas/tests/oper.rs

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
extern crate approx;
2+
extern crate blas_src;
23
extern crate defmac;
34
extern crate ndarray;
5+
extern crate num_complex;
46
extern crate num_traits;
5-
extern crate blas_src;
67

78
use ndarray::prelude::*;
89

@@ -12,6 +13,8 @@ use ndarray::{Data, Ix, LinalgScalar};
1213

1314
use approx::assert_relative_eq;
1415
use defmac::defmac;
16+
use num_complex::Complex32;
17+
use num_complex::Complex64;
1518

1619
#[test]
1720
fn mat_vec_product_1d() {
@@ -52,6 +55,20 @@ fn range_mat64(m: Ix, n: Ix) -> Array2<f64> {
5255
.unwrap()
5356
}
5457

58+
fn range_mat_complex(m: Ix, n: Ix) -> Array2<Complex32> {
59+
Array::linspace(0., (m * n) as f32 - 1., m * n)
60+
.into_shape((m, n))
61+
.unwrap()
62+
.map(|&f| Complex32::new(f, 0.))
63+
}
64+
65+
fn range_mat_complex64(m: Ix, n: Ix) -> Array2<Complex64> {
66+
Array::linspace(0., (m * n) as f64 - 1., m * n)
67+
.into_shape((m, n))
68+
.unwrap()
69+
.map(|&f| Complex64::new(f, 0.))
70+
}
71+
5572
fn range1_mat64(m: Ix) -> Array1<f64> {
5673
Array::linspace(0., m as f64 - 1., m)
5774
}
@@ -250,6 +267,77 @@ fn gemm_64_1_f() {
250267
assert_relative_eq!(y, answer, epsilon = 1e-12, max_relative = 1e-7);
251268
}
252269

270+
#[test]
271+
fn gemm_c64_1_f() {
272+
let a = range_mat_complex64(64, 64).reversed_axes();
273+
let (m, n) = a.dim();
274+
// m x n times n x 1 == m x 1
275+
let x = range_mat_complex64(n, 1);
276+
let mut y = range_mat_complex64(m, 1);
277+
let answer = reference_mat_mul(&a, &x) + &y;
278+
general_mat_mul(
279+
Complex64::new(1.0, 0.),
280+
&a,
281+
&x,
282+
Complex64::new(1.0, 0.),
283+
&mut y,
284+
);
285+
assert_relative_eq!(
286+
y.mapv(|i| i.norm_sqr()),
287+
answer.mapv(|i| i.norm_sqr()),
288+
epsilon = 1e-12,
289+
max_relative = 1e-7
290+
);
291+
}
292+
293+
#[test]
294+
fn gemm_c32_1_f() {
295+
let a = range_mat_complex(64, 64).reversed_axes();
296+
let (m, n) = a.dim();
297+
// m x n times n x 1 == m x 1
298+
let x = range_mat_complex(n, 1);
299+
let mut y = range_mat_complex(m, 1);
300+
let answer = reference_mat_mul(&a, &x) + &y;
301+
general_mat_mul(
302+
Complex32::new(1.0, 0.),
303+
&a,
304+
&x,
305+
Complex32::new(1.0, 0.),
306+
&mut y,
307+
);
308+
assert_relative_eq!(
309+
y.mapv(|i| i.norm_sqr()),
310+
answer.mapv(|i| i.norm_sqr()),
311+
epsilon = 1e-12,
312+
max_relative = 1e-7
313+
);
314+
}
315+
316+
#[test]
317+
fn gemm_c64_actually_complex() {
318+
let mut a = range_mat_complex64(4,4);
319+
a = a.map(|&i| if i.re > 8. { i.conj() } else { i });
320+
let mut b = range_mat_complex64(4,6);
321+
b = b.map(|&i| if i.re > 4. { i.conj() } else {i});
322+
let mut y = range_mat_complex64(4,6);
323+
let alpha = Complex64::new(0., 1.0);
324+
let beta = Complex64::new(1.0, 1.0);
325+
let answer = alpha * reference_mat_mul(&a, &b) + beta * &y;
326+
general_mat_mul(
327+
alpha.clone(),
328+
&a,
329+
&b,
330+
beta.clone(),
331+
&mut y,
332+
);
333+
assert_relative_eq!(
334+
y.mapv(|i| i.norm_sqr()),
335+
answer.mapv(|i| i.norm_sqr()),
336+
epsilon = 1e-12,
337+
max_relative = 1e-7
338+
);
339+
}
340+
253341
#[test]
254342
fn gen_mat_vec_mul() {
255343
let alpha = -2.3;

0 commit comments

Comments
 (0)