Skip to content

Commit 189b140

Browse files
committed
Add scheduler support in mul!
1 parent c5e6d1e commit 189b140

File tree

2 files changed

+54
-4
lines changed

2 files changed

+54
-4
lines changed

src/tensors/blockiterator.jl

+2
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,5 @@ Base.IteratorEltype(::BlockIterator) = Base.HasEltype()
1313
Base.eltype(::Type{<:BlockIterator{T}}) where {T} = blocktype(T)
1414
Base.length(iter::BlockIterator) = length(iter.structure)
1515
Base.isdone(iter::BlockIterator, state...) = Base.isdone(iter.structure, state...)
16+
17+
Base.haskey(iter::BlockIterator, c) = haskey(iter.structure, c)

src/tensors/linalg.jl

+52-4
Original file line numberDiff line numberDiff line change
@@ -283,9 +283,43 @@ function LinearAlgebra.tr(t::AbstractTensorMap)
283283
end
284284

285285
# TensorMap multiplication
286-
function LinearAlgebra.mul!(tC::AbstractTensorMap,
287-
tA::AbstractTensorMap,
288-
tB::AbstractTensorMap, α=true, β=false)
286+
function LinearAlgebra.mul!(tC::AbstractTensorMap, tA::AbstractTensorMap,
287+
tB::AbstractTensorMap,
288+
α::Number, β::Number,
289+
backend=TO.DefaultBackend())
290+
if backend isa TO.DefaultBackend
291+
newbackend = TO.select_backend(mul!, tC, tA, tB)
292+
return mul!(tC, tA, tB, α, β, newbackend)
293+
elseif backend isa TO.NoBackend # error for missing backend
294+
TC = typeof(tC)
295+
TA = typeof(tA)
296+
TB = typeof(tB)
297+
throw(ArgumentError("No suitable backend found for `mul!` and tensor types $TC, $TA and $TB"))
298+
else # error for unknown backend
299+
TC = typeof(tC)
300+
TA = typeof(tA)
301+
TB = typeof(tB)
302+
throw(ArgumentError("Unknown backend for `mul!` and tensor types $TC, $TA and $TB"))
303+
end
304+
end
305+
306+
function TO.select_backend(::typeof(mul!), C::AbstractTensorMap, A::AbstractTensorMap,
307+
B::AbstractTensorMap)
308+
return SerialScheduler()
309+
end
310+
311+
function LinearAlgebra.mul!(tC::AbstractTensorMap, tA::AbstractTensorMap,
312+
tB::AbstractTensorMap, α::Number, β::Number,
313+
scheduler::Union{Nothing,Scheduler})
314+
if isnothing(scheduler)
315+
return sequential_mul!(tC, tA, tB, α, β)
316+
else
317+
return threaded_mul!(tC, tA, tB, α, β, scheduler)
318+
end
319+
end
320+
321+
function sequential_mul!(tC::AbstractTensorMap, tA::AbstractTensorMap,
322+
tB::AbstractTensorMap, α::Number, β::Number)
289323
compose(space(tA), space(tB)) == space(tC) ||
290324
throw(SpaceMismatch(lazy"$(space(tC)) ≠ $(space(tA)) * $(space(tB))"))
291325

@@ -325,7 +359,21 @@ function LinearAlgebra.mul!(tC::AbstractTensorMap,
325359
return tC
326360
end
327361

328-
# TODO: consider spawning threads for different blocks, support backends
362+
function threaded_mul!(tC::AbstractTensorMap, tA::AbstractTensorMap, tB::AbstractTensorMap,
363+
α::Number, β::Number, scheduler::Scheduler)
364+
# obtain cached data before multithreading
365+
bCs, bAs, bBs = blocks(tC), blocks(tA), blocks(tB)
366+
367+
tforeach(blocksectors(tC); scheduler) do c
368+
if haskey(bAs, c) # then also bBs should have it
369+
mul!(bCs[c], bAs[c], bBs[c], α, β)
370+
elseif !isone(β)
371+
scale!(bCs[c], β)
372+
end
373+
end
374+
375+
return tC
376+
end
329377

330378
# TensorMap inverse
331379
function Base.inv(t::AbstractTensorMap)

0 commit comments

Comments
 (0)