Skip to content

Commit 9c1cd18

Browse files
Michael Abbottmcabbott
Michael Abbott
authored andcommitted
add simple 3-arg and 4-arg * methods
1 parent c1620f0 commit 9c1cd18

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

src/matrix_multiply.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,15 @@ import LinearAlgebra: BlasFloat, matprod, mul!
1515
@inline *(A::StaticArray{Tuple{N,1},<:Any,2}, B::Adjoint{<:Any,<:StaticVector}) where {N} = vec(A) * B
1616
@inline *(A::StaticArray{Tuple{N,1},<:Any,2}, B::Transpose{<:Any,<:StaticVector}) where {N} = vec(A) * B
1717

18+
# Avoid LinearAlgebra._quad_matmul's order calculation on equal sizes
19+
@inline *(A::StaticMatrix{N,N}, B::StaticMatrix{N,N}, C::StaticMatrix{N,N}) where {N} = (A*B)*C
20+
@inline *(A::StaticMatrix{N,N}, B::StaticMatrix{N,N}, C::StaticMatrix{N,N}, D::StaticMatrix{N,N}) where {N} = ((A*B)*C)*D
21+
1822
"""
1923
mul_result_structure(a::Type, b::Type)
2024
2125
Get a structure wrapper that should be applied to the result of multiplication of matrices
22-
of given types (a*b).
26+
of given types (a*b).
2327
"""
2428
function mul_result_structure(a, b)
2529
return identity
@@ -114,7 +118,6 @@ end
114118
b::Union{Transpose{Tb, <:StaticVector}, Adjoint{Tb, <:StaticVector}}) where {sa, sb, Ta, Tb}
115119
newsize = (sa[1], sb[2])
116120
exprs = [:(a[$i]*b[$j]) for i = 1:sa[1], j = 1:sb[2]]
117-
118121
return quote
119122
@_inline_meta
120123
T = promote_op(*, Ta, Tb)
@@ -209,7 +212,7 @@ end
209212
while m < M
210213
mu = min(M, m + M_r)
211214
mrange = m+1:mu
212-
215+
213216
atemps_init = [:($(atemps[k1]) = a[$k1]) for k1 = mrange]
214217
exprs_init = [:($(tmps[k1,k2]) = $(atemps[k1]) * b[$(1 + (k2-1) * sb[1])]) for k1 = mrange, k2 = nrange]
215218
atemps_loop_init = [:($(atemps[k1]) = a[$(k1-sa[1]) + $(sa[1])*j]) for k1 = mrange]

test/matrix_multiply.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,9 @@ mul_wrappers = [
173173
@test m*transpose(n) === @SMatrix [8 14; 18 32]
174174
@test transpose(m)*transpose(n) === @SMatrix [11 19; 16 28]
175175

176+
@test @inferred(m*n*m) === @SMatrix [49 72; 109 160]
177+
@test @inferred(m*n*m*n) === @SMatrix [386 507; 858 1127]
178+
176179
# check different sizes because there are multiple implementations for matrices of different sizes
177180
for (mm, nn) in [
178181
(m, n),

0 commit comments

Comments
 (0)