-
Notifications
You must be signed in to change notification settings - Fork 36
Description
For certain distributions, the random variable represented by the Distribution
has support which is lower-dimensional than the return-type indicates; that is, the returned realizations are embedded in a higher dimensional space.
For example, LKJ
is a distribution over correlation-matrices. Correlation matrices are required to be positive-definite (PD) and have 1 along the diagonal. PD means that we only have (n choose 2) + n
degrees of freedom, and 1 along the diagonal removes the additional factor of n
, leaving us with only (n choose 2)
degrees of freedom. That is, as a vector space, the dimension of the correlation-matrices is actually just (n choose 2)
, not n × n
as might be indicated by the returned Matrix{Float64}
from rand(::LKJ)
!
For SimpleVarInfo
, this is trivial to support because SimpleVarInfo
only contains the realizations themselves, no information related to the distributions, etc. Therefore, with something like TuringLang/Bijectors.jl#246, things just work
julia> using DynamicPPL, Distributions, Bijectors
julia> # Switch the bijector used to the `VecCorrBijector` from the forementioned PR.
Bijectors.bijector(::LKJ) = Bijectors.VecCorrBijector();
julia> @model demo() = x ~ LKJ(3, 1);
julia> model = demo();
julia> vi = SimpleVarInfo(model);
julia> # Now it's a matrix.
vi[@varname(x)]
3×3 Matrix{Float64}:
1.0 -0.00803721 -0.849602
-0.00803721 1.0 0.00190424
-0.849602 0.00190424 1.0
julia> vi_transformed = link!!(vi, model);
julia> # Now it's a vector.
vi_transformed[@varname(x)]
3-element Vector{Float64}:
-0.00803738468434096
-1.2547213956880081
-0.0093368799126288
julia> logjoint(model, vi_transformed) # (✓) Works!
-3.515748926181343
In contrast, with VarInfo
things are not so simple:
julia> vi = VarInfo(model);
julia> vi[@varname(x)]
3×3 Matrix{Float64}:
1.0 0.382085 0.607741
0.382085 1.0 -0.173265
0.607741 -0.173265 1.0
julia> vi_transformed = link!!(vi, model);
ERROR: DimensionMismatch: tried to assign 3 elements to 9 destinations
Stacktrace:
...
With VarInfo
there are multiple challenges:
link!!
occurs in-place and expects the same shape as the original (untransformed) value.getindex(vi, vn, dist)
usesreconstruct(dist, val)
to reshape the underlying flattened representation inVarInfo
to whatdist
expects. This is done before passing it to the bijector/transformation, and so we if we're working with aVector
(because we're in transformed space), then callreconstruct(dist, val::Vector)
we get back aMatrix
aaaand the inverse transformation, which expects aVector
, fails. We could start looking into potentially adding the transformation used to thereconstruct
call, i.e. letting(dist, transform)
-pairs define thereconstruct
rather than justdist
, but then the problem is that inVarInfo
whether a variable is transformed or not is decided at runtime, which in turn causes type-instabilities (reconstruct
would then returnVector
in some cases andMatrix
in others, decided upon at runtime).
So. We need a good way of doing this with VarInfo
and I figured I'd make an issue so we can discuss this in more detail together.