Skip to content

Commit 4e6c064

Browse files
author
Andy Ferris
committed
Added size-inferable matrix-vector products
Can now do a fast StaticMatrix * AbstractVector and related `A_mul_B!` calls.
1 parent f1c2b3e commit 4e6c064

File tree

2 files changed

+98
-12
lines changed

2 files changed

+98
-12
lines changed

src/matrix_multiply.jl

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ end
9999

100100
s = (sA[1],)
101101
T = promote_op(matprod, TA, Tb)
102+
#println(T)
102103

103104
if sb[1] != sA[2]
104105
error("Dimension mismatch")
@@ -130,6 +131,37 @@ end
130131
end
131132
end
132133

134+
# This happens to be size-inferrable from A
135+
@generated function *(A::StaticMatrix, b::AbstractVector)
136+
TA = eltype(A)
137+
Tb = eltype(b)
138+
sA = size(A)
139+
#sb = size(b)
140+
141+
s = (sA[1],)
142+
T = promote_op(matprod, TA, Tb)
143+
144+
if T == Tb
145+
newtype = similar_type(A, s)
146+
else
147+
newtype = similar_type(A, T, s)
148+
end
149+
150+
if sA[2] != 0
151+
exprs = [reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:(A[$(sub2ind(sA, k, j))]*b[$j]) for j = 1:sA[2]]) for k = 1:sA[1]]
152+
else
153+
exprs = [zero(T) for k = 1:sA[1]]
154+
end
155+
156+
return quote
157+
$(Expr(:meta,:inline))
158+
if length(b) != $(sA[2])
159+
error("Dimension mismatch")
160+
end
161+
@inbounds return $(Expr(:call, newtype, Expr(:tuple, exprs...)))
162+
end
163+
end
164+
133165
@generated function *(a::StaticVector, B::StaticMatrix)
134166
Ta = eltype(a)
135167
TB = eltype(B)
@@ -402,7 +434,7 @@ end
402434

403435
# The idea here is to get pointers to stack variables and call BLAS.
404436
# This saves an aweful lot of time compared to copying SArray's to Ref{SArray{...}}
405-
# and should be fastest for (very) large SArrays
437+
# and using BLAS should be fastest for (very) large SArrays
406438

407439
# Here is an LLVM function that gets the pointer to its input, %x
408440
# After this we would make the ccall above.
@@ -413,6 +445,48 @@ end
413445
# ret i32* %1
414446
# }
415447

448+
@generated function A_mul_B!(c::StaticVector, A::StaticMatrix, b::StaticVector)
449+
sA = size(A)
450+
sb = size(b)
451+
s = size(c)
452+
T = eltype(c)
453+
454+
if sb[1] != sA[2] || s[1] != sA[1]
455+
error("Dimension mismatch")
456+
end
457+
458+
if sA[2] != 0
459+
exprs = [:(c[$k] = $(reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:(A[$(sub2ind(sA, k, j))]*b[$j]) for j = 1:sA[2]]))) for k = 1:sA[1]]
460+
else
461+
exprs = [:(c[$k] = $(zero(T))) for k = 1:sA[1]]
462+
end
463+
464+
return quote
465+
$(Expr(:meta,:inline))
466+
@inbounds $(Expr(:block, exprs...))
467+
end
468+
end
469+
470+
# The unrolled code is inferrable from the size of A
471+
@generated function A_mul_B!(c::AbstractVector, A::StaticMatrix, b::AbstractVector)
472+
sA = size(A)
473+
T = eltype(c)
474+
475+
if sA[2] != 0
476+
exprs = [:(c[$k] = $(reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:(A[$(sub2ind(sA, k, j))]*b[$j]) for j = 1:sA[2]]))) for k = 1:sA[1]]
477+
else
478+
exprs = [:(c[$k] = $(zero(T))) for k = 1:sA[1]]
479+
end
480+
481+
return quote
482+
$(Expr(:meta,:inline))
483+
if length(b) != $(sA[2]) || length(c) != $(sA[1])
484+
error("Dimension mismatch")
485+
end
486+
@inbounds $(Expr(:block, exprs...))
487+
end
488+
end
489+
416490

417491
@generated function A_mul_B!(C::StaticMatrix, A::StaticMatrix, B::StaticMatrix)
418492
if isbits(C)

test/matrix_multiply.jl

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,25 @@
44
v = @SVector [1, 2]
55
@test m*v === @SVector [5, 11]
66
# More complicated eltype inference
7-
v = @SVector [CartesianIndex((1,3)), CartesianIndex((3,1))]
8-
x = @inferred(m*v)
7+
v2 = @SVector [CartesianIndex((1,3)), CartesianIndex((3,1))]
8+
x = @inferred(m*v2)
99
@test isa(x, SVector{2,CartesianIndex{2}})
1010
@test x == @SVector [CartesianIndex((7,5)), CartesianIndex((15,13))]
1111

12-
m = @MMatrix [1 2; 3 4]
13-
v = @MVector [1, 2]
14-
@test (m*v)::MVector == @MVector [5, 11]
12+
v3 = [1, 2]
13+
@test m*v3 === @SVector [5, 11]
1514

16-
m = @SArray [1 2; 3 4]
17-
v = @SArray [1, 2]
18-
@test m*v === @SArray [5, 11]
15+
m2 = @MMatrix [1 2; 3 4]
16+
v4 = @MVector [1, 2]
17+
@test (m2*v4)::MVector == @MVector [5, 11]
1918

20-
m = @MArray [1 2; 3 4]
21-
v = @MArray [1, 2]
22-
@test (m*v)::MArray == @MArray [5, 11]
19+
m3 = @SArray [1 2; 3 4]
20+
v5 = @SArray [1, 2]
21+
@test m3*v5 === @SArray [5, 11]
22+
23+
m4 = @MArray [1 2; 3 4]
24+
v6 = @MArray [1, 2]
25+
@test (m4*v6)::MArray == @MArray [5, 11]
2326
end
2427

2528
@testset "Vector-matrix" begin
@@ -117,9 +120,18 @@
117120
end
118121

119122
@testset "A_mul_B!" begin
123+
v = @SVector [2, 4]
124+
v2 = [2, 4]
120125
m = @SMatrix [1 2; 3 4]
121126
n = @SMatrix [2 3; 4 5]
122127

128+
outvec = MVector{2,Int}()
129+
A_mul_B!(outvec, m, v)
130+
@test outvec == @MVector [10,22]
131+
outvec2 = Vector{Int}(2)
132+
A_mul_B!(outvec2, m, v2)
133+
@test outvec2 == [10,22]
134+
123135
a = MMatrix{2,2,Int,4}()
124136
A_mul_B!(a, m, n)
125137
@test a::MMatrix{2,2,Int,4} == @MMatrix [10 13; 22 29]

0 commit comments

Comments
 (0)