Skip to content

Commit 450ff9b

Browse files
lkdvosJutho
andauthored
Fix rrules of TensorOperations with DiagonalTensorMap (#210)
* Rewrite AD for TensorOperations in terms of `similar` instead of `zerovector` * Add testcase * Add links * Rewrite in terms of tensoralloc * Be consistent with indextuple lengths * try fix (without testing) * proper fix (hopefullly) --------- Co-authored-by: Jutho <[email protected]> Co-authored-by: Jutho <[email protected]>
1 parent a264735 commit 450ff9b

File tree

4 files changed

+46
-20
lines changed

4 files changed

+46
-20
lines changed

ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using LinearAlgebra
88
using TupleTools
99

1010
import TensorOperations as TO
11-
using TensorOperations: promote_contract
11+
using TensorOperations: promote_contract, tensoralloc_add, tensoralloc_contract
1212
using VectorInterface: promote_scale, promote_add
1313

1414
include("utility.jl")

ext/TensorKitChainRulesCoreExt/tensoroperations.jl

+23-14
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
1414
dC = @thunk projectC(scale(ΔC, conj(β)))
1515
dA = @thunk let
1616
ipA = invperm(linearize(pA))
17-
_dA = zerovector(A, promote_add(ΔC, α))
18-
_dA = tensoradd!(_dA, ΔC, (ipA, ()), conjA, conjA ? α : conj(α), Zero(), ba...)
17+
pdA = _repartition(ipA, A)
18+
TA = promote_add(ΔC, α)
19+
# TODO: allocator
20+
_dA = tensoralloc_add(TA, ΔC, pdA, conjA, Val(false))
21+
_dA = tensoradd!(_dA, ΔC, pdA, conjA, conjA ? α : conj(α), Zero(), ba...)
1922
return projectA(_dA)
2023
end
2124
= @thunk let
@@ -55,34 +58,37 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
5558
function pullback(ΔC′)
5659
ΔC = unthunk(ΔC′)
5760
ipAB = invperm(linearize(pAB))
58-
pΔC = (TupleTools.getindices(ipAB, trivtuple(TO.numout(pA))),
59-
TupleTools.getindices(ipAB, TO.numout(pA) .+ trivtuple(TO.numin(pB))))
61+
pΔC = _repartition(ipAB, TO.numout(pA))
6062

6163
dC = @thunk projectC(scale(ΔC, conj(β)))
6264
dA = @thunk let
63-
ipA = (invperm(linearize(pA)), ())
65+
ipA = _repartition(invperm(linearize(pA)), A)
6466
conjΔC = conjA
6567
conjB′ = conjA ? conjB : !conjB
66-
_dA = zerovector(A,
67-
promote_contract(scalartype(ΔC), scalartype(B), scalartype(α)))
68+
TA = promote_contract(scalartype(ΔC), scalartype(B), scalartype(α))
69+
# TODO: allocator
6870
tB = twist(B,
6971
TupleTools.vcat(filter(x -> !isdual(space(B, x)), pB[1]),
7072
filter(x -> isdual(space(B, x)), pB[2])))
73+
_dA = tensoralloc_contract(TA, ΔC, pΔC, conjΔC, tB, reverse(pB), conjB′, ipA,
74+
Val(false))
7175
_dA = tensorcontract!(_dA,
7276
ΔC, pΔC, conjΔC,
7377
tB, reverse(pB), conjB′, ipA,
7478
conjA ? α : conj(α), Zero(), ba...)
7579
return projectA(_dA)
7680
end
7781
dB = @thunk let
78-
ipB = (invperm(linearize(pB)), ())
82+
ipB = _repartition(invperm(linearize(pB)), B)
7983
conjΔC = conjB
8084
conjA′ = conjB ? conjA : !conjA
81-
_dB = zerovector(B,
82-
promote_contract(scalartype(ΔC), scalartype(A), scalartype(α)))
85+
TB = promote_contract(scalartype(ΔC), scalartype(A), scalartype(α))
86+
# TODO: allocator
8387
tA = twist(A,
8488
TupleTools.vcat(filter(x -> isdual(space(A, x)), pA[1]),
8589
filter(x -> !isdual(space(A, x)), pA[2])))
90+
_dB = tensoralloc_contract(TB, tA, reverse(pA), conjA′, ΔC, pΔC, conjΔC, ipB,
91+
Val(false))
8692
_dB = tensorcontract!(_dB,
8793
tA, reverse(pA), conjA′,
8894
ΔC, pΔC, conjΔC, ipB,
@@ -121,12 +127,15 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensortrace!),
121127
dC = @thunk projectC(scale(ΔC, conj(β)))
122128
dA = @thunk let
123129
ip = invperm((linearize(p)..., q[1]..., q[2]...))
130+
pdA = _repartition(ip, A)
124131
E = one!(TO.tensoralloc_add(scalartype(A), A, q, conjA))
125132
twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E)))
126-
_dA = zerovector(A, promote_scale(ΔC, α))
127-
_dA = tensorproduct!(_dA, ΔC,
128-
(trivtuple(TO.numind(p)), ()), conjA, E,
129-
((), trivtuple(TO.numind(q))), conjA, (ip, ()),
133+
pE = ((), trivtuple(TO.numind(q)))
134+
pΔC = (trivtuple(TO.numind(p)), ())
135+
TA = promote_scale(ΔC, α)
136+
# TODO: allocator
137+
_dA = tensoralloc_contract(TA, ΔC, pΔC, conjA, E, pE, conjA, pdA, Val(false))
138+
_dA = tensorproduct!(_dA, ΔC, pΔC, conjA, E, pE, conjA, pdA,
130139
conjA ? α : conj(α), Zero(), ba...)
131140
return projectA(_dA)
132141
end

ext/TensorKitChainRulesCoreExt/utility.jl

+8-5
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,21 @@
22
# -------
33
trivtuple(N) = ntuple(identity, N)
44

5-
function _repartition(p::IndexTuple, N₁::Int)
5+
Base.@constprop :aggressive function _repartition(p::IndexTuple, N₁::Int)
66
length(p) >= N₁ ||
77
throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)"))
8-
return p[1:N₁], p[(N₁ + 1):end]
8+
return TupleTools.getindices(p, trivtuple(N₁)),
9+
TupleTools.getindices(p, trivtuple(length(p) - N₁) .+ N₁)
10+
end
11+
Base.@constprop :aggressive function _repartition(p::Index2Tuple, N₁::Int)
12+
return _repartition(linearize(p), N₁)
913
end
10-
_repartition(p::Index2Tuple, N₁::Int) = _repartition(linearize(p), N₁)
1114
function _repartition(p::Union{IndexTuple,Index2Tuple}, ::Index2Tuple{N₁}) where {N₁}
1215
return _repartition(p, N₁)
1316
end
1417
function _repartition(p::Union{IndexTuple,Index2Tuple},
15-
::AbstractTensorMap{<:Any,N₁}) where {N₁}
16-
return _repartition(p, N₁)
18+
t::AbstractTensorMap)
19+
return _repartition(p, TensorKit.numout(t))
1720
end
1821

1922
TensorKit.block(t::ZeroTangent, c::Sector) = t

test/bugfixes.jl

+14
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
tensorfree!(t2)
4545
end
4646

47+
# https://github.com/Jutho/TensorKit.jl/issues/201
4748
@testset "Issue #201" begin
4849
function f(A::AbstractTensorMap)
4950
U, S, V, = tsvd(A)
@@ -71,4 +72,17 @@
7172
grad4, = Zygote.gradient(g, convert(Array, B₀))
7273
@test convert(Array, grad3) grad4
7374
end
75+
76+
# https://github.com/Jutho/TensorKit.jl/issues/209
77+
@testset "Issue #209" begin
78+
function f(T, D)
79+
@tensor T[1, 4, 1, 3] * D[3, 4]
80+
end
81+
V = Z2Space(2, 2)
82+
D = DiagonalTensorMap(randn(4), V)
83+
T = randn(V V V V)
84+
g1, = Zygote.gradient(f, T, D)
85+
g2, = Zygote.gradient(f, T, TensorMap(D))
86+
@test g1 g2
87+
end
7488
end

0 commit comments

Comments
 (0)