66// option. This file may not be copied, modified, or distributed
77// except according to those terms.
88
9+ use rayon;
910use libnum:: Zero ;
1011use itertools:: free:: enumerate;
1112
@@ -413,6 +414,8 @@ fn mat_mul_impl<A>(alpha: A,
413414 mat_mul_general ( alpha, lhs, rhs, beta, c)
414415}
415416
417+ const SPLIT : usize = 64 ;
418+
416419/// C ← α A B + β C
417420fn mat_mul_general < A > ( alpha : A ,
418421 lhs : & ArrayView < A , ( Ix , Ix ) > ,
@@ -421,7 +424,27 @@ fn mat_mul_general<A>(alpha: A,
421424 c : & mut ArrayViewMut < A , ( Ix , Ix ) > )
422425 where A : LinalgScalar ,
423426{
424- let ( ( m, k) , ( _, n) ) = ( lhs. dim , rhs. dim ) ;
427+ let ( ( m, k) , ( k2, n) ) = ( lhs. dim , rhs. dim ) ;
428+
429+ debug_assert_eq ! ( k, k2) ;
430+ if m > SPLIT {
431+ // [ A0 ] B = [ C0 ]
432+ // [ A1 ] [ C1 ]
433+ let mid = m / 2 ;
434+ let ( a0, a1) = lhs. split_at ( Axis ( 0 ) , mid) ;
435+ let ( mut c0, mut c1) = c. view_mut ( ) . split_at ( Axis ( 0 ) , mid) ;
436+ rayon:: join ( move || mat_mul_general ( alpha, & a0, rhs, beta, & mut c0) ,
437+ move || mat_mul_general ( alpha, & a1, rhs, beta, & mut c1) ) ;
438+ return ;
439+ } else if n > SPLIT {
440+ // A [ B0 B1 ] = [ C0 C1 ]
441+ let mid = n / 2 ;
442+ let ( b0, b1) = rhs. split_at ( Axis ( 1 ) , mid) ;
443+ let ( mut c0, mut c1) = c. view_mut ( ) . split_at ( Axis ( 1 ) , mid) ;
444+ rayon:: join ( move || mat_mul_general ( alpha, lhs, & b0, beta, & mut c0) ,
445+ move || mat_mul_general ( alpha, lhs, & b1, beta, & mut c1) ) ;
446+ return ;
447+ }
425448
426449 // common parameters for gemm
427450 let ap = lhs. as_ptr ( ) ;
0 commit comments