@@ -283,9 +283,43 @@ function LinearAlgebra.tr(t::AbstractTensorMap)
283
283
end
284
284
285
285
# TensorMap multiplication
286
- function LinearAlgebra. mul! (tC:: AbstractTensorMap ,
287
- tA:: AbstractTensorMap ,
288
- tB:: AbstractTensorMap , α= true , β= false )
286
+ function LinearAlgebra. mul! (tC:: AbstractTensorMap , tA:: AbstractTensorMap ,
287
+ tB:: AbstractTensorMap ,
288
+ α:: Number , β:: Number ,
289
+ backend= TO. DefaultBackend ())
290
+ if backend isa TO. DefaultBackend
291
+ newbackend = TO. select_backend (mul!, tC, tA, tB)
292
+ return mul! (tC, tA, tB, α, β, newbackend)
293
+ elseif backend isa TO. NoBackend # error for missing backend
294
+ TC = typeof (tC)
295
+ TA = typeof (tA)
296
+ TB = typeof (tB)
297
+ throw (ArgumentError (" No suitable backend found for `mul!` and tensor types $TC , $TA and $TB " ))
298
+ else # error for unknown backend
299
+ TC = typeof (tC)
300
+ TA = typeof (tA)
301
+ TB = typeof (tB)
302
+ throw (ArgumentError (" Unknown backend for `mul!` and tensor types $TC , $TA and $TB " ))
303
+ end
304
+ end
305
+
306
+ function TO. select_backend (:: typeof (mul!), C:: AbstractTensorMap , A:: AbstractTensorMap ,
307
+ B:: AbstractTensorMap )
308
+ return SerialScheduler ()
309
+ end
310
+
311
+ function LinearAlgebra. mul! (tC:: AbstractTensorMap , tA:: AbstractTensorMap ,
312
+ tB:: AbstractTensorMap , α:: Number , β:: Number ,
313
+ scheduler:: Union{Nothing,Scheduler} )
314
+ if isnothing (scheduler)
315
+ return sequential_mul! (tC, tA, tB, α, β)
316
+ else
317
+ return threaded_mul! (tC, tA, tB, α, β, scheduler)
318
+ end
319
+ end
320
+
321
+ function sequential_mul! (tC:: AbstractTensorMap , tA:: AbstractTensorMap ,
322
+ tB:: AbstractTensorMap , α:: Number , β:: Number )
289
323
compose (space (tA), space (tB)) == space (tC) ||
290
324
throw (SpaceMismatch (lazy " $(space(tC)) ≠ $(space(tA)) * $(space(tB))" ))
291
325
@@ -325,7 +359,21 @@ function LinearAlgebra.mul!(tC::AbstractTensorMap,
325
359
return tC
326
360
end
327
361
328
- # TODO : consider spawning threads for different blocks, support backends
362
+ function threaded_mul! (tC:: AbstractTensorMap , tA:: AbstractTensorMap , tB:: AbstractTensorMap ,
363
+ α:: Number , β:: Number , scheduler:: Scheduler )
364
+ # obtain cached data before multithreading
365
+ bCs, bAs, bBs = blocks (tC), blocks (tA), blocks (tB)
366
+
367
+ tforeach (blocksectors (tC); scheduler) do c
368
+ if haskey (bAs, c) # then also bBs should have it
369
+ mul! (bCs[c], bAs[c], bBs[c], α, β)
370
+ elseif ! isone (β)
371
+ scale! (bCs[c], β)
372
+ end
373
+ end
374
+
375
+ return tC
376
+ end
329
377
330
378
# TensorMap inverse
331
379
function Base. inv (t:: AbstractTensorMap )
0 commit comments