Skip to content

Commit fbea544

Browse files
committed
widen to StaticMatMulLike
1 parent 9c1cd18 commit fbea544

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

src/matrix_multiply.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ import LinearAlgebra: BlasFloat, matprod, mul!
1616
@inline *(A::StaticArray{Tuple{N,1},<:Any,2}, B::Transpose{<:Any,<:StaticVector}) where {N} = vec(A) * B
1717

1818
# Avoid LinearAlgebra._quad_matmul's order calculation on equal sizes
19-
@inline *(A::StaticMatrix{N,N}, B::StaticMatrix{N,N}, C::StaticMatrix{N,N}) where {N} = (A*B)*C
20-
@inline *(A::StaticMatrix{N,N}, B::StaticMatrix{N,N}, C::StaticMatrix{N,N}, D::StaticMatrix{N,N}) where {N} = ((A*B)*C)*D
19+
@inline *(A::StaticMatMulLike{N,N}, B::StaticMatMulLike{N,N}, C::StaticMatMulLike{N,N}) where {N} = (A*B)*C
20+
@inline *(A::StaticMatMulLike{N,N}, B::StaticMatMulLike{N,N}, C::StaticMatMulLike{N,N}, D::StaticMatMulLike{N,N}) where {N} = ((A*B)*C)*D
2121

2222
"""
2323
mul_result_structure(a::Type, b::Type)

test/matrix_multiply.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,11 @@ mul_wrappers = [
173173
@test m*transpose(n) === @SMatrix [8 14; 18 32]
174174
@test transpose(m)*transpose(n) === @SMatrix [11 19; 16 28]
175175

176+
# 3- and 4-arg *
176177
@test @inferred(m*n*m) === @SMatrix [49 72; 109 160]
177178
@test @inferred(m*n*m*n) === @SMatrix [386 507; 858 1127]
179+
@test @inferred(m*n'*UpperTriangular(m)) === @SMatrix [8 72; 18 164]
180+
@test @inferred(Diagonal(m)*n*m'*transpose(n)) === @SMatrix [70 122; 496 864]
178181

179182
# check different sizes because there are multiple implementations for matrices of different sizes
180183
for (mm, nn) in [

0 commit comments

Comments
 (0)