Skip to content

Commit e8b900d

Browse files
committed
make otimes AD planar-compatible
1 parent f9847a1 commit e8b900d

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

ext/TensorKitChainRulesCoreExt.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,15 @@ function ChainRulesCore.rrule(::typeof(⊗), A::AbstractTensorMap, B::AbstractTe
118118
pB = (allind(B), ())
119119
dA = zerovector(A,
120120
promote_contract(scalartype(ΔC), scalartype(B)))
121-
dA = tensorcontract!(dA, ipA, ΔC, pΔC, :N, B, pB, :C)
121+
dA = planarcontract!(dA, ΔC, pΔC, :N, B, pB, :C, ipA, One(), Zero())
122122
return projectA(dA)
123123
end
124124
dB_ = @thunk begin
125125
ipB = (codomainind(B), domainind(B))
126126
pA = ((), allind(A))
127127
dB = zerovector(B,
128128
promote_contract(scalartype(ΔC), scalartype(A)))
129-
dB = tensorcontract!(dB, ipB, A, pA, :C, ΔC, pΔC, :N)
129+
dB = planarcontract!(dB, A, pA, :C, ΔC, pΔC, :N, ipB, One(), Zero())
130130
return projectB(dB)
131131
end
132132
return NoTangent(), dA_, dB_

0 commit comments

Comments
 (0)