Skip to content

Commit 06fd64f

Browse files
pull out inner function for mutlithreaded nested loops to make logic clearer
1 parent 844bb44 commit 06fd64f

File tree

2 files changed

+44
-28
lines changed

2 files changed

+44
-28
lines changed

src/GPE.jl

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,20 @@ function update_mll!(gp::GPE; noise::Bool=true, domean::Bool=true, kern::Bool=tr
215215
gp
216216
end
217217

218+
function _dmll_kern_row!(dmll, buf, k, ααinvcKI, X, data, j, dim, nparams)
219+
# diagonal
220+
dKij_dθ!(buf, k, X, X, data, j, j, dim, nparams)
221+
@inbounds for iparam in 1:nparams
222+
dmll[iparam] += buf[iparam] * ααinvcKI[j, j] / 2.0
223+
end
224+
# off-diagonal
225+
@inbounds for i in 1:j-1
226+
dKij_dθ!(buf, k, X, X, data, i, j, dim, nparams)
227+
@simd for iparam in 1:nparams
228+
dmll[iparam] += buf[iparam] * ααinvcKI[i, j]
229+
end
230+
end
231+
end
218232
"""
219233
dmll_kern!((dmll::AbstractVector, k::Kernel, X::AbstractMatrix, data::KernelData, ααinvcKI::AbstractMatrix))
220234
@@ -236,18 +250,8 @@ function dmll_kern!(dmll::AbstractVector, k::Kernel, X::AbstractMatrix, data::Ke
236250
kthread = kcopies[Threads.threadid()]
237251
bufthread = buffercopies[Threads.threadid()]
238252
dmllthread = dmllcopies[Threads.threadid()]
239-
# diagonal
240-
dKij_dθ!(bufthread, kthread, X, X, data, j, j, dim, nparams)
241-
for iparam in 1:nparams
242-
dmllthread[iparam] += bufthread[iparam] * ααinvcKI[j, j] / 2.0
243-
end
244-
# off-diagonal
245-
for i in j+1:nobs
246-
dKij_dθ!(bufthread, kthread, X, X, data, i, j, dim, nparams)
247-
@simd for iparam in 1:nparams
248-
dmllthread[iparam] += bufthread[iparam] * ααinvcKI[i, j]
249-
end
250-
end
253+
_dmll_kern_row!(dmllthread, bufthread, kthread,
254+
ααinvcKI, X, data, j, dim, nparams)
251255
end
252256

253257
dmll[:] = sum(dmllcopies) # sum up the results from all threads

src/kernels/kernels.jl

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,28 @@ function cov(k::Kernel, X1::AbstractMatrix, X2::AbstractMatrix, data::KernelData
3636
cov!(cK, k, X1, X2, data)
3737
end
3838

39+
function _cov_row!(cK, k, X::AbstractMatrix, data, j, dim)
40+
cK[j,j] = cov_ij(k, X, X, data, j, j, dim)
41+
@inbounds for i in 1:j-1
42+
cK[i,j] = cov_ij(k, X, X, data, i, j, dim)
43+
cK[j,i] = cK[i,j]
44+
end
45+
end
3946
function cov!(cK::AbstractMatrix, k::Kernel, X::AbstractMatrix, data::KernelData=EmptyData())
4047
dim, nobs = size(X)
4148
(nobs,nobs) == size(cK) || throw(ArgumentError("cK has size $(size(cK)) and X has size $(size(X))"))
4249
kcopies = [deepcopy(k) for _ in 1:Threads.nthreads()] # in case k is not threadsafe (e.g. ADkernel)
4350
@inbounds Threads.@threads for j in 1:nobs
4451
kthread = kcopies[Threads.threadid()]
45-
cK[j,j] = cov_ij(kthread, X, X, data, j, j, dim)
46-
for i in 1:j-1
47-
cK[i,j] = cov_ij(kthread, X, X, data, i, j, dim)
48-
cK[j,i] = cK[i,j]
49-
end
52+
_cov_row!(cK, k, X, data, j, dim)
5053
end
5154
return cK
5255
end
56+
function _cov_row!(cK, k, X1::AbstractMatrix, X2::AbstractMatrix, data, i, dim, nobs2)
57+
@inbounds for j in 1:nobs2
58+
cK[i,j] = cov_ij(k, X1, X2, data, i, j, dim)
59+
end
60+
end
5361
"""
5462
cov!(cK::AbstractMatrix, k::Kernel, X1::AbstractMatrix, X2::AbstractMatrix, data::KernelData=EmptyData())
5563
@@ -67,9 +75,7 @@ function cov!(cK::AbstractMatrix, k::Kernel, X1::AbstractMatrix, X2::AbstractMat
6775
kcopies = [deepcopy(k) for _ in 1:Threads.nthreads()]
6876
@inbounds Threads.@threads for i in 1:nobs1
6977
kthread = kcopies[Threads.threadid()]
70-
for j in 1:nobs2
71-
cK[i,j] = cov_ij(kthread, X1, X2, data, i, j, dim)
72-
end
78+
_cov_row!(cK, kthread, X1, X2, data, i, dim, nobs2)
7379
end
7480
return cK
7581
end
@@ -97,20 +103,28 @@ cov(k::Kernel, X::AbstractMatrix, data::KernelData=EmptyData()) = cov(k, X, X, d
97103
end
98104
end
99105

106+
function _grad_slice_row!(dK, k, X::AbstractMatrix, data, j, p, dim)
107+
dK[j,j] = dKij_dθp(k,X,X,data,j,j,p,dim)
108+
@inbounds @simd for i in 1:(j-1)
109+
dK[i,j] = dKij_dθp(k,X,X,data,i,j,p,dim)
110+
dK[j,i] = dK[i,j]
111+
end
112+
end
100113
function grad_slice!(dK::AbstractMatrix, k::Kernel, X::AbstractMatrix, data::KernelData, p::Int)
101114
dim, nobs = size(X)
102115
(nobs,nobs) == size(dK) || throw(ArgumentError("dK has size $(size(dK)) and X has size $(size(X))"))
103116
kcopies = [deepcopy(k) for _ in 1:Threads.nthreads()]
104117
@inbounds Threads.@threads for j in 1:nobs
105118
kthread = kcopies[Threads.threadid()]
106-
dK[j,j] = dKij_dθp(kthread,X,X,data,j,j,p,dim)
107-
@simd for i in 1:(j-1)
108-
dK[i,j] = dKij_dθp(kthread,X,X,data,i,j,p,dim)
109-
dK[j,i] = dK[i,j]
110-
end
119+
_grad_slice_row!(dK, kthread, X, data, j, p, dim)
111120
end
112121
return dK
113122
end
123+
function _grad_slice_row!(dK, k, X1::AbstractMatrix, X2::AbstractMatrix, data, i, p, dim, nobs2)
124+
@inbounds @simd for j in 1:nobs2
125+
dK[i,j] = dKij_dθp(k,X1,X2,data,i,j,p,dim)
126+
end
127+
end
114128
function grad_slice!(dK::AbstractMatrix, k::Kernel, X1::AbstractMatrix, X2::AbstractMatrix, data::KernelData, p::Int)
115129
if X1 === X2
116130
return grad_slice!(dK, k, X1, data, p)
@@ -123,9 +137,7 @@ function grad_slice!(dK::AbstractMatrix, k::Kernel, X1::AbstractMatrix, X2::Abst
123137
kcopies = [deepcopy(k) for _ in 1:Threads.nthreads()]
124138
@inbounds Threads.@threads for i in 1:nobs1
125139
kthread = kcopies[Threads.threadid()]
126-
@simd for j in 1:nobs2
127-
dK[i,j] = dKij_dθp(kthread,X1,X2,data,i,j,p,dim)
128-
end
140+
_grad_slice_row!(dK, kthread, X1, X2, data, i, p, dim, nobs2)
129141
end
130142
return dK
131143
end

0 commit comments

Comments
 (0)