Skip to content

Commit 10c001d

Browse files
authored
Add in-place matrix multiply-add (#738)
Add support for the new 5-argument in-place matrix multiplication in Julia v1.3. For use in TrajectoryOptimization.jl normal SArrays work great, but we found that compilation times went through the roof when we tried larger problems. To my knowledge, SizedArray with 5-arg mul! is the only solution that works nearly as fast as SArrays at small sizes (thanks to loop unrolling!), scales well to large arrays (thanks to BLAS), and doesn't incur any memory allocations.
1 parent e6d935e commit 10c001d

7 files changed

+516
-208
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ julia = "1"
1313
[extras]
1414
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
1515
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
16+
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
1617

1718
[targets]
18-
test = ["InteractiveUtils", "Test"]
19+
test = ["InteractiveUtils", "Test", "BenchmarkTools"]

benchmark/bench_matrix_ops.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,50 @@ for f in [*, \]
4141
end
4242
end
4343

44+
# Multiply-add
45+
function benchmark_matmul(s,N1,N2,ArrayType)
46+
if ArrayType <: MArray
47+
Mat = MMatrix
48+
A = rand(Mat{N1,N2})
49+
B = rand(Mat{N2,N2})
50+
C = rand(Mat{N1,N2})
51+
label = "MMatrix"
52+
elseif ArrayType <: SizedArray
53+
Mat = SizedMatrix
54+
A = rand(Mat{N1,N2})
55+
B = rand(Mat{N2,N2})
56+
C = rand(Mat{N1,N2})
57+
label = "SizedMatrix"
58+
elseif ArrayType <: Array
59+
A = rand(N1,N2)
60+
B = rand(N2,N2)
61+
C = rand(N1,N2)
62+
label = "Matrix"
63+
end
64+
α,β = 1.0, 1.0
65+
s1 = s["mul!(C,A,B)"][string(N1, pad=2) * string(N2, pad=2)] = BenchmarkGroup()
66+
s2 = s["mul!(C,A,B,α,β)"][string(N1, pad=2) * string(N2, pad=2)] = BenchmarkGroup()
67+
s3 = s["mul!(B,A',C)"][string(N1, pad=2) * string(N2, pad=2)] = BenchmarkGroup()
68+
s4 = s["mul!(B,A',C,α,β)"][string(N1, pad=2) * string(N2, pad=2)] = BenchmarkGroup()
69+
70+
s1[label] = @benchmarkable mul!($C,$A,$B)
71+
s2[label] = @benchmarkable mul!($C,$A,$B,$α,$β)
72+
s3[label] = @benchmarkable mul!($B,Transpose($A),$C)
73+
s4[label] = @benchmarkable mul!($B,Transpose($A),$C,$α,$β)
74+
end
75+
76+
begin
77+
suite["mul!(C,A,B)"] = BenchmarkGroup(["inplace", "multiply-add"])
78+
suite["mul!(C,A,B,α,β)"] = BenchmarkGroup(["inplace", "multiply-add"])
79+
suite["mul!(B,A',C)"] = BenchmarkGroup(["inplace", "multiply-add"])
80+
suite["mul!(B,A',C,α,β)"] = BenchmarkGroup(["inplace", "multiply-add"])
81+
for N in matrix_sizes
82+
for Mat in (MMatrix, SizedMatrix, Matrix)
83+
benchmark_matmul(suite, N+1, N, Mat)
84+
end
85+
end
86+
end
87+
88+
4489
end # module
4590
BenchMatrixOps.suite

src/StaticArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ include("mapreduce.jl")
126126
include("sort.jl")
127127
include("arraymath.jl")
128128
include("linalg.jl")
129+
include("matrix_multiply_add.jl")
129130
include("matrix_multiply.jl")
130131
include("det.jl")
131132
include("inv.jl")

src/matrix_multiply.jl

Lines changed: 6 additions & 207 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,6 @@ import LinearAlgebra: BlasFloat, matprod, mul!
1313
@inline *(A::StaticArray{Tuple{N,1},<:Any,2}, B::Adjoint{<:Any,<:StaticVector}) where {N} = vec(A) * B
1414
@inline *(A::StaticArray{Tuple{N,1},<:Any,2}, B::Transpose{<:Any,<:StaticVector}) where {N} = vec(A) * B
1515

16-
@inline mul!(dest::StaticVecOrMat, A::StaticMatrix, B::StaticVector) = _mul!(Size(dest), dest, Size(A), Size(B), A, B)
17-
@inline mul!(dest::StaticVecOrMat, A::StaticMatrix, B::StaticMatrix) = _mul!(Size(dest), dest, Size(A), Size(B), A, B)
18-
@inline mul!(dest::StaticVecOrMat, A::StaticVector, B::StaticMatrix) = mul!(dest, reshape(A, Size(Size(A)[1], 1)), B)
19-
@inline mul!(dest::StaticVecOrMat, A::StaticVector, B::Transpose{<:Any, <:StaticVector}) = _mul!(Size(dest), dest, Size(A), Size(B), A, B)
20-
@inline mul!(dest::StaticVecOrMat, A::StaticVector, B::Adjoint{<:Any, <:StaticVector}) = _mul!(Size(dest), dest, Size(A), Size(B), A, B)
21-
#@inline *{TA<:LinearAlgebra.BlasFloat,Tb}(A::StaticMatrix{TA}, b::StaticVector{Tb})
22-
23-
2416

2517
# Implementations
2618

@@ -97,10 +89,14 @@ end
9789

9890
# Heuristic choice between BLAS and explicit unrolling (or chunk-based unrolling)
9991
if sa[1]*sa[2]*sb[2] >= 14*14*14
92+
Sa = TSize{size(S),false}()
93+
Sb = TSize{sa,false}()
94+
Sc = TSize{sb,false}()
95+
_add = MulAddMul(true,false)
10096
return quote
10197
@_inline_meta
10298
C = similar(a, T, $S)
103-
mul_blas!($S, C, Sa, Sb, a, b)
99+
mul_blas!($Sa, C, $Sa, $Sb, a, b, $_add)
104100
return C
105101
end
106102
elseif sa[1]*sa[2]*sb[2] < 8*8*8
@@ -177,7 +173,7 @@ end
177173
# Do a custom b[:, k2] to return a SVector (an isbitstype type) rather than (possibly) a mutable type. Avoids allocation == faster
178174
tmp_type_in = :(SVector{$(sb[1]), T})
179175
tmp_type_out = :(SVector{$(sa[1]), T})
180-
vect_exprs = [:($(Symbol("tmp_$k2"))::$tmp_type_out = partly_unrolled_multiply(Size(a), Size($(sb[1])), a,
176+
vect_exprs = [:($(Symbol("tmp_$k2"))::$tmp_type_out = partly_unrolled_multiply(TSize(a), TSize($(sb[1])), a,
181177
$(Expr(:call, tmp_type_in, [Expr(:ref, :b, LinearIndices(sb)[i, k2]) for i = 1:sb[1]]...)))::$tmp_type_out)
182178
for k2 = 1:sb[2]]
183179

@@ -193,201 +189,4 @@ end
193189
end
194190
end
195191

196-
@generated function partly_unrolled_multiply(::Size{sa}, ::Size{sb}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticArray{<:Tuple, Tb}) where {sa, sb, Ta, Tb}
197-
if sa[2] != sb[1]
198-
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb"))
199-
end
200-
201-
if sa[2] != 0
202-
exprs = [reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:(a[$(LinearIndices(sa)[k, j])]*b[$j]) for j = 1:sa[2]]) for k = 1:sa[1]]
203-
else
204-
exprs = [:(zero(promote_op(matprod,Ta,Tb))) for k = 1:sa[1]]
205-
end
206-
207-
return quote
208-
$(Expr(:meta,:noinline))
209-
@inbounds return SVector(tuple($(exprs...)))
210-
end
211-
end
212-
213-
# TODO aliasing problems if c === b?
214-
@generated function _mul!(::Size{sc}, c::StaticVector, ::Size{sa}, ::Size{sb}, a::StaticMatrix, b::StaticVector) where {sa, sb, sc}
215-
if sb[1] != sa[2] || sc[1] != sa[1]
216-
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb and assign to array of size $sc"))
217-
end
218-
219-
if sa[2] != 0
220-
exprs = [:(c[$k] = $(reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:(a[$(LinearIndices(sa)[k, j])]*b[$j]) for j = 1:sa[2]]))) for k = 1:sa[1]]
221-
else
222-
exprs = [:(c[$k] = zero(eltype(c))) for k = 1:sa[1]]
223-
end
224-
225-
return quote
226-
@_inline_meta
227-
@inbounds $(Expr(:block, exprs...))
228-
return c
229-
end
230-
end
231-
232-
@generated function _mul!(::Size{sc}, c::StaticMatrix, ::Size{sa}, ::Size{sb}, a::StaticVector,
233-
b::Union{Transpose{<:Any, <:StaticVector}, Adjoint{<:Any, <:StaticVector}}) where {sa, sb, sc}
234-
if sa[1] != sc[1] || sb[2] != sc[2]
235-
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb and assign to array of size $sc"))
236-
end
237-
238-
exprs = [:(c[$(LinearIndices(sc)[i, j])] = a[$i] * b[$j]) for i = 1:sa[1], j = 1:sb[2]]
239-
240-
return quote
241-
@_inline_meta
242-
@inbounds $(Expr(:block, exprs...))
243-
return c
244-
end
245-
end
246-
247-
@generated function _mul!(Sc::Size{sc}, c::StaticMatrix{<:Any, <:Any, Tc}, Sa::Size{sa}, Sb::Size{sb}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticMatrix{<:Any, <:Any, Tb}) where {sa, sb, sc, Ta, Tb, Tc}
248-
can_blas = Tc == Ta && Tc == Tb && Tc <: BlasFloat
249-
250-
if can_blas
251-
if sa[1] * sa[2] * sb[2] < 4*4*4
252-
return quote
253-
@_inline_meta
254-
mul_unrolled!(Sc, c, Sa, Sb, a, b)
255-
return c
256-
end
257-
elseif sa[1] * sa[2] * sb[2] < 14*14*14 # Something seems broken for this one with large matrices (becomes allocating)
258-
return quote
259-
@_inline_meta
260-
mul_unrolled_chunks!(Sc, c, Sa, Sb, a, b)
261-
return c
262-
end
263-
else
264-
return quote
265-
@_inline_meta
266-
mul_blas!(Sc, c, Sa, Sb, a, b)
267-
return c
268-
end
269-
end
270-
else
271-
if sa[1] * sa[2] * sb[2] < 4*4*4
272-
return quote
273-
@_inline_meta
274-
mul_unrolled!(Sc, c, Sa, Sb, a, b)
275-
return c
276-
end
277-
else
278-
return quote
279-
@_inline_meta
280-
mul_unrolled_chunks!(Sc, c, Sa, Sb, a, b)
281-
return c
282-
end
283-
end
284-
end
285-
end
286-
287-
288-
@generated function mul_blas!(::Size{s}, c::StaticMatrix{<:Any, <:Any, T}, ::Size{sa}, ::Size{sb}, a::StaticMatrix{<:Any, <:Any, T}, b::StaticMatrix{<:Any, <:Any, T}) where {s,sa,sb, T <: BlasFloat}
289-
if sb[1] != sa[2] || sa[1] != s[1] || sb[2] != s[2]
290-
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb and assign to array of size $s"))
291-
end
292-
293-
if sa[1] > 0 && sa[2] > 0 && sb[2] > 0
294-
# This code adapted from `gemm!()` in base/linalg/blas.jl
295-
296-
if T == Float64
297-
gemm = :dgemm_
298-
elseif T == Float32
299-
gemm = :sgemm_
300-
elseif T == Complex{Float64}
301-
gemm = :zgemm_
302-
else # T == Complex{Float32}
303-
gemm = :cgemm_
304-
end
305-
306-
blascall = quote
307-
ccall((LinearAlgebra.BLAS.@blasfunc($gemm), LinearAlgebra.BLAS.libblas), Nothing,
308-
(Ref{UInt8}, Ref{UInt8}, Ref{LinearAlgebra.BLAS.BlasInt}, Ref{LinearAlgebra.BLAS.BlasInt},
309-
Ref{LinearAlgebra.BLAS.BlasInt}, Ref{$T}, Ptr{$T}, Ref{LinearAlgebra.BLAS.BlasInt},
310-
Ptr{$T}, Ref{LinearAlgebra.BLAS.BlasInt}, Ref{$T}, Ptr{$T},
311-
Ref{LinearAlgebra.BLAS.BlasInt}),
312-
transA, transB, m, n,
313-
ka, alpha, a, strideA,
314-
b, strideB, beta, c,
315-
strideC)
316-
end
317-
318-
return quote
319-
alpha = one(T)
320-
beta = zero(T)
321-
transA = 'N'
322-
transB = 'N'
323-
m = $(sa[1])
324-
ka = $(sa[2])
325-
kb = $(sb[1])
326-
n = $(sb[2])
327-
strideA = $(sa[1])
328-
strideB = $(sb[1])
329-
strideC = $(s[1])
330-
331-
$blascall
332-
333-
return c
334-
end
335-
else
336-
throw(DimensionMismatch("Cannot call BLAS gemm with zero-dimension arrays, attempted $sa * $sb -> $s."))
337-
end
338-
end
339-
340-
341-
@generated function mul_unrolled!(::Size{sc}, c::StaticMatrix, ::Size{sa}, ::Size{sb}, a::StaticMatrix, b::StaticMatrix) where {sa, sb, sc}
342-
if sb[1] != sa[2] || sa[1] != sc[1] || sb[2] != sc[2]
343-
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb and assign to array of size $sc"))
344-
end
345-
346-
if sa[2] != 0
347-
exprs = [:(c[$(LinearIndices(sc)[k1, k2])] = $(reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:(a[$(LinearIndices(sa)[k1, j])]*b[$(LinearIndices(sb)[j, k2])]) for j = 1:sa[2]]))) for k1 = 1:sa[1], k2 = 1:sb[2]]
348-
else
349-
exprs = [:(c[$(LinearIndices(sc)[k1, k2])] = zero(eltype(c))) for k1 = 1:sa[1], k2 = 1:sb[2]]
350-
end
351-
352-
return quote
353-
@_inline_meta
354-
@inbounds $(Expr(:block, exprs...))
355-
end
356-
end
357-
358-
@generated function mul_unrolled_chunks!(::Size{sc}, c::StaticMatrix, ::Size{sa}, ::Size{sb}, a::StaticMatrix, b::StaticMatrix) where {sa, sb, sc}
359-
if sb[1] != sa[2] || sa[1] != sc[1] || sb[2] != sc[2]
360-
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb and assign to array of size $sc"))
361-
end
362-
363-
#vect_exprs = [:($(Symbol("tmp_$k2")) = partly_unrolled_multiply(A, B[:, $k2])) for k2 = 1:sB[2]]
364-
365-
# Do a custom b[:, k2] to return a SVector (an isbitstype type) rather than a mutable type. Avoids allocation == faster
366-
tmp_type = SVector{sb[1], eltype(c)}
367-
vect_exprs = [:($(Symbol("tmp_$k2")) = partly_unrolled_multiply($(Size(sa)), $(Size(sb[1])), a, $(Expr(:call, tmp_type, [Expr(:ref, :b, LinearIndices(sb)[i, k2]) for i = 1:sb[1]]...)))) for k2 = 1:sb[2]]
368-
369-
exprs = [:(c[$(LinearIndices(sc)[k1, k2])] = $(Symbol("tmp_$k2"))[$k1]) for k1 = 1:sa[1], k2 = 1:sb[2]]
370-
371-
return quote
372-
@_inline_meta
373-
@inbounds $(Expr(:block, vect_exprs...))
374-
@inbounds $(Expr(:block, exprs...))
375-
end
376-
end
377-
378-
#function mul_blas(a, b, c, A, B)
379-
#q
380-
#end
381-
382-
# The idea here is to get pointers to stack variables and call BLAS.
383-
# This saves an aweful lot of time compared to copying SArray's to Ref{SArray{...}}
384-
# and using BLAS should be fastest for (very) large SArrays
385-
386-
# Here is an LLVM function that gets the pointer to its input, %x
387-
# After this we would make the ccall above.
388192
#
389-
# define i8* @f(i32 %x) #0 {
390-
# %1 = alloca i32, align 4
391-
# store i32 %x, i32* %1, align 4
392-
# ret i32* %1
393-
# }

0 commit comments

Comments
 (0)