@@ -14,8 +14,11 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
14
14
dC = @thunk projectC (scale (ΔC, conj (β)))
15
15
dA = @thunk let
16
16
ipA = invperm (linearize (pA))
17
- _dA = zerovector (A, promote_add (ΔC, α))
18
- _dA = tensoradd! (_dA, ΔC, (ipA, ()), conjA, conjA ? α : conj (α), Zero (), ba... )
17
+ pdA = _repartition (ipA, A)
18
+ TA = promote_add (ΔC, α)
19
+ # TODO : allocator
20
+ _dA = tensoralloc_add (TA, ΔC, pdA, conjA, Val (false ))
21
+ _dA = tensoradd! (_dA, ΔC, pdA, conjA, conjA ? α : conj (α), Zero (), ba... )
19
22
return projectA (_dA)
20
23
end
21
24
dα = @thunk let
@@ -55,34 +58,37 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
55
58
function pullback (ΔC′)
56
59
ΔC = unthunk (ΔC′)
57
60
ipAB = invperm (linearize (pAB))
58
- pΔC = (TupleTools. getindices (ipAB, trivtuple (TO. numout (pA))),
59
- TupleTools. getindices (ipAB, TO. numout (pA) .+ trivtuple (TO. numin (pB))))
61
+ pΔC = _repartition (ipAB, TO. numout (pA))
60
62
61
63
dC = @thunk projectC (scale (ΔC, conj (β)))
62
64
dA = @thunk let
63
- ipA = (invperm (linearize (pA)), () )
65
+ ipA = _repartition (invperm (linearize (pA)), A )
64
66
conjΔC = conjA
65
67
conjB′ = conjA ? conjB : ! conjB
66
- _dA = zerovector (A,
67
- promote_contract ( scalartype (ΔC), scalartype (B), scalartype (α)))
68
+ TA = promote_contract ( scalartype (ΔC), scalartype (B), scalartype (α))
69
+ # TODO : allocator
68
70
tB = twist (B,
69
71
TupleTools. vcat (filter (x -> ! isdual (space (B, x)), pB[1 ]),
70
72
filter (x -> isdual (space (B, x)), pB[2 ])))
73
+ _dA = tensoralloc_contract (TA, ΔC, pΔC, conjΔC, tB, reverse (pB), conjB′, ipA,
74
+ Val (false ))
71
75
_dA = tensorcontract! (_dA,
72
76
ΔC, pΔC, conjΔC,
73
77
tB, reverse (pB), conjB′, ipA,
74
78
conjA ? α : conj (α), Zero (), ba... )
75
79
return projectA (_dA)
76
80
end
77
81
dB = @thunk let
78
- ipB = (invperm (linearize (pB)), () )
82
+ ipB = _repartition (invperm (linearize (pB)), B )
79
83
conjΔC = conjB
80
84
conjA′ = conjB ? conjA : ! conjA
81
- _dB = zerovector (B,
82
- promote_contract ( scalartype (ΔC), scalartype (A), scalartype (α)))
85
+ TB = promote_contract ( scalartype (ΔC), scalartype (A), scalartype (α))
86
+ # TODO : allocator
83
87
tA = twist (A,
84
88
TupleTools. vcat (filter (x -> isdual (space (A, x)), pA[1 ]),
85
89
filter (x -> ! isdual (space (A, x)), pA[2 ])))
90
+ _dB = tensoralloc_contract (TB, tA, reverse (pA), conjA′, ΔC, pΔC, conjΔC, ipB,
91
+ Val (false ))
86
92
_dB = tensorcontract! (_dB,
87
93
tA, reverse (pA), conjA′,
88
94
ΔC, pΔC, conjΔC, ipB,
@@ -121,12 +127,15 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensortrace!),
121
127
dC = @thunk projectC (scale (ΔC, conj (β)))
122
128
dA = @thunk let
123
129
ip = invperm ((linearize (p)... , q[1 ]. .. , q[2 ]. .. ))
130
+ pdA = _repartition (ip, A)
124
131
E = one! (TO. tensoralloc_add (scalartype (A), A, q, conjA))
125
132
twist! (E, filter (x -> ! isdual (space (E, x)), codomainind (E)))
126
- _dA = zerovector (A, promote_scale (ΔC, α))
127
- _dA = tensorproduct! (_dA, ΔC,
128
- (trivtuple (TO. numind (p)), ()), conjA, E,
129
- ((), trivtuple (TO. numind (q))), conjA, (ip, ()),
133
+ pE = ((), trivtuple (TO. numind (q)))
134
+ pΔC = (trivtuple (TO. numind (p)), ())
135
+ TA = promote_scale (ΔC, α)
136
+ # TODO : allocator
137
+ _dA = tensoralloc_contract (TA, ΔC, pΔC, conjA, E, pE, conjA, pdA, Val (false ))
138
+ _dA = tensorproduct! (_dA, ΔC, pΔC, conjA, E, pE, conjA, pdA,
130
139
conjA ? α : conj (α), Zero (), ba... )
131
140
return projectA (_dA)
132
141
end
0 commit comments