Skip to content

Commit 6c08717

Browse files
author
RAYNAUD Paul (raynaudp)
committed
improve circular shift in push! (faster update than LBFGSOperator)
1 parent b4c2970 commit 6c08717

File tree

1 file changed

+13
-21
lines changed

1 file changed

+13
-21
lines changed

src/compressed_lbfgs.jl

+13-21
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@ function columnshift!(A::AbstractMatrix{T}; direction::Int=-1, indicemax::Int=si
6767
return A
6868
end
6969

70+
function columnshift!(A::AbstractMatrix{T}; direction::Int=-1, indicemax::Int=size(A)[1]) where T
71+
map(i-> view(A,:,i+direction) .= view(A,:,i), 1-direction:indicemax)
72+
return A
73+
end
74+
7075
"""
7176
CompressedLBFGS(n::Int; [T=Float64, m=5], gpu:Bool)
7277
@@ -100,21 +105,16 @@ function Base.push!(op::CompressedLBFGS{T,M,V}, s::V, y::V) where {T,M,V<:Abstra
100105
view(op.Yₖ, :, op.k) .= y
101106
view(op.Dₖ.diag, op.k) .= dot(s, y)
102107
mul!(view(op.Lₖ.data, op.k, 1:op.k-1), transpose(view(op.Yₖ, :, 1:op.k-1)), view(op.Sₖ, :, op.k) )
103-
104108
else # k == m update circurlarly the intermediary structures
105-
op.Sₖ .= circshift(op.Sₖ, (0, -1))
106-
op.Yₖ .= circshift(op.Yₖ, (0, -1))
109+
columnshift!(op.Sₖ; indicemax=op.k)
110+
columnshift!(op.Yₖ; indicemax=op.k)
107111
op.Dₖ .= circshift(op.Dₖ, (-1, -1))
108-
op.Sₖ[:, op.k] .= s
109-
op.Yₖ[:, op.k] .= y
110-
op.Dₖ.diag[op.k] = dot(s, y)
111-
# circshift doesn't work for a LowerTriangular matrix
112-
# for the time being, reinstantiate completely the Lₖ matrix
113-
for j in 1:op.k
114-
for i in 1:j-1
115-
op.Lₖ.data[j, i] = dot(view(op.Sₖ,:, j), view(op.Yₖ, :, i))
116-
end
117-
end
112+
view(op.Sₖ, :, op.k) .= s
113+
view(op.Yₖ, :, op.k) .= y
114+
view(op.Dₖ.diag, op.k) .= dot(s, y)
115+
116+
map(i-> view(op.Lₖ, i:op.m-1, i-1) .= view(op.Lₖ, i+1:op.m, i), 2:op.m)
117+
mul!(view(op.Lₖ.data, op.k, 1:op.k-1), transpose(view(op.Yₖ, :, 1:op.k-1)), view(op.Sₖ, :, op.k) )
118118
end
119119

120120
# step 4 and 6
@@ -150,15 +150,11 @@ end
150150
function inverse_cholesky(op::CompressedLBFGS{T,M,V}) where {T,M,V}
151151
view(op.intermediate_diagonal.diag, 1:op.k) .= inv.(view(op.Dₖ.diag, 1:op.k))
152152

153-
# view(op.Lₖ, 1:op.k, 1:op.k) * inv(Diagonal(op.Dₖ[1:op.k, 1:op.k])) * transpose(view(op.Lₖ, 1:op.k, 1:op.k))
154153
mul!(view(op.inverse_intermediate_1, 1:op.k, 1:op.k), view(op.intermediate_diagonal, 1:op.k, 1:op.k), transpose(view(op.Lₖ, 1:op.k, 1:op.k)))
155154
mul!(view(op.chol_matrix, 1:op.k, 1:op.k), view(op.Lₖ, 1:op.k, 1:op.k), view(op.inverse_intermediate_1, 1:op.k, 1:op.k))
156155

157-
# view(op.chol_matrix, 1:op.k, 1:op.k) .= op.α .* (transpose(view(op.Sₖ, :, 1:op.k)) * view(op.Sₖ, :, 1:op.k))
158156
mul!(view(op.chol_matrix, 1:op.k, 1:op.k), transpose(view(op.Sₖ, :, 1:op.k)), view(op.Sₖ, :, 1:op.k), op.α, (T)(1))
159157

160-
# view(op.chol_matrix, 1:op.k, 1:op.k) .= op.α .* (transpose(view(op.Sₖ, :, 1:op.k)) * view(op.Sₖ, :, 1:op.k)) .+ view(op.Lₖ, 1:op.k, 1:op.k) * inv(Diagonal(op.Dₖ[1:op.k, 1:op.k])) * transpose(view(op.Lₖ, 1:op.k, 1:op.k))
161-
162158
cholesky!(Symmetric(view(op.chol_matrix, 1:op.k, 1:op.k)))
163159
Jₖ = transpose(UpperTriangular(view(op.chol_matrix, 1:op.k, 1:op.k)))
164160
return Jₖ
@@ -182,10 +178,8 @@ function precompile_iterated_structure!(op::CompressedLBFGS)
182178
# updates related to D^(-1/2)
183179
view(op.intermediate_diagonal.diag, 1:op.k) .= (x -> 1/sqrt(x)).(view(op.Dₖ.diag, 1:op.k))
184180
mul!(view(op.intermediate_1, 1:op.k,op.k+1:2*op.k), view(op.intermediate_diagonal, 1:op.k, 1:op.k), transpose(view(op.Lₖ, 1:op.k, 1:op.k)))
185-
# view(op.intermediate_1, 1:op.k,op.k+1:2*op.k) .= view(op.Dₖ, 1:op.k, 1:op.k)^(-1/2) * transpose(view(op.Lₖ, 1:op.k, 1:op.k))
186181
mul!(view(op.intermediate_2, op.k+1:2*op.k, 1:op.k), view(op.Lₖ, 1:op.k, 1:op.k), view(op.intermediate_diagonal, 1:op.k, 1:op.k))
187182
view(op.intermediate_2, op.k+1:2*op.k, 1:op.k) .= view(op.intermediate_2, op.k+1:2*op.k, 1:op.k) .* -1
188-
# view(op.intermediate_2, op.k+1:2*op.k, 1:op.k) .= .- view(op.Lₖ, 1:op.k, 1:op.k) * view(op.Dₖ, 1:op.k, 1:op.k)^(-1/2)
189183

190184
view(op.inverse_intermediate_1, 1:2*op.k, 1:2*op.k) .= inv(op.intermediate_1[1:2*op.k, 1:2*op.k])
191185
view(op.inverse_intermediate_2, 1:2*op.k, 1:2*op.k) .= inv(op.intermediate_2[1:2*op.k, 1:2*op.k])
@@ -200,12 +194,10 @@ function LinearAlgebra.mul!(Bv::V, op::CompressedLBFGS{T,M,V}, v::V) where {T,M,
200194
# scal!(op.α, view(op.sol, op.k+1:2*op.k)) # more allocation, slower
201195
view(op.sol, op.k+1:2*op.k) .*= op.α
202196

203-
# view(op.sol, 1:2*op.k) .= view(op.inverse_intermediate_1, 1:2*op.k, 1:2*op.k) * (view(op.inverse_intermediate_2, 1:2*op.k, 1:2*op.k) * view(op.sol, 1:2*op.k))
204197
mul!(view(op.intermediary_vector, 1:2*op.k), view(op.inverse_intermediate_2, 1:2*op.k, 1:2*op.k), view(op.sol, 1:2*op.k))
205198
mul!(view(op.sol, 1:2*op.k), view(op.inverse_intermediate_1, 1:2*op.k, 1:2*op.k), view(op.intermediary_vector, 1:2*op.k))
206199

207200
# step 7
208-
# Bv .= op.α .* v .- (view(op.Yₖ, :,1:op.k) * view(op.sol, 1:op.k) .+ op.α .* view(op.Sₖ, :, 1:op.k) * view(op.sol, op.k+1:2*op.k))
209201
mul!(Bv, view(op.Yₖ, :, 1:op.k), view(op.sol, 1:op.k))
210202
mul!(Bv, view(op.Sₖ, :, 1:op.k), view(op.sol, op.k+1:2*op.k), - op.α, (T)(-1))
211203
Bv .+= op.α .* v

0 commit comments

Comments
 (0)