@@ -13,14 +13,6 @@ import LinearAlgebra: BlasFloat, matprod, mul!
13
13
@inline * (A:: StaticArray{Tuple{N,1},<:Any,2} , B:: Adjoint{<:Any,<:StaticVector} ) where {N} = vec (A) * B
14
14
@inline * (A:: StaticArray{Tuple{N,1},<:Any,2} , B:: Transpose{<:Any,<:StaticVector} ) where {N} = vec (A) * B
15
15
16
- @inline mul! (dest:: StaticVecOrMat , A:: StaticMatrix , B:: StaticVector ) = _mul! (Size (dest), dest, Size (A), Size (B), A, B)
17
- @inline mul! (dest:: StaticVecOrMat , A:: StaticMatrix , B:: StaticMatrix ) = _mul! (Size (dest), dest, Size (A), Size (B), A, B)
18
- @inline mul! (dest:: StaticVecOrMat , A:: StaticVector , B:: StaticMatrix ) = mul! (dest, reshape (A, Size (Size (A)[1 ], 1 )), B)
19
- @inline mul! (dest:: StaticVecOrMat , A:: StaticVector , B:: Transpose{<:Any, <:StaticVector} ) = _mul! (Size (dest), dest, Size (A), Size (B), A, B)
20
- @inline mul! (dest:: StaticVecOrMat , A:: StaticVector , B:: Adjoint{<:Any, <:StaticVector} ) = _mul! (Size (dest), dest, Size (A), Size (B), A, B)
21
- # @inline *{TA<:LinearAlgebra.BlasFloat,Tb}(A::StaticMatrix{TA}, b::StaticVector{Tb})
22
-
23
-
24
16
25
17
# Implementations
26
18
97
89
98
90
# Heuristic choice between BLAS and explicit unrolling (or chunk-based unrolling)
99
91
if sa[1 ]* sa[2 ]* sb[2 ] >= 14 * 14 * 14
92
+ Sa = TSize {size(S),false} ()
93
+ Sb = TSize {sa,false} ()
94
+ Sc = TSize {sb,false} ()
95
+ _add = MulAddMul (true ,false )
100
96
return quote
101
97
@_inline_meta
102
98
C = similar (a, T, $ S)
103
- mul_blas! ($ S , C, Sa, Sb, a, b)
99
+ mul_blas! ($ Sa , C, $ Sa, $ Sb, a, b, $ _add )
104
100
return C
105
101
end
106
102
elseif sa[1 ]* sa[2 ]* sb[2 ] < 8 * 8 * 8
177
173
# Do a custom b[:, k2] to return a SVector (an isbitstype type) rather than (possibly) a mutable type. Avoids allocation == faster
178
174
tmp_type_in = :(SVector{$ (sb[1 ]), T})
179
175
tmp_type_out = :(SVector{$ (sa[1 ]), T})
180
- vect_exprs = [:($ (Symbol (" tmp_$k2 " )):: $tmp_type_out = partly_unrolled_multiply (Size (a), Size ($ (sb[1 ])), a,
176
+ vect_exprs = [:($ (Symbol (" tmp_$k2 " )):: $tmp_type_out = partly_unrolled_multiply (TSize (a), TSize ($ (sb[1 ])), a,
181
177
$ (Expr (:call , tmp_type_in, [Expr (:ref , :b , LinearIndices (sb)[i, k2]) for i = 1 : sb[1 ]]. .. ))):: $tmp_type_out )
182
178
for k2 = 1 : sb[2 ]]
183
179
@@ -193,201 +189,4 @@ end
193
189
end
194
190
end
195
191
196
- @generated function partly_unrolled_multiply (:: Size{sa} , :: Size{sb} , a:: StaticMatrix{<:Any, <:Any, Ta} , b:: StaticArray{<:Tuple, Tb} ) where {sa, sb, Ta, Tb}
197
- if sa[2 ] != sb[1 ]
198
- throw (DimensionMismatch (" Tried to multiply arrays of size $sa and $sb " ))
199
- end
200
-
201
- if sa[2 ] != 0
202
- exprs = [reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)), [:(a[$ (LinearIndices (sa)[k, j])]* b[$ j]) for j = 1 : sa[2 ]]) for k = 1 : sa[1 ]]
203
- else
204
- exprs = [:(zero (promote_op (matprod,Ta,Tb))) for k = 1 : sa[1 ]]
205
- end
206
-
207
- return quote
208
- $ (Expr (:meta ,:noinline ))
209
- @inbounds return SVector (tuple ($ (exprs... )))
210
- end
211
- end
212
-
213
- # TODO aliasing problems if c === b?
214
- @generated function _mul! (:: Size{sc} , c:: StaticVector , :: Size{sa} , :: Size{sb} , a:: StaticMatrix , b:: StaticVector ) where {sa, sb, sc}
215
- if sb[1 ] != sa[2 ] || sc[1 ] != sa[1 ]
216
- throw (DimensionMismatch (" Tried to multiply arrays of size $sa and $sb and assign to array of size $sc " ))
217
- end
218
-
219
- if sa[2 ] != 0
220
- exprs = [:(c[$ k] = $ (reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)), [:(a[$ (LinearIndices (sa)[k, j])]* b[$ j]) for j = 1 : sa[2 ]]))) for k = 1 : sa[1 ]]
221
- else
222
- exprs = [:(c[$ k] = zero (eltype (c))) for k = 1 : sa[1 ]]
223
- end
224
-
225
- return quote
226
- @_inline_meta
227
- @inbounds $ (Expr (:block , exprs... ))
228
- return c
229
- end
230
- end
231
-
232
- @generated function _mul! (:: Size{sc} , c:: StaticMatrix , :: Size{sa} , :: Size{sb} , a:: StaticVector ,
233
- b:: Union{Transpose{<:Any, <:StaticVector}, Adjoint{<:Any, <:StaticVector}} ) where {sa, sb, sc}
234
- if sa[1 ] != sc[1 ] || sb[2 ] != sc[2 ]
235
- throw (DimensionMismatch (" Tried to multiply arrays of size $sa and $sb and assign to array of size $sc " ))
236
- end
237
-
238
- exprs = [:(c[$ (LinearIndices (sc)[i, j])] = a[$ i] * b[$ j]) for i = 1 : sa[1 ], j = 1 : sb[2 ]]
239
-
240
- return quote
241
- @_inline_meta
242
- @inbounds $ (Expr (:block , exprs... ))
243
- return c
244
- end
245
- end
246
-
247
- @generated function _mul! (Sc:: Size{sc} , c:: StaticMatrix{<:Any, <:Any, Tc} , Sa:: Size{sa} , Sb:: Size{sb} , a:: StaticMatrix{<:Any, <:Any, Ta} , b:: StaticMatrix{<:Any, <:Any, Tb} ) where {sa, sb, sc, Ta, Tb, Tc}
248
- can_blas = Tc == Ta && Tc == Tb && Tc <: BlasFloat
249
-
250
- if can_blas
251
- if sa[1 ] * sa[2 ] * sb[2 ] < 4 * 4 * 4
252
- return quote
253
- @_inline_meta
254
- mul_unrolled! (Sc, c, Sa, Sb, a, b)
255
- return c
256
- end
257
- elseif sa[1 ] * sa[2 ] * sb[2 ] < 14 * 14 * 14 # Something seems broken for this one with large matrices (becomes allocating)
258
- return quote
259
- @_inline_meta
260
- mul_unrolled_chunks! (Sc, c, Sa, Sb, a, b)
261
- return c
262
- end
263
- else
264
- return quote
265
- @_inline_meta
266
- mul_blas! (Sc, c, Sa, Sb, a, b)
267
- return c
268
- end
269
- end
270
- else
271
- if sa[1 ] * sa[2 ] * sb[2 ] < 4 * 4 * 4
272
- return quote
273
- @_inline_meta
274
- mul_unrolled! (Sc, c, Sa, Sb, a, b)
275
- return c
276
- end
277
- else
278
- return quote
279
- @_inline_meta
280
- mul_unrolled_chunks! (Sc, c, Sa, Sb, a, b)
281
- return c
282
- end
283
- end
284
- end
285
- end
286
-
287
-
288
- @generated function mul_blas! (:: Size{s} , c:: StaticMatrix{<:Any, <:Any, T} , :: Size{sa} , :: Size{sb} , a:: StaticMatrix{<:Any, <:Any, T} , b:: StaticMatrix{<:Any, <:Any, T} ) where {s,sa,sb, T <: BlasFloat }
289
- if sb[1 ] != sa[2 ] || sa[1 ] != s[1 ] || sb[2 ] != s[2 ]
290
- throw (DimensionMismatch (" Tried to multiply arrays of size $sa and $sb and assign to array of size $s " ))
291
- end
292
-
293
- if sa[1 ] > 0 && sa[2 ] > 0 && sb[2 ] > 0
294
- # This code adapted from `gemm!()` in base/linalg/blas.jl
295
-
296
- if T == Float64
297
- gemm = :dgemm_
298
- elseif T == Float32
299
- gemm = :sgemm_
300
- elseif T == Complex{Float64}
301
- gemm = :zgemm_
302
- else # T == Complex{Float32}
303
- gemm = :cgemm_
304
- end
305
-
306
- blascall = quote
307
- ccall ((LinearAlgebra. BLAS. @blasfunc ($ gemm), LinearAlgebra. BLAS. libblas), Nothing,
308
- (Ref{UInt8}, Ref{UInt8}, Ref{LinearAlgebra. BLAS. BlasInt}, Ref{LinearAlgebra. BLAS. BlasInt},
309
- Ref{LinearAlgebra. BLAS. BlasInt}, Ref{$ T}, Ptr{$ T}, Ref{LinearAlgebra. BLAS. BlasInt},
310
- Ptr{$ T}, Ref{LinearAlgebra. BLAS. BlasInt}, Ref{$ T}, Ptr{$ T},
311
- Ref{LinearAlgebra. BLAS. BlasInt}),
312
- transA, transB, m, n,
313
- ka, alpha, a, strideA,
314
- b, strideB, beta, c,
315
- strideC)
316
- end
317
-
318
- return quote
319
- alpha = one (T)
320
- beta = zero (T)
321
- transA = ' N'
322
- transB = ' N'
323
- m = $ (sa[1 ])
324
- ka = $ (sa[2 ])
325
- kb = $ (sb[1 ])
326
- n = $ (sb[2 ])
327
- strideA = $ (sa[1 ])
328
- strideB = $ (sb[1 ])
329
- strideC = $ (s[1 ])
330
-
331
- $ blascall
332
-
333
- return c
334
- end
335
- else
336
- throw (DimensionMismatch (" Cannot call BLAS gemm with zero-dimension arrays, attempted $sa * $sb -> $s ." ))
337
- end
338
- end
339
-
340
-
341
- @generated function mul_unrolled! (:: Size{sc} , c:: StaticMatrix , :: Size{sa} , :: Size{sb} , a:: StaticMatrix , b:: StaticMatrix ) where {sa, sb, sc}
342
- if sb[1 ] != sa[2 ] || sa[1 ] != sc[1 ] || sb[2 ] != sc[2 ]
343
- throw (DimensionMismatch (" Tried to multiply arrays of size $sa and $sb and assign to array of size $sc " ))
344
- end
345
-
346
- if sa[2 ] != 0
347
- exprs = [:(c[$ (LinearIndices (sc)[k1, k2])] = $ (reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)), [:(a[$ (LinearIndices (sa)[k1, j])]* b[$ (LinearIndices (sb)[j, k2])]) for j = 1 : sa[2 ]]))) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
348
- else
349
- exprs = [:(c[$ (LinearIndices (sc)[k1, k2])] = zero (eltype (c))) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
350
- end
351
-
352
- return quote
353
- @_inline_meta
354
- @inbounds $ (Expr (:block , exprs... ))
355
- end
356
- end
357
-
358
- @generated function mul_unrolled_chunks! (:: Size{sc} , c:: StaticMatrix , :: Size{sa} , :: Size{sb} , a:: StaticMatrix , b:: StaticMatrix ) where {sa, sb, sc}
359
- if sb[1 ] != sa[2 ] || sa[1 ] != sc[1 ] || sb[2 ] != sc[2 ]
360
- throw (DimensionMismatch (" Tried to multiply arrays of size $sa and $sb and assign to array of size $sc " ))
361
- end
362
-
363
- # vect_exprs = [:($(Symbol("tmp_$k2")) = partly_unrolled_multiply(A, B[:, $k2])) for k2 = 1:sB[2]]
364
-
365
- # Do a custom b[:, k2] to return a SVector (an isbitstype type) rather than a mutable type. Avoids allocation == faster
366
- tmp_type = SVector{sb[1 ], eltype (c)}
367
- vect_exprs = [:($ (Symbol (" tmp_$k2 " )) = partly_unrolled_multiply ($ (Size (sa)), $ (Size (sb[1 ])), a, $ (Expr (:call , tmp_type, [Expr (:ref , :b , LinearIndices (sb)[i, k2]) for i = 1 : sb[1 ]]. .. )))) for k2 = 1 : sb[2 ]]
368
-
369
- exprs = [:(c[$ (LinearIndices (sc)[k1, k2])] = $ (Symbol (" tmp_$k2 " ))[$ k1]) for k1 = 1 : sa[1 ], k2 = 1 : sb[2 ]]
370
-
371
- return quote
372
- @_inline_meta
373
- @inbounds $ (Expr (:block , vect_exprs... ))
374
- @inbounds $ (Expr (:block , exprs... ))
375
- end
376
- end
377
-
378
- # function mul_blas(a, b, c, A, B)
379
- # q
380
- # end
381
-
382
- # The idea here is to get pointers to stack variables and call BLAS.
383
- # This saves an aweful lot of time compared to copying SArray's to Ref{SArray{...}}
384
- # and using BLAS should be fastest for (very) large SArrays
385
-
386
- # Here is an LLVM function that gets the pointer to its input, %x
387
- # After this we would make the ccall above.
388
192
#
389
- # define i8* @f(i32 %x) #0 {
390
- # %1 = alloca i32, align 4
391
- # store i32 %x, i32* %1, align 4
392
- # ret i32* %1
393
- # }
0 commit comments