-
Notifications
You must be signed in to change notification settings - Fork 36
Proper support for distributions with embedded support #462
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
fc16d3e
a0de0ac
4fe5eea
dfaf7be
8d77d78
16ac9e1
2e3b006
425ca5d
c1f0b3b
73f4bd2
1b3c581
c2fbded
3b156db
6070e3f
613eb1b
0af6e29
0fcd481
29faba0
cc1bb7b
2501510
6957e2e
90a3edb
7817920
ade35c8
d7841e5
7a6ef1b
360283f
d958d84
3cf6e07
a25891d
02dd8bf
4765ea9
f02fdd9
603e027
de47598
8cd2610
752e40b
b0a67a9
2765b08
ed03864
96c0690
7b5521d
595d9ee
d81217e
9b516b8
61e832b
d06cc8a
dbecece
aa76c08
f756e44
9785594
78e6332
cec5bd3
b6dc3ec
f126448
1e4d688
b9a8c16
e3ce20d
e5d19a8
ba7c24c
bf73961
eec87ee
42ed9df
6fa0f72
3d5ead3
64c3a07
3fd55d4
e683cce
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,7 +43,6 @@ export AbstractVarInfo, | |
push!!, | ||
empty!!, | ||
getlogp, | ||
resetlogp!, | ||
setlogp!!, | ||
acclogp!!, | ||
resetlogp!!, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -194,8 +194,8 @@ end | |
|
||
# fallback without sampler | ||
function assume(dist::Distribution, vn::VarName, vi) | ||
r = vi[vn, dist] | ||
return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi | ||
r, logp = invlink_with_logpdf(vi, vn, dist) | ||
return r, logp, vi | ||
end | ||
|
||
# SampleFromPrior and SampleFromUniform | ||
|
@@ -211,7 +211,9 @@ function assume( | |
if sampler isa SampleFromUniform || is_flagged(vi, vn, "del") | ||
unset_flag!(vi, vn, "del") | ||
r = init(rng, dist, sampler) | ||
BangBang.setindex!!(vi, vectorize(dist, maybe_link(vi, vn, dist, r)), vn) | ||
BangBang.setindex!!( | ||
vi, vectorize(dist, maybe_reconstruct_and_link(vi, vn, dist, r)), vn | ||
) | ||
setorder!(vi, vn, get_num_produce(vi)) | ||
else | ||
# Otherwise we just extract it. | ||
|
@@ -220,15 +222,17 @@ function assume( | |
else | ||
r = init(rng, dist, sampler) | ||
if istrans(vi) | ||
push!!(vi, vn, link(dist, r), dist, sampler) | ||
push!!(vi, vn, reconstruct_and_link(dist, r), dist, sampler) | ||
# By default `push!!` sets the transformed flag to `false`. | ||
settrans!!(vi, true, vn) | ||
else | ||
push!!(vi, vn, r, dist, sampler) | ||
end | ||
end | ||
|
||
return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi | ||
# HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed, that's confusing. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I really want to remove this entire function. |
||
logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r) | ||
return r, logpdf(dist, r) - logjac, vi | ||
end | ||
|
||
# default fallback (used e.g. by `SampleFromPrior` and `SampleUniform`) | ||
|
@@ -470,7 +474,11 @@ function get_and_set_val!( | |
r = init(rng, dist, spl, n) | ||
for i in 1:n | ||
vn = vns[i] | ||
setindex!!(vi, vectorize(dist, maybe_link(vi, vn, dist, r[:, i])), vn) | ||
setindex!!( | ||
vi, | ||
vectorize(dist, maybe_reconstruct_and_link(vi, vn, dist, r[:, i])), | ||
vn, | ||
) | ||
setorder!(vi, vn, get_num_produce(vi)) | ||
end | ||
else | ||
|
@@ -508,13 +516,17 @@ function get_and_set_val!( | |
for i in eachindex(vns) | ||
vn = vns[i] | ||
dist = dists isa AbstractArray ? dists[i] : dists | ||
setindex!!(vi, vectorize(dist, maybe_link(vi, vn, dist, r[i])), vn) | ||
setindex!!( | ||
vi, vectorize(dist, maybe_reconstruct_and_link(vi, vn, dist, r[i])), vn | ||
) | ||
setorder!(vi, vn, get_num_produce(vi)) | ||
end | ||
else | ||
# r = reshape(vi[vec(vns)], size(vns)) | ||
# FIXME: Remove `reconstruct` in `getindex_raw(::VarInfo, ...)` | ||
# and fix the lines below. | ||
r_raw = getindex_raw(vi, vec(vns)) | ||
r = maybe_invlink.((vi,), vns, dists, reshape(r_raw, size(vns))) | ||
r = maybe_invlink_and_reconstruct.((vi,), vns, dists, reshape(r_raw, size(vns))) | ||
end | ||
else | ||
f = (vn, dist) -> init(rng, dist, spl) | ||
|
@@ -525,7 +537,7 @@ function get_and_set_val!( | |
# 2. Define an anonymous function which returns `nothing`, which | ||
# we then broadcast. This will allocate a vector of `nothing` though. | ||
if istrans(vi) | ||
push!!.((vi,), vns, link.((vi,), vns, dists, r), dists, (spl,)) | ||
push!!.((vi,), vns, reconstruct_and_link.((vi,), vns, dists, r), dists, (spl,)) | ||
# NOTE: Need to add the correction. | ||
acclogp!!(vi, sum(logabsdetjac.(bijector.(dists), r))) | ||
# `push!!` sets the trans-flag to `false` by default. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -177,10 +177,39 @@ function to_namedtuple_expr(syms, vals) | |
return :(NamedTuple{$names_expr}($vals_expr)) | ||
end | ||
|
||
""" | ||
link_transform(dist) | ||
|
||
Return the constrained-to-unconstrained bijector for distribution `dist`. | ||
|
||
By default, this is just `Bijectors.bijector(dist)`. | ||
|
||
!!! warning | ||
Note that currently this is not used by `Bijectors.logpdf_with_trans`, | ||
hence that needs to be overloaded separately if the intention is | ||
to change behavior of an existing distribution. | ||
""" | ||
link_transform(dist) = bijector(dist) | ||
|
||
""" | ||
invlink_transform(dist) | ||
|
||
Return the unconstrained-to-constrained bijector for distribution `dist`. | ||
|
||
By default, this is just `inverse(link_transform(dist))`. | ||
|
||
!!! warning | ||
Note that currently this is not used by `Bijectors.logpdf_with_trans`, | ||
hence that needs to be overloaded separately if the intention is | ||
to change behavior of an existing distribution. | ||
""" | ||
invlink_transform(dist) = inverse(link_transform(dist)) | ||
|
||
##################################################### | ||
# Helper functions for vectorize/reconstruct values # | ||
##################################################### | ||
|
||
vectorize(d, r) = vec(r) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems like this is untested and causes method ambiguity issues (when eg tested with test_method_ambiguities). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know of this |
||
vectorize(d::UnivariateDistribution, r::Real) = [r] | ||
vectorize(d::MultivariateDistribution, r::AbstractVector{<:Real}) = copy(r) | ||
vectorize(d::MatrixDistribution, r::AbstractMatrix{<:Real}) = copy(vec(r)) | ||
|
@@ -191,7 +220,23 @@ vectorize(d::MatrixDistribution, r::AbstractMatrix{<:Real}) = copy(vec(r)) | |
# otherwise we will have error for MatrixDistribution. | ||
# Note this is not the case for MultivariateDistribution so I guess this might be lack of | ||
# support for some types related to matrices (like PDMat). | ||
reconstruct(d::UnivariateDistribution, val::Real) = val | ||
|
||
""" | ||
reconstruct([f, ]dist, val) | ||
|
||
Reconstruct `val` so that it's compatible with `dist`. | ||
|
||
If `f` is also provided, the reconstruct value will be | ||
such that `f(reconstruct_val)` is compatible with `dist`. | ||
""" | ||
reconstruct(f, dist, val) = reconstruct(dist, val) | ||
|
||
# No-op versions. | ||
reconstruct(::UnivariateDistribution, val::Real) = val | ||
reconstruct(::MultivariateDistribution, val::AbstractVector{<:Real}) = copy(val) | ||
reconstruct(::MatrixDistribution, val::AbstractMatrix{<:Real}) = copy(val) | ||
# TODO: Implement no-op `reconstruct` for general array variates. | ||
|
||
reconstruct(d::Distribution, val::AbstractVector) = reconstruct(size(d), val) | ||
reconstruct(::Tuple{}, val::AbstractVector) = val[1] | ||
reconstruct(s::NTuple{1}, val::AbstractVector) = copy(val) | ||
|
Uh oh!
There was an error while loading. Please reload this page.