Skip to content

Commit 69e611b

Browse files
Add reset function for Diagonal QN operators, fix bug SpectralGradient (#266)
* add reset for Diagonal QN operators, fix bug SpectralGradient * Apply suggestions from code review Co-authored-by: tmigot <[email protected]> Co-authored-by: tmigot <[email protected]>
1 parent f2fe71c commit 69e611b

File tree

2 files changed

+46
-5
lines changed

2 files changed

+46
-5
lines changed

src/DiagonalHessianApproximation.jl

+29-4
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,18 @@ function push!(
7878
return B
7979
end
8080

81+
"""
82+
reset!(op::DiagonalQN)
83+
Resets the DiagonalQN data of the given operator.
84+
"""
85+
function reset!(op::DiagonalQN{T}) where {T <: Real}
86+
op.d .= one(T)
87+
op.nprod = 0
88+
op.ntprod = 0
89+
op.nctprod = 0
90+
return op
91+
end
92+
8193
"""
8294
Implementation of a spectral gradient quasi-Newton approximation described in
8395
@@ -87,7 +99,7 @@ https://doi.org/10.18637/jss.v060.i03
8799
"""
88100
mutable struct SpectralGradient{T <: Real, I <: Integer, F} <:
89101
AbstractDiagonalQuasiNewtonOperator{T}
90-
d::T # Diagonal coefficient of the operator (multiple of the identity)
102+
d::Vector{T} # Diagonal coefficient of the operator (multiple of the identity)
91103
nrow::I
92104
ncol::I
93105
symmetric::Bool
@@ -114,8 +126,9 @@ The approximation is defined as σI.
114126
- `σ::Real`: initial positive multiple of the identity;
115127
- `n::Int`: operator size.
116128
"""
117-
function SpectralGradient(d::T, n::I) where {T <: Real, I <: Integer}
118-
@assert d > 0
129+
function SpectralGradient::T, n::I) where {T <: Real, I <: Integer}
130+
@assert σ > 0
131+
d = [σ]
119132
prod = (res, v, α, β) -> mulSquareOpDiagonal!(res, d, v, α, β)
120133
SpectralGradient(d, n, n, true, true, prod, prod, prod, 0, 0, 0, true, true, true)
121134
end
@@ -131,6 +144,18 @@ function push!(
131144
if all(s .== 0)
132145
error("Cannot divide by zero and s .= 0")
133146
end
134-
B.d = dot(s, y) / dot(s, s)
147+
B.d[1] = dot(s, y) / dot(s, s)
135148
return B
136149
end
150+
151+
"""
152+
reset!(op::SpectralGradient)
153+
Resets the SpectralGradient data of the given operator.
154+
"""
155+
function reset!(op::SpectralGradient{T}) where {T <: Real}
156+
op.d[1] = one(T)
157+
op.nprod = 0
158+
op.ntprod = 0
159+
op.nctprod = 0
160+
return op
161+
end

test/test_diag.jl

+17-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ end
6363
Bref = -5 / 2
6464
end
6565
push!(B, s, y)
66-
@test abs(B.d - Bref) <= 1e-10
66+
@test abs(B.d[1] - Bref) <= 1e-10
6767
end
6868
end
6969

@@ -81,3 +81,19 @@ end
8181
mul!(u, C, v)
8282
@test (@allocated mul!(u, C, v)) == 0
8383
end
84+
85+
@testset "reset" begin
86+
B = DiagonalQN([1.0, -1.0, 1.0], false)
87+
s = x1 - x0
88+
y = ∇f(x1) - ∇f(x0)
89+
push!(B, s, y)
90+
reset!(B)
91+
@test B * x0 == x0
92+
93+
B = SpectralGradient(2.5, 3)
94+
s = x1 - x0
95+
y = ∇f(x1) - ∇f(x0)
96+
push!(B, s, y)
97+
reset!(B)
98+
@test B * x0 == x0
99+
end

0 commit comments

Comments
 (0)