Skip to content

Commit cc6b4ea

Browse files
committed
Add scheduler support in mul!
1 parent 037a2a0 commit cc6b4ea

File tree

3 files changed

+60
-8
lines changed

3 files changed

+60
-8
lines changed

src/tensors/blockiterator.jl

+2
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,5 @@ Base.IteratorEltype(::BlockIterator) = Base.HasEltype()
1313
Base.eltype(::Type{<:BlockIterator{T}}) where {T} = blocktype(T)
1414
Base.length(iter::BlockIterator) = length(iter.structure)
1515
Base.isdone(iter::BlockIterator, state...) = Base.isdone(iter.structure, state...)
16+
17+
Base.haskey(iter::BlockIterator, c) = haskey(iter.structure, c)

src/tensors/linalg.jl

+55-6
Original file line numberDiff line numberDiff line change
@@ -283,12 +283,47 @@ function LinearAlgebra.tr(t::AbstractTensorMap)
283283
end
284284

285285
# TensorMap multiplication
286-
function LinearAlgebra.mul!(tC::AbstractTensorMap,
287-
tA::AbstractTensorMap,
288-
tB::AbstractTensorMap, α=true, β=false)
286+
function LinearAlgebra.mul!(tC::AbstractTensorMap, tA::AbstractTensorMap,
287+
tB::AbstractTensorMap,
288+
α::Number, β::Number,
289+
backend::AbstractBackend=TO.DefaultBackend())
290+
if backend isa TO.DefaultBackend
291+
newbackend = TO.select_backend(mul!, tC, tA, tB)
292+
return mul!(tC, tA, tB, α, β, newbackend)
293+
elseif backend isa TO.NoBackend # error for missing backend
294+
TC = typeof(tC)
295+
TA = typeof(tA)
296+
TB = typeof(tB)
297+
throw(ArgumentError("No suitable backend found for `mul!` and tensor types $TC, $TA and $TB"))
298+
else # error for unknown backend
299+
TC = typeof(tC)
300+
TA = typeof(tA)
301+
TB = typeof(tB)
302+
throw(ArgumentError("Unknown backend for `mul!` and tensor types $TC, $TA and $TB"))
303+
end
304+
end
305+
306+
function TO.select_backend(::typeof(mul!), C::AbstractTensorMap, A::AbstractTensorMap,
307+
B::AbstractTensorMap)
308+
return TensorKitBackend()
309+
end
310+
311+
function LinearAlgebra.mul!(tC::AbstractTensorMap, tA::AbstractTensorMap,
312+
tB::AbstractTensorMap, α::Number, β::Number,
313+
backend::TensorKitBackend)
289314
compose(space(tA), space(tB)) == space(tC) ||
290315
throw(SpaceMismatch(lazy"$(space(tC)) ≠ $(space(tA)) * $(space(tB))"))
291316

317+
scheduler = backend.blockscheduler
318+
if isnothing(scheduler)
319+
return sequential_mul!(tC, tA, tB, α, β)
320+
else
321+
return threaded_mul!(tC, tA, tB, α, β, scheduler)
322+
end
323+
end
324+
325+
function sequential_mul!(tC::AbstractTensorMap, tA::AbstractTensorMap,
326+
tB::AbstractTensorMap, α::Number, β::Number)
292327
iterC = blocks(tC)
293328
iterA = blocks(tA)
294329
iterB = blocks(tB)
@@ -310,13 +345,13 @@ function LinearAlgebra.mul!(tC::AbstractTensorMap,
310345
elseif cB < cC
311346
nextB = iterate(iterB, stateB)
312347
else
313-
if β != one(β)
348+
if !isone(β)
314349
rmul!(C, β)
315350
end
316351
nextC = iterate(iterC, stateC)
317352
end
318353
else
319-
if β != one(β)
354+
if !isone(β)
320355
rmul!(C, β)
321356
end
322357
nextC = iterate(iterC, stateC)
@@ -325,7 +360,21 @@ function LinearAlgebra.mul!(tC::AbstractTensorMap,
325360
return tC
326361
end
327362

328-
# TODO: consider spawning threads for different blocks, support backends
363+
function threaded_mul!(tC::AbstractTensorMap, tA::AbstractTensorMap, tB::AbstractTensorMap,
364+
α::Number, β::Number, scheduler::Scheduler)
365+
# obtain cached data before multithreading
366+
bCs, bAs, bBs = blocks(tC), blocks(tA), blocks(tB)
367+
368+
tforeach(blocksectors(tC); scheduler) do c
369+
if haskey(bAs, c) # then also bBs should have it
370+
mul!(bCs[c], bAs[c], bBs[c], α, β)
371+
elseif !isone(β)
372+
scale!(bCs[c], β)
373+
end
374+
end
375+
376+
return tC
377+
end
329378

330379
# TensorMap inverse
331380
function Base.inv(t::AbstractTensorMap)

src/tensors/tensoroperations.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,10 @@ TO.tensorcost(t::AbstractTensorMap, i::Int) = dim(space(t, i))
149149
# TODO: figure out a name
150150
# TODO: what should be the default scheduler?
151151
# TODO: should we allow a separate scheduler for "blocks" and "subblocks"
152-
@kwdef struct TensorKitBackend{B<:AbstractBackend,S<:Scheduler} <: AbstractBackend
152+
@kwdef struct TensorKitBackend{B<:AbstractBackend,BS,SBS} <: AbstractBackend
153153
arraybackend::B = TO.DefaultBackend()
154-
scheduler::S = SerialScheduler()
154+
blockscheduler::BS = SerialScheduler()
155+
subblockscheduler::SBS = SerialScheduler()
155156
end
156157

157158
function TO.select_backend(::typeof(TO.tensoradd!), C::AbstractTensorMap,

0 commit comments

Comments
 (0)