Skip to content

Commit ab0de4e

Browse files
Eliminate useless allocation for diagonal quasi-Newton operators (#337)
Co-authored-by: Dominique <[email protected]>
1 parent 28fd5ae commit ab0de4e

File tree

2 files changed

+60
-22
lines changed

2 files changed

+60
-22
lines changed

src/DiagonalHessianApproximation.jl

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -47,23 +47,22 @@ end
4747
# y = ∇f(x_{k+1}) - ∇f(x_k)
4848
function push!(
4949
B::DiagonalPSB{T, I, V, F},
50-
s0::V,
51-
y0::V,
50+
s::V,
51+
y::V,
5252
) where {T <: Real, I <: Integer, V <: AbstractVector{T}, F}
53-
s0Norm = norm(s0, 2)
54-
if s0Norm == 0
53+
sNorm = norm(s, 2)
54+
if sNorm == 0
5555
error("Cannot update DiagonalQN operator with s=0")
5656
end
5757
# sᵀBs = sᵀy can be scaled by ||s||² without changing the update
58-
s = (si / s0Norm for si s0)
5958
s2 = (si^2 for si s)
60-
y = (yi / s0Norm for yi y0)
61-
trA2 = dot(s2, s2)
62-
sT_y = dot(s, y)
63-
sT_B_s = dot(s2, B.d)
59+
sNorm2 = sNorm^2
60+
trA2 = dot(s2, s2) / sNorm2^2
61+
sT_y = dot(s, y) / sNorm2
62+
sT_B_s = dot(s2, B.d) / sNorm2
6463
q = sT_y - sT_B_s
6564
q /= trA2
66-
B.d .+= q .* s .^ 2
65+
B.d .+= q / sNorm2 .* s .^ 2
6766
return B
6867
end
6968

@@ -126,25 +125,24 @@ end
126125
# y = ∇f(x_{k+1}) - ∇f(x_k)
127126
function push!(
128127
B::DiagonalAndrei{T, I, V, F},
129-
s0::V,
130-
y0::V,
128+
s::V,
129+
y::V,
131130
) where {T <: Real, I <: Integer, V <: AbstractVector{T}, F}
132-
s0Norm = norm(s0, 2)
133-
if s0Norm == 0
131+
sNorm = norm(s, 2)
132+
if sNorm == 0
134133
error("Cannot update DiagonalQN operator with s=0")
135134
end
136135
# sᵀBs = sᵀy can be scaled by ||s||² without changing the update
137-
s = (si / s0Norm for si s0)
138136
s2 = (si^2 for si s)
139-
y = (yi / s0Norm for yi y0)
140-
trA2 = dot(s2, s2)
141-
sT_y = dot(s, y)
142-
sT_B_s = dot(s2, B.d)
137+
sNorm2 = sNorm^2
138+
trA2 = dot(s2, s2) / sNorm2^2
139+
sT_y = dot(s, y) / sNorm2
140+
sT_B_s = dot(s2, B.d) / sNorm2
143141
q = sT_y - sT_B_s
144-
sT_s = dot(s, s)
142+
sT_s = dot(s, s) / sNorm2
145143
q += sT_s
146144
q /= trA2
147-
B.d .+= q .* s .^ 2 .- 1
145+
B.d .+= q / sNorm2 .* s .^ 2 .- 1
148146
return B
149147
end
150148

@@ -199,7 +197,7 @@ function push!(
199197
s::V,
200198
y::V,
201199
) where {T <: Real, I <: Integer, F, V <: AbstractVector{T}}
202-
if all(s .== 0)
200+
if all(x -> x == 0, s)
203201
error("Cannot divide by zero and s .= 0")
204202
end
205203
B.d[1] = dot(s, y) / dot(s, s)

test/test_diag.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,40 @@
1+
"""
2+
@wrappedallocs(expr)
3+
4+
Given an expression, this macro wraps that expression inside a new function
5+
which will evaluate that expression and measure the amount of memory allocated
6+
by the expression. Wrapping the expression in a new function allows for more
7+
accurate memory allocation detection when using global variables (e.g. when
8+
at the REPL).
9+
10+
This code is based on that of https://github.com/JuliaAlgebra/TypedPolynomials.jl/blob/master/test/runtests.jl
11+
12+
For example, `@wrappedallocs(x + y)` produces:
13+
14+
```julia
15+
function g(x1, x2)
16+
@allocated x1 + x2
17+
end
18+
g(x, y)
19+
```
20+
21+
You can use this macro in a unit test to verify that a function does not
22+
allocate:
23+
24+
```
25+
@test @wrappedallocs(x + y) == 0
26+
```
27+
"""
28+
macro wrappedallocs(expr)
29+
argnames = [gensym() for a in expr.args]
30+
quote
31+
function g($(argnames...))
32+
@allocated $(Expr(expr.head, argnames...))
33+
end
34+
$(Expr(:call, :g, [esc(a) for a in expr.args]...))
35+
end
36+
end
37+
138
# Points
239
x0 = [-1.0, 1.0, -1.0]
340
x1 = x0 + [1.0, 0.0, 1.0]
@@ -74,12 +111,15 @@ end
74111
u = similar(v)
75112
mul!(u, A, v)
76113
@test (@allocated mul!(u, A, v)) == 0
114+
@test (@wrappedallocs push!(A, u, v)) == 0
77115
B = DiagonalPSB(d)
78116
mul!(u, B, v)
79117
@test (@allocated mul!(u, B, v)) == 0
118+
@test (@wrappedallocs push!(B, u, v)) == 0
80119
C = SpectralGradient(rand(), 5)
81120
mul!(u, C, v)
82121
@test (@allocated mul!(u, C, v)) == 0
122+
@test (@wrappedallocs push!(C, u, v)) == 0
83123
end
84124

85125
@testset "reset" begin

0 commit comments

Comments
 (0)