Skip to content

Commit fecec86

Browse files
committed
TEST: Make sure all kernels are tested
These tests are rudimentary so far, but they cover all the possible kernels (avx, sse2, fallback) we have so far.
1 parent e85d96e commit fecec86

File tree

2 files changed

+114
-19
lines changed

2 files changed

+114
-19
lines changed

src/dgemm_kernel.rs

Lines changed: 75 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -132,22 +132,82 @@ unsafe fn at(ptr: *const T, i: usize) -> T {
132132
*ptr.offset(i as isize)
133133
}
134134

135-
#[test]
136-
fn test_gemm_kernel() {
137-
let mut a = [1.; 32];
138-
let mut b = [0.; 16];
139-
for (i, x) in a.iter_mut().enumerate() {
140-
*x = i as f64;
135+
#[cfg(test)]
136+
mod tests {
137+
use super::*;
138+
use aligned_alloc::Alloc;
139+
140+
fn aligned_alloc<T>(elt: T, n: usize) -> Alloc<T> where T: Copy
141+
{
142+
unsafe {
143+
Alloc::new(n, Gemm::align_to()).init_with(elt)
144+
}
141145
}
142-
for i in 0..4 {
143-
b[i + i * 4] = 1.;
146+
147+
use super::T;
148+
type KernelFn = unsafe fn(usize, T, *const T, *const T, T, *mut T, isize, isize);
149+
150+
fn test_a_kernel(_name: &str, kernel_fn: KernelFn) {
151+
const K: usize = 4;
152+
let mut a = aligned_alloc(1., MR * K);
153+
let mut b = aligned_alloc(0., NR * K);
154+
for (i, x) in a.iter_mut().enumerate() {
155+
*x = i as _;
156+
}
157+
158+
for i in 0..K {
159+
b[i + i * NR] = 1.;
160+
}
161+
let mut c = [0.; MR * NR];
162+
unsafe {
163+
kernel_fn(K, 1., &a[0], &b[0], 0., &mut c[0], 1, MR as isize);
164+
// col major C
165+
}
166+
assert_eq!(&a[..], &c[..a.len()]);
144167
}
145-
let mut c = [0.; 32];
146-
unsafe {
147-
kernel(4, 1., &a[0], &b[0],
148-
0., &mut c[0], 1, 8);
149-
// transposed C so that results line up
168+
169+
#[test]
170+
fn test_native_kernel() {
171+
test_a_kernel("kernel", kernel);
150172
}
151-
assert_eq!(&a, &c);
152-
}
153173

174+
#[test]
175+
fn test_kernel_fallback_impl() {
176+
test_a_kernel("kernel", kernel_fallback_impl);
177+
}
178+
179+
#[test]
180+
fn test_loop_m_n() {
181+
let mut m = [[0; NR]; MR];
182+
loop_m!(i, loop_n!(j, m[i][j] += 1));
183+
for arr in &m[..] {
184+
for elt in &arr[..] {
185+
assert_eq!(*elt, 1);
186+
}
187+
}
188+
}
189+
190+
mod test_arch_kernels {
191+
use super::test_a_kernel;
192+
macro_rules! test_arch_kernels_x86 {
193+
($($feature_name:tt, $function_name:ident),*) => {
194+
$(
195+
#[test]
196+
fn $function_name() {
197+
if is_x86_feature_detected_!($feature_name) {
198+
test_a_kernel(stringify!($function_name), super::super::$function_name);
199+
} else {
200+
println!("Skipping, host does not have feature: {:?}", $feature_name);
201+
}
202+
}
203+
)*
204+
}
205+
}
206+
207+
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
208+
test_arch_kernels_x86! {
209+
"avx", kernel_target_avx,
210+
"sse2", kernel_target_sse2
211+
}
212+
}
213+
}

src/sgemm_kernel.rs

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -395,27 +395,38 @@ mod tests {
395395
}
396396
}
397397

398+
use super::T;
399+
type KernelFn = unsafe fn(usize, T, *const T, *const T, T, *mut T, isize, isize);
398400

399-
#[test]
400-
fn test_gemm_kernel() {
401+
fn test_a_kernel(_name: &str, kernel_fn: KernelFn) {
401402
const K: usize = 4;
402403
let mut a = aligned_alloc(1., MR * K);
403404
let mut b = aligned_alloc(0., NR * K);
404405
for (i, x) in a.iter_mut().enumerate() {
405-
*x = i as f32;
406+
*x = i as _;
406407
}
407408

408409
for i in 0..K {
409410
b[i + i * NR] = 1.;
410411
}
411412
let mut c = [0.; MR * NR];
412413
unsafe {
413-
kernel(K, 1., &a[0], &b[0], 0., &mut c[0], 1, MR as isize);
414+
kernel_fn(K, 1., &a[0], &b[0], 0., &mut c[0], 1, MR as isize);
414415
// col major C
415416
}
416417
assert_eq!(&a[..], &c[..a.len()]);
417418
}
418419

420+
#[test]
421+
fn test_native_kernel() {
422+
test_a_kernel("kernel", kernel);
423+
}
424+
425+
#[test]
426+
fn test_kernel_fallback_impl() {
427+
test_a_kernel("kernel", kernel_fallback_impl);
428+
}
429+
419430
#[test]
420431
fn test_loop_m_n() {
421432
let mut m = [[0; NR]; MR];
@@ -426,4 +437,28 @@ mod tests {
426437
}
427438
}
428439
}
440+
441+
mod test_arch_kernels {
442+
use super::test_a_kernel;
443+
macro_rules! test_arch_kernels_x86 {
444+
($($feature_name:tt, $function_name:ident),*) => {
445+
$(
446+
#[test]
447+
fn $function_name() {
448+
if is_x86_feature_detected_!($feature_name) {
449+
test_a_kernel(stringify!($function_name), super::super::$function_name);
450+
} else {
451+
println!("Skipping, host does not have feature: {:?}", $feature_name);
452+
}
453+
}
454+
)*
455+
}
456+
}
457+
458+
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
459+
test_arch_kernels_x86! {
460+
"avx", kernel_target_avx,
461+
"sse2", kernel_target_sse2
462+
}
463+
}
429464
}

0 commit comments

Comments
 (0)