Skip to content

Commit 55e0273

Browse files
committed
🤖 Format .jl files
1 parent 70752d4 commit 55e0273

4 files changed

+68
-46
lines changed

ext/LinearOperatorsChainRulesCoreExt.jl

+47-25
Original file line numberDiff line numberDiff line change
@@ -3,49 +3,71 @@ module LinearOperatorsChainRulesCoreExt
33
using LinearOperators
44
isdefined(Base, :get_extension) ? (import ChainRulesCore) : (import ..ChainRulesCore)
55

6-
function ChainRulesCore.frule((_, Δx, _), ::typeof(*), op::AbstractLinearOperator{T}, x::AbstractVector{S}) where {T, S}
7-
y = op*x
8-
Δy = op*Δx
6+
function ChainRulesCore.frule(
7+
(_, Δx, _),
8+
::typeof(*),
9+
op::AbstractLinearOperator{T},
10+
x::AbstractVector{S},
11+
) where {T, S}
12+
y = op * x
13+
Δy = op * Δx
914
return y, Δy
1015
end
11-
function ChainRulesCore.rrule(::typeof(*), op::AbstractLinearOperator{T}, x::AbstractVector{S}) where {T, S}
12-
y = op*x
16+
function ChainRulesCore.rrule(
17+
::typeof(*),
18+
op::AbstractLinearOperator{T},
19+
x::AbstractVector{S},
20+
) where {T, S}
21+
y = op * x
1322
project_x = ChainRulesCore.ProjectTo(x)
1423
function mul_pullback(ȳ)
15-
= project_x( adjoint(op)*ChainRulesCore.unthunk(ȳ) )
16-
return ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), x̄
24+
= project_x(adjoint(op) * ChainRulesCore.unthunk(ȳ))
25+
return ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), x̄
1726
end
1827
return y, mul_pullback
1928
end
2029

21-
function ChainRulesCore.frule((_, Δx, _), ::typeof(*), x::Union{LinearOperators.Adjoint{S, V}, LinearOperators.Transpose{S, V} }, op::AbstractLinearOperator{T}) where {T, S, V <: AbstractVector{S}}
22-
y = x*op
23-
Δy = Δx*op
30+
function ChainRulesCore.frule(
31+
(_, Δx, _),
32+
::typeof(*),
33+
x::Union{LinearOperators.Adjoint{S, V}, LinearOperators.Transpose{S, V}},
34+
op::AbstractLinearOperator{T},
35+
) where {T, S, V <: AbstractVector{S}}
36+
y = x * op
37+
Δy = Δx * op
2438
return y, Δy
2539
end
26-
function ChainRulesCore.rrule(::typeof(*), x::LinearOperators.Transpose{S, V}, op::AbstractLinearOperator{T}) where {T, S, V <: AbstractVector{S}}
27-
y = x*op
40+
function ChainRulesCore.rrule(
41+
::typeof(*),
42+
x::LinearOperators.Transpose{S, V},
43+
op::AbstractLinearOperator{T},
44+
) where {T, S, V <: AbstractVector{S}}
45+
y = x * op
2846
project_x = ChainRulesCore.ProjectTo(x)
2947
function mul_pullback(ȳ)
30-
# needed to make sure that ȳ is recognized as Transposed
31-
# ȳ_ = transpose(collect(vec(ChainRulesCore.unthunk(ȳ))))
32-
ȳ_ = transpose(vec(ChainRulesCore.unthunk(ȳ)))
33-
= project_x(ȳ_*adjoint(op))
34-
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
48+
# needed to make sure that ȳ is recognized as Transposed
49+
# ȳ_ = transpose(collect(vec(ChainRulesCore.unthunk(ȳ))))
50+
ȳ_ = transpose(vec(ChainRulesCore.unthunk(ȳ)))
51+
= project_x(ȳ_ * adjoint(op))
52+
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
3553
end
3654
return y, mul_pullback
3755
end
38-
function ChainRulesCore.rrule(::typeof(*), x::LinearOperators.Adjoint{S, V}, op::AbstractLinearOperator{T}) where {T, S, V <: AbstractVector{S}}
39-
y = x*op
56+
function ChainRulesCore.rrule(
57+
::typeof(*),
58+
x::LinearOperators.Adjoint{S, V},
59+
op::AbstractLinearOperator{T},
60+
) where {T, S, V <: AbstractVector{S}}
61+
y = x * op
4062
project_x = ChainRulesCore.ProjectTo(x)
4163
function mul_pullback(ȳ)
42-
# needed to make sure that ȳ is recognized as Adjoint
43-
# ȳ_ = adjoint(collect(vec(ChainRulesCore.unthunk(ȳ))))
44-
ȳ_ = adjoint(conj.(vec(ChainRulesCore.unthunk(ȳ))))
45-
= project_x(ȳ_*adjoint(op))
46-
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
64+
# needed to make sure that ȳ is recognized as Adjoint
65+
# ȳ_ = adjoint(collect(vec(ChainRulesCore.unthunk(ȳ))))
66+
ȳ_ = adjoint(conj.(vec(ChainRulesCore.unthunk(ȳ))))
67+
= project_x(ȳ_ * adjoint(op))
68+
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
4769
end
4870
return y, mul_pullback
4971
end
5072

51-
end # module
73+
end # module

src/LinearOperators.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ include("deprecated.jl")
3232
@static if !isdefined(Base, :get_extension)
3333
import Requires
3434
end
35-
35+
3636
@static if !isdefined(Base, :get_extension)
3737
function __init__()
3838
Requires.@require ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" begin

test/runtests.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@ include("test_callable.jl")
1414
include("test_deprecated.jl")
1515
include("test_normest.jl")
1616
include("test_diag.jl")
17-
include("test_chainrules.jl")
17+
include("test_chainrules.jl")

test/test_chainrules.jl

+19-19
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,51 @@
11
using Zygote
22

3-
function matmulOp(mat::AbstractArray{T}) where T
4-
function prod!(res,x)
5-
for i in axes(mat,1)
6-
res[i] = transpose(mat[i,:])*x
7-
end
3+
function matmulOp(mat::AbstractArray{T}) where {T}
4+
function prod!(res, x)
5+
for i in axes(mat, 1)
6+
res[i] = transpose(mat[i, :]) * x
7+
end
88
end
99

10-
function ctprod!(res,x)
11-
for i in axes(mat,2)
12-
res[i] = dot(mat[:,i],x)
13-
end
10+
function ctprod!(res, x)
11+
for i in axes(mat, 2)
12+
res[i] = dot(mat[:, i], x)
13+
end
1414
end
1515

16-
return LinearOperator{T}(size(mat,1),size(mat,2),false, false, prod!, nothing, ctprod!)
16+
return LinearOperator{T}(size(mat, 1), size(mat, 2), false, false, prod!, nothing, ctprod!)
1717
end
1818

1919
function test_chainrules()
2020
@testset ExtendedTestSet "Chainrules" begin
21-
for (M,N) in zip([2,3,8,7], [2,4,8,16])
21+
for (M, N) in zip([2, 3, 8, 7], [2, 4, 8, 16])
2222
for T in [Float64, ComplexF64]
2323
mat = simple_matrix(T, M, N)
2424
op = matmulOp(mat)
25-
x = rand(T,N)
25+
x = rand(T, N)
2626
xᵀ = transpose(x[1:M])
2727
xᴴ = adjoint(x[1:M])
2828

2929
# test op*x
30-
y, g = Zygote.withgradient(v->sum(abs.(op*v)), x)
31-
y2, g2 = Zygote.withgradient(v->sum(abs.(mat*v)), x)
30+
y, g = Zygote.withgradient(v -> sum(abs.(op * v)), x)
31+
y2, g2 = Zygote.withgradient(v -> sum(abs.(mat * v)), x)
3232
@test isapprox(y, y2)
3333
@test isapprox(g[1], g2[1])
3434

3535
# test xᵀ*op
36-
yt, gt = Zygote.withgradient(v->sum(abs.(v*op)), xᵀ)
37-
yt2, gt2 = Zygote.withgradient(v->sum(abs.(v*mat)), xᵀ)
36+
yt, gt = Zygote.withgradient(v -> sum(abs.(v * op)), xᵀ)
37+
yt2, gt2 = Zygote.withgradient(v -> sum(abs.(v * mat)), xᵀ)
3838
@test isapprox(yt, yt2)
3939
@test isapprox(gt[1], gt2[1])
4040

4141
# test xᴴ*op
42-
yh, gh = Zygote.withgradient(v->sum(abs.(v*op)), xᴴ)
43-
yh2, gh2 = Zygote.withgradient(v->sum(abs.(v*mat)), xᴴ)
42+
yh, gh = Zygote.withgradient(v -> sum(abs.(v * op)), xᴴ)
43+
yh2, gh2 = Zygote.withgradient(v -> sum(abs.(v * mat)), xᴴ)
4444
@test isapprox(yh, yh2)
4545
@test isapprox(gh[1], gh2[1])
4646
end
4747
end
4848
end
4949
end
5050

51-
test_chainrules()
51+
test_chainrules()

0 commit comments

Comments
 (0)