Skip to content

Commit 6578475

Browse files
authored
Add rrule for DiagonalTensorMap and include tests (#234)
* Add `rrule` for `DiagonalTensorMap(::AbstractTensorMap)` * include tests
1 parent 8e3af86 commit 6578475

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

ext/TensorKitChainRulesCoreExt/constructors.jl

+12
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,18 @@ function ChainRulesCore.rrule(::Type{<:DiagonalTensorMap}, data::DenseVector, ar
4747
return D, DiagonalTensorMap_pullback
4848
end
4949

50+
function ChainRulesCore.rrule(::Type{DiagonalTensorMap}, t::AbstractTensorMap)
51+
d = DiagonalTensorMap(t)
52+
function DiagonalTensorMap_pullback(Δd_)
53+
Δt = similar(t) # no projector needed
54+
for (c, b) in blocks(unthunk(Δd_))
55+
copy!(block(Δt, c), Diagonal(b))
56+
end
57+
return NoTangent(), Δt
58+
end
59+
return d, DiagonalTensorMap_pullback
60+
end
61+
5062
function ChainRulesCore.rrule(::typeof(Base.getproperty), t::TensorMap, prop::Symbol)
5163
if prop === :data
5264
function getdata_pullback(Δdata)

test/ad.jl

+3
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,9 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
196196
test_rrule(DiagonalTensorMap, D.data, D.domain)
197197
test_rrule(Base.getproperty, D, :data)
198198
test_rrule(Base.getproperty, D1, :data)
199+
200+
test_rrule(DiagonalTensorMap, rand!(T1))
201+
test_rrule(DiagonalTensorMap, randn!(T))
199202
end
200203
end
201204

0 commit comments

Comments
 (0)