Skip to content

Support for linking distributions with embedded support #461

@torfjelde

Description

@torfjelde

For certain distributions, the random variable represented by the Distribution has support which is lower-dimensional than the return-type indicates; that is, the returned realizations are embedded in a higher dimensional space.

For example, LKJ is a distribution over correlation-matrices. Correlation matrices are required to be positive-definite (PD) and have 1 along the diagonal. PD means that we only have (n choose 2) + n degrees of freedom, and 1 along the diagonal removes the additional factor of n, leaving us with only (n choose 2) degrees of freedom. That is, as a vector space, the dimension of the correlation-matrices is actually just (n choose 2), not n × n as might be indicated by the returned Matrix{Float64} from rand(::LKJ)!

For SimpleVarInfo, this is trivial to support because SimpleVarInfo only contains the realizations themselves, no information related to the distributions, etc. Therefore, with something like TuringLang/Bijectors.jl#246, things just work

julia> using DynamicPPL, Distributions, Bijectors

julia> # Switch the bijector used to the `VecCorrBijector` from the forementioned PR.
       Bijectors.bijector(::LKJ) = Bijectors.VecCorrBijector();

julia> @model demo() = x ~ LKJ(3, 1);

julia> model = demo();

julia> vi = SimpleVarInfo(model);

julia> # Now it's a matrix.
       vi[@varname(x)]
3×3 Matrix{Float64}:
  1.0         -0.00803721  -0.849602
 -0.00803721   1.0          0.00190424
 -0.849602     0.00190424   1.0

julia> vi_transformed = link!!(vi, model);

julia> # Now it's a vector.
       vi_transformed[@varname(x)]
3-element Vector{Float64}:
 -0.00803738468434096
 -1.2547213956880081
 -0.0093368799126288

julia> logjoint(model, vi_transformed)  # (✓) Works!
-3.515748926181343

In contrast, with VarInfo things are not so simple:

julia> vi = VarInfo(model);

julia> vi[@varname(x)]
3×3 Matrix{Float64}:
 1.0        0.382085   0.607741
 0.382085   1.0       -0.173265
 0.607741  -0.173265   1.0

julia> vi_transformed = link!!(vi, model);
ERROR: DimensionMismatch: tried to assign 3 elements to 9 destinations
Stacktrace:
...

With VarInfo there are multiple challenges:

  1. link!! occurs in-place and expects the same shape as the original (untransformed) value.
  2. getindex(vi, vn, dist) uses reconstruct(dist, val) to reshape the underlying flattened representation in VarInfo to what dist expects. This is done before passing it to the bijector/transformation, and so we if we're working with a Vector (because we're in transformed space), then call reconstruct(dist, val::Vector) we get back a Matrix aaaand the inverse transformation, which expects a Vector, fails. We could start looking into potentially adding the transformation used to the reconstruct call, i.e. letting (dist, transform)-pairs define the reconstruct rather than just dist, but then the problem is that in VarInfo whether a variable is transformed or not is decided at runtime, which in turn causes type-instabilities (reconstruct would then return Vector in some cases and Matrix in others, decided upon at runtime).

So. We need a good way of doing this with VarInfo and I figured I'd make an issue so we can discuss this in more detail together.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions