Skip to content

Commit 84cc038

Browse files
emmatypingbluss
authored andcommitted
Add complex matmul test
1 parent 7bf4b62 commit 84cc038

File tree

2 files changed

+37
-12
lines changed

2 files changed

+37
-12
lines changed

src/linalg/impl_linalg.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ fn mat_mul_impl<A>(
416416
rhs_trans = CblasTrans;
417417
}
418418

419-
macro_rules! cast_ty {
419+
macro_rules! gemm_scalar_cast {
420420
(f32, $var:ident) => {
421421
cast_as(&$var)
422422
};
@@ -460,17 +460,17 @@ fn mat_mul_impl<A>(
460460
CblasRowMajor,
461461
lhs_trans,
462462
rhs_trans,
463-
m as blas_index, // m, rows of Op(a)
464-
n as blas_index, // n, cols of Op(b)
465-
k as blas_index, // k, cols of Op(a)
466-
cast_ty!($ty, alpha), // alpha
467-
lhs_.ptr.as_ptr() as *const _, // a
468-
lhs_stride, // lda
469-
rhs_.ptr.as_ptr() as *const _, // b
470-
rhs_stride, // ldb
471-
cast_ty!($ty, beta), // beta
472-
c_.ptr.as_ptr() as *mut _, // c
473-
c_stride, // ldc
463+
m as blas_index, // m, rows of Op(a)
464+
n as blas_index, // n, cols of Op(b)
465+
k as blas_index, // k, cols of Op(a)
466+
gemm_scalar_cast!($ty, alpha), // alpha
467+
lhs_.ptr.as_ptr() as *const _, // a
468+
lhs_stride, // lda
469+
rhs_.ptr.as_ptr() as *const _, // b
470+
rhs_stride, // ldb
471+
gemm_scalar_cast!($ty, beta), // beta
472+
c_.ptr.as_ptr() as *mut _, // c
473+
c_stride, // ldc
474474
);
475475
}
476476
return;

xtest-blas/tests/oper.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,31 @@ fn gemm_c32_1_f() {
313313
);
314314
}
315315

316+
#[test]
317+
fn gemm_c64_actually_complex() {
318+
let mut a = range_mat_complex64(4,4);
319+
a = a.map(|&i| if i.re > 8. { i.conj() } else { i });
320+
let mut b = range_mat_complex64(4,6);
321+
b = b.map(|&i| if i.re > 4. { i.conj() } else {i});
322+
let mut y = range_mat_complex64(4,6);
323+
let alpha = Complex64::new(0., 1.0);
324+
let beta = Complex64::new(1.0, 1.0);
325+
let answer = alpha * reference_mat_mul(&a, &b) + beta * &y;
326+
general_mat_mul(
327+
alpha.clone(),
328+
&a,
329+
&b,
330+
beta.clone(),
331+
&mut y,
332+
);
333+
assert_relative_eq!(
334+
y.mapv(|i| i.norm_sqr()),
335+
answer.mapv(|i| i.norm_sqr()),
336+
epsilon = 1e-12,
337+
max_relative = 1e-7
338+
);
339+
}
340+
316341
#[test]
317342
fn gen_mat_vec_mul() {
318343
let alpha = -2.3;

0 commit comments

Comments
 (0)