@@ -283,12 +283,47 @@ function LinearAlgebra.tr(t::AbstractTensorMap)
283
283
end
284
284
285
285
# TensorMap multiplication
286
- function LinearAlgebra. mul! (tC:: AbstractTensorMap ,
287
- tA:: AbstractTensorMap ,
288
- tB:: AbstractTensorMap , α= true , β= false )
286
+ function LinearAlgebra. mul! (tC:: AbstractTensorMap , tA:: AbstractTensorMap ,
287
+ tB:: AbstractTensorMap ,
288
+ α:: Number , β:: Number ,
289
+ backend:: AbstractBackend = TO. DefaultBackend ())
290
+ if backend isa TO. DefaultBackend
291
+ newbackend = TO. select_backend (mul!, tC, tA, tB)
292
+ return mul! (tC, tA, tB, α, β, newbackend)
293
+ elseif backend isa TO. NoBackend # error for missing backend
294
+ TC = typeof (tC)
295
+ TA = typeof (tA)
296
+ TB = typeof (tB)
297
+ throw (ArgumentError (" No suitable backend found for `mul!` and tensor types $TC , $TA and $TB " ))
298
+ else # error for unknown backend
299
+ TC = typeof (tC)
300
+ TA = typeof (tA)
301
+ TB = typeof (tB)
302
+ throw (ArgumentError (" Unknown backend for `mul!` and tensor types $TC , $TA and $TB " ))
303
+ end
304
+ end
305
+
306
+ function TO. select_backend (:: typeof (mul!), C:: AbstractTensorMap , A:: AbstractTensorMap ,
307
+ B:: AbstractTensorMap )
308
+ return TensorKitBackend ()
309
+ end
310
+
311
+ function LinearAlgebra. mul! (tC:: AbstractTensorMap , tA:: AbstractTensorMap ,
312
+ tB:: AbstractTensorMap , α:: Number , β:: Number ,
313
+ backend:: TensorKitBackend )
289
314
compose (space (tA), space (tB)) == space (tC) ||
290
315
throw (SpaceMismatch (lazy " $(space(tC)) ≠ $(space(tA)) * $(space(tB))" ))
291
316
317
+ scheduler = backend. blockscheduler
318
+ if isnothing (scheduler)
319
+ return sequential_mul! (tC, tA, tB, α, β)
320
+ else
321
+ return threaded_mul! (tC, tA, tB, α, β, scheduler)
322
+ end
323
+ end
324
+
325
+ function sequential_mul! (tC:: AbstractTensorMap , tA:: AbstractTensorMap ,
326
+ tB:: AbstractTensorMap , α:: Number , β:: Number )
292
327
iterC = blocks (tC)
293
328
iterA = blocks (tA)
294
329
iterB = blocks (tB)
@@ -310,13 +345,13 @@ function LinearAlgebra.mul!(tC::AbstractTensorMap,
310
345
elseif cB < cC
311
346
nextB = iterate (iterB, stateB)
312
347
else
313
- if β != one (β)
348
+ if ! isone (β)
314
349
rmul! (C, β)
315
350
end
316
351
nextC = iterate (iterC, stateC)
317
352
end
318
353
else
319
- if β != one (β)
354
+ if ! isone (β)
320
355
rmul! (C, β)
321
356
end
322
357
nextC = iterate (iterC, stateC)
@@ -325,7 +360,21 @@ function LinearAlgebra.mul!(tC::AbstractTensorMap,
325
360
return tC
326
361
end
327
362
328
- # TODO : consider spawning threads for different blocks, support backends
363
+ function threaded_mul! (tC:: AbstractTensorMap , tA:: AbstractTensorMap , tB:: AbstractTensorMap ,
364
+ α:: Number , β:: Number , scheduler:: Scheduler )
365
+ # obtain cached data before multithreading
366
+ bCs, bAs, bBs = blocks (tC), blocks (tA), blocks (tB)
367
+
368
+ tforeach (blocksectors (tC); scheduler) do c
369
+ if haskey (bAs, c) # then also bBs should have it
370
+ mul! (bCs[c], bAs[c], bBs[c], α, β)
371
+ elseif ! isone (β)
372
+ scale! (bCs[c], β)
373
+ end
374
+ end
375
+
376
+ return tC
377
+ end
329
378
330
379
# TensorMap inverse
331
380
function Base. inv (t:: AbstractTensorMap )
0 commit comments