@@ -3,49 +3,71 @@ module LinearOperatorsChainRulesCoreExt
3
3
using LinearOperators
4
4
isdefined (Base, :get_extension ) ? (import ChainRulesCore) : (import .. ChainRulesCore)
5
5
6
- function ChainRulesCore. frule ((_, Δx, _), :: typeof (* ), op:: AbstractLinearOperator{T} , x:: AbstractVector{S} ) where {T, S}
7
- y = op* x
8
- Δy = op* Δx
6
+ function ChainRulesCore. frule (
7
+ (_, Δx, _),
8
+ :: typeof (* ),
9
+ op:: AbstractLinearOperator{T} ,
10
+ x:: AbstractVector{S} ,
11
+ ) where {T, S}
12
+ y = op * x
13
+ Δy = op * Δx
9
14
return y, Δy
10
15
end
11
- function ChainRulesCore. rrule (:: typeof (* ), op:: AbstractLinearOperator{T} , x:: AbstractVector{S} ) where {T, S}
12
- y = op* x
16
+ function ChainRulesCore. rrule (
17
+ :: typeof (* ),
18
+ op:: AbstractLinearOperator{T} ,
19
+ x:: AbstractVector{S} ,
20
+ ) where {T, S}
21
+ y = op * x
13
22
project_x = ChainRulesCore. ProjectTo (x)
14
23
function mul_pullback (ȳ)
15
- x̄ = project_x ( adjoint (op)* ChainRulesCore. unthunk (ȳ) )
16
- return ChainRulesCore. NoTangent (), ChainRulesCore. NoTangent (), x̄
24
+ x̄ = project_x (adjoint (op) * ChainRulesCore. unthunk (ȳ))
25
+ return ChainRulesCore. NoTangent (), ChainRulesCore. NoTangent (), x̄
17
26
end
18
27
return y, mul_pullback
19
28
end
20
29
21
- function ChainRulesCore. frule ((_, Δx, _), :: typeof (* ), x:: Union{LinearOperators.Adjoint{S, V}, LinearOperators.Transpose{S, V} } , op:: AbstractLinearOperator{T} ) where {T, S, V <: AbstractVector{S} }
22
- y = x* op
23
- Δy = Δx* op
30
+ function ChainRulesCore. frule (
31
+ (_, Δx, _),
32
+ :: typeof (* ),
33
+ x:: Union{LinearOperators.Adjoint{S, V}, LinearOperators.Transpose{S, V}} ,
34
+ op:: AbstractLinearOperator{T} ,
35
+ ) where {T, S, V <: AbstractVector{S} }
36
+ y = x * op
37
+ Δy = Δx * op
24
38
return y, Δy
25
39
end
26
- function ChainRulesCore. rrule (:: typeof (* ), x:: LinearOperators.Transpose{S, V} , op:: AbstractLinearOperator{T} ) where {T, S, V <: AbstractVector{S} }
27
- y = x* op
40
+ function ChainRulesCore. rrule (
41
+ :: typeof (* ),
42
+ x:: LinearOperators.Transpose{S, V} ,
43
+ op:: AbstractLinearOperator{T} ,
44
+ ) where {T, S, V <: AbstractVector{S} }
45
+ y = x * op
28
46
project_x = ChainRulesCore. ProjectTo (x)
29
47
function mul_pullback (ȳ)
30
- # needed to make sure that ȳ is recognized as Transposed
31
- # ȳ_ = transpose(collect(vec(ChainRulesCore.unthunk(ȳ))))
32
- ȳ_ = transpose (vec (ChainRulesCore. unthunk (ȳ)))
33
- x̄ = project_x (ȳ_* adjoint (op))
34
- return ChainRulesCore. NoTangent (), x̄, ChainRulesCore. NoTangent ()
48
+ # needed to make sure that ȳ is recognized as Transposed
49
+ # ȳ_ = transpose(collect(vec(ChainRulesCore.unthunk(ȳ))))
50
+ ȳ_ = transpose (vec (ChainRulesCore. unthunk (ȳ)))
51
+ x̄ = project_x (ȳ_ * adjoint (op))
52
+ return ChainRulesCore. NoTangent (), x̄, ChainRulesCore. NoTangent ()
35
53
end
36
54
return y, mul_pullback
37
55
end
38
- function ChainRulesCore. rrule (:: typeof (* ), x:: LinearOperators.Adjoint{S, V} , op:: AbstractLinearOperator{T} ) where {T, S, V <: AbstractVector{S} }
39
- y = x* op
56
+ function ChainRulesCore. rrule (
57
+ :: typeof (* ),
58
+ x:: LinearOperators.Adjoint{S, V} ,
59
+ op:: AbstractLinearOperator{T} ,
60
+ ) where {T, S, V <: AbstractVector{S} }
61
+ y = x * op
40
62
project_x = ChainRulesCore. ProjectTo (x)
41
63
function mul_pullback (ȳ)
42
- # needed to make sure that ȳ is recognized as Adjoint
43
- # ȳ_ = adjoint(collect(vec(ChainRulesCore.unthunk(ȳ))))
44
- ȳ_ = adjoint (conj .(vec (ChainRulesCore. unthunk (ȳ))))
45
- x̄ = project_x (ȳ_* adjoint (op))
46
- return ChainRulesCore. NoTangent (), x̄, ChainRulesCore. NoTangent ()
64
+ # needed to make sure that ȳ is recognized as Adjoint
65
+ # ȳ_ = adjoint(collect(vec(ChainRulesCore.unthunk(ȳ))))
66
+ ȳ_ = adjoint (conj .(vec (ChainRulesCore. unthunk (ȳ))))
67
+ x̄ = project_x (ȳ_ * adjoint (op))
68
+ return ChainRulesCore. NoTangent (), x̄, ChainRulesCore. NoTangent ()
47
69
end
48
70
return y, mul_pullback
49
71
end
50
72
51
- end # module
73
+ end # module
0 commit comments