@@ -237,18 +237,46 @@ end
237
237
end
238
238
239
239
@testset " LKJCholesky" begin
240
+ # Convert Cholesky factor to its free parameters, i.e. its off-diagonal elements
241
+ function chol_3by3_to_free_params (x:: Cholesky )
242
+ if x. uplo == :U
243
+ return [x. U[1 , 2 ], x. U[1 , 3 ], x. U[2 , 3 ]]
244
+ else
245
+ return [x. L[2 , 1 ], x. L[3 , 1 ], x. L[3 , 2 ]]
246
+ end
247
+ # TODO : Generalise to arbitrary dimension using this code:
248
+ # inds = [
249
+ # LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if
250
+ # (uplo === :L && I[2] < I[1]) || (uplo === :U && I[2] > I[1])
251
+ # ]
252
+ end
253
+
254
+ # Reconstruct Cholesky factor from its free parameters
255
+ # Note that x[i, i] is always positive so we don't need to worry about the sign
256
+ function free_params_to_chol_3by3 (free_params:: AbstractVector , uplo:: Symbol )
257
+ x = UpperTriangular (zeros (eltype (free_params), 3 , 3 ))
258
+ x[1 , 1 ] = 1
259
+ x[1 , 2 ] = free_params[1 ]
260
+ x[1 , 3 ] = free_params[2 ]
261
+ x[2 , 2 ] = sqrt (1 - free_params[1 ]^ 2 )
262
+ x[2 , 3 ] = free_params[3 ]
263
+ x[3 , 3 ] = sqrt (1 - free_params[2 ]^ 2 - free_params[3 ]^ 2 )
264
+ if uplo == :U
265
+ return Cholesky (x)
266
+ else
267
+ return Cholesky (transpose (x))
268
+ end
269
+ end
270
+
240
271
@testset " uplo: $uplo " for uplo in [:L , :U ]
241
272
dist = LKJCholesky (3 , 1 , uplo)
242
273
single_sample_tests (dist)
243
274
244
275
x = rand (dist)
245
-
246
- inds = [
247
- LinearIndices (size (x))[I] for I in CartesianIndices (size (x)) if
248
- (uplo === :L && I[2 ] < I[1 ]) || (uplo === :U && I[2 ] > I[1 ])
249
- ]
250
- J = ForwardDiff. jacobian (z -> link (dist, Cholesky (z, x. uplo, x. info)), x. UL)
251
- J = J[:, inds]
276
+ # Here, we need to pass ForwardDiff only the free parameters of the
277
+ # Cholesky factor so that we get a square Jacobian matrix
278
+ free_params = chol_3by3_to_free_params (x)
279
+ J = ForwardDiff. jacobian (z -> link (dist, free_params_to_chol_3by3 (z, uplo)), free_params)
252
280
logpdf_turing = logpdf_with_trans (dist, x, true )
253
281
@test logpdf (dist, x) - _logabsdet (J) ≈ logpdf_turing
254
282
end
0 commit comments