Skip to content

Commit 2c164b2

Browse files
authored
Revert "Proper support for distributions with embedded support (#462)"
This reverts commit 7b01d25.
1 parent 7b01d25 commit 2c164b2

File tree

13 files changed

+175
-331
lines changed

13 files changed

+175
-331
lines changed

docs/src/api.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,7 @@ DynamicPPL.link!!
206206
DynamicPPL.invlink!!
207207
DynamicPPL.default_transformation
208208
DynamicPPL.maybe_invlink_before_eval!!
209-
DynamicPPL.reconstruct
210-
```
209+
```
211210

212211
#### Utils
213212

src/DynamicPPL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ export AbstractVarInfo,
4343
push!!,
4444
empty!!,
4545
getlogp,
46+
resetlogp!,
4647
setlogp!!,
4748
acclogp!!,
4849
resetlogp!!,

src/abstract_varinfo.jl

Lines changed: 0 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -553,99 +553,6 @@ variables `x` would return
553553
"""
554554
function tonamedtuple end
555555

556-
# TODO: Clean up all this linking stuff once and for all!
557-
"""
558-
with_logabsdet_jacobian_and_reconstruct([f, ]dist, x)
559-
560-
Like `Bijectors.with_logabsdet_jacobian(f, x)`, but also ensures the resulting
561-
value is reconstructed to the correct type and shape according to `dist`.
562-
"""
563-
function with_logabsdet_jacobian_and_reconstruct(f, dist, x)
564-
x_recon = reconstruct(f, dist, x)
565-
return with_logabsdet_jacobian(f, x_recon)
566-
end
567-
568-
# TODO: Once `(inv)link` isn't used heavily in `getindex(vi, vn)`, we can
569-
# just use `first ∘ with_logabsdet_jacobian` to reduce the maintenance burden.
570-
# NOTE: `reconstruct` is no-op if `val` is already of correct shape.
571-
"""
572-
reconstruct_and_link(dist, val)
573-
reconstruct_and_link(vi::AbstractVarInfo, vi::VarName, dist, val)
574-
575-
Return linked `val` but reconstruct before linking, if necessary.
576-
577-
Note that unlike [`invlink_and_reconstruct`](@ref), this does not necessarily
578-
return a reconstructed value, i.e. a value of the same type and shape as expected
579-
by `dist`.
580-
581-
See also: [`invlink_and_reconstruct`](@ref), [`reconstruct`](@ref).
582-
"""
583-
reconstruct_and_link(f, dist, val) = f(reconstruct(f, dist, val))
584-
reconstruct_and_link(dist, val) = reconstruct_and_link(link_transform(dist), dist, val)
585-
function reconstruct_and_link(::AbstractVarInfo, ::VarName, dist, val)
586-
return reconstruct_and_link(dist, val)
587-
end
588-
589-
"""
590-
invlink_and_reconstruct(dist, val)
591-
invlink_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val)
592-
593-
Return invlinked and reconstructed `val`.
594-
595-
See also: [`reconstruct_and_link`](@ref), [`reconstruct`](@ref).
596-
"""
597-
invlink_and_reconstruct(f, dist, val) = f(reconstruct(f, dist, val))
598-
function invlink_and_reconstruct(dist, val)
599-
return invlink_and_reconstruct(invlink_transform(dist), dist, val)
600-
end
601-
function invlink_and_reconstruct(::AbstractVarInfo, ::VarName, dist, val)
602-
return invlink_and_reconstruct(dist, val)
603-
end
604-
605-
"""
606-
maybe_link_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val)
607-
608-
Return reconstructed `val`, possibly linked if `istrans(vi, vn)` is `true`.
609-
"""
610-
function maybe_reconstruct_and_link(vi::AbstractVarInfo, vn::VarName, dist, val)
611-
return if istrans(vi, vn)
612-
reconstruct_and_link(vi, vn, dist, val)
613-
else
614-
reconstruct(dist, val)
615-
end
616-
end
617-
618-
"""
619-
maybe_invlink_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val)
620-
621-
Return reconstructed `val`, possibly invlinked if `istrans(vi, vn)` is `true`.
622-
"""
623-
function maybe_invlink_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val)
624-
return if istrans(vi, vn)
625-
invlink_and_reconstruct(vi, vn, dist, val)
626-
else
627-
reconstruct(dist, val)
628-
end
629-
end
630-
631-
"""
632-
invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist[, x])
633-
634-
Invlink `x` and compute the logpdf under `dist` including correction from
635-
the invlink-transformation.
636-
637-
If `x` is not provided, `getval(vi, vn)` will be used.
638-
"""
639-
function invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist)
640-
return invlink_with_logpdf(vi, vn, dist, getval(vi, vn))
641-
end
642-
function invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist, y)
643-
# NOTE: Will this cause type-instabilities or will union-splitting save us?
644-
f = istrans(vi, vn) ? invlink_transform(dist) : identity
645-
x, logjac = with_logabsdet_jacobian_and_reconstruct(f, dist, y)
646-
return x, logpdf(dist, x) + logjac
647-
end
648-
649556
# Legacy code that is currently overloaded for the sake of simplicity.
650557
# TODO: Remove when possible.
651558
increment_num_produce!(::AbstractVarInfo) = nothing

src/context_implementations.jl

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,8 @@ end
194194

195195
# fallback without sampler
196196
function assume(dist::Distribution, vn::VarName, vi)
197-
r, logp = invlink_with_logpdf(vi, vn, dist)
198-
return r, logp, vi
197+
r = vi[vn, dist]
198+
return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi
199199
end
200200

201201
# SampleFromPrior and SampleFromUniform
@@ -211,9 +211,7 @@ function assume(
211211
if sampler isa SampleFromUniform || is_flagged(vi, vn, "del")
212212
unset_flag!(vi, vn, "del")
213213
r = init(rng, dist, sampler)
214-
BangBang.setindex!!(
215-
vi, vectorize(dist, maybe_reconstruct_and_link(vi, vn, dist, r)), vn
216-
)
214+
BangBang.setindex!!(vi, vectorize(dist, maybe_link(vi, vn, dist, r)), vn)
217215
setorder!(vi, vn, get_num_produce(vi))
218216
else
219217
# Otherwise we just extract it.
@@ -222,17 +220,15 @@ function assume(
222220
else
223221
r = init(rng, dist, sampler)
224222
if istrans(vi)
225-
push!!(vi, vn, reconstruct_and_link(dist, r), dist, sampler)
223+
push!!(vi, vn, link(dist, r), dist, sampler)
226224
# By default `push!!` sets the transformed flag to `false`.
227225
settrans!!(vi, true, vn)
228226
else
229227
push!!(vi, vn, r, dist, sampler)
230228
end
231229
end
232230

233-
# HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct.
234-
logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r)
235-
return r, logpdf(dist, r) - logjac, vi
231+
return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi
236232
end
237233

238234
# default fallback (used e.g. by `SampleFromPrior` and `SampleUniform`)
@@ -474,11 +470,7 @@ function get_and_set_val!(
474470
r = init(rng, dist, spl, n)
475471
for i in 1:n
476472
vn = vns[i]
477-
setindex!!(
478-
vi,
479-
vectorize(dist, maybe_reconstruct_and_link(vi, vn, dist, r[:, i])),
480-
vn,
481-
)
473+
setindex!!(vi, vectorize(dist, maybe_link(vi, vn, dist, r[:, i])), vn)
482474
setorder!(vi, vn, get_num_produce(vi))
483475
end
484476
else
@@ -516,17 +508,13 @@ function get_and_set_val!(
516508
for i in eachindex(vns)
517509
vn = vns[i]
518510
dist = dists isa AbstractArray ? dists[i] : dists
519-
setindex!!(
520-
vi, vectorize(dist, maybe_reconstruct_and_link(vi, vn, dist, r[i])), vn
521-
)
511+
setindex!!(vi, vectorize(dist, maybe_link(vi, vn, dist, r[i])), vn)
522512
setorder!(vi, vn, get_num_produce(vi))
523513
end
524514
else
525515
# r = reshape(vi[vec(vns)], size(vns))
526-
# FIXME: Remove `reconstruct` in `getindex_raw(::VarInfo, ...)`
527-
# and fix the lines below.
528516
r_raw = getindex_raw(vi, vec(vns))
529-
r = maybe_invlink_and_reconstruct.((vi,), vns, dists, reshape(r_raw, size(vns)))
517+
r = maybe_invlink.((vi,), vns, dists, reshape(r_raw, size(vns)))
530518
end
531519
else
532520
f = (vn, dist) -> init(rng, dist, spl)
@@ -537,7 +525,7 @@ function get_and_set_val!(
537525
# 2. Define an anonymous function which returns `nothing`, which
538526
# we then broadcast. This will allocate a vector of `nothing` though.
539527
if istrans(vi)
540-
push!!.((vi,), vns, reconstruct_and_link.((vi,), vns, dists, r), dists, (spl,))
528+
push!!.((vi,), vns, link.((vi,), vns, dists, r), dists, (spl,))
541529
# NOTE: Need to add the correction.
542530
acclogp!!(vi, sum(logabsdetjac.(bijector.(dists), r)))
543531
# `push!!` sets the trans-flag to `false` by default.

src/simple_varinfo.jl

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ end
290290

291291
# `NamedTuple`
292292
function Base.getindex(vi::SimpleVarInfo, vn::VarName, dist::Distribution)
293-
return maybe_invlink_and_reconstruct(vi, vn, dist, getindex(vi, vn))
293+
return maybe_invlink(vi, vn, dist, getindex(vi, vn))
294294
end
295295
function Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribution)
296296
vals_linked = mapreduce(vcat, vns) do vn
@@ -329,9 +329,6 @@ function getindex_raw(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribut
329329
return reconstruct(dist, vals, length(vns))
330330
end
331331

332-
# HACK: because `VarInfo` isn't ready to implement a proper `getindex_raw`.
333-
getval(vi::SimpleVarInfo, vn::VarName) = getindex_raw(vi, vn)
334-
335332
Base.haskey(vi::SimpleVarInfo, vn::VarName) = hasvalue(vi.values, vn)
336333

337334
function BangBang.setindex!!(vi::SimpleVarInfo, val, vn::VarName)
@@ -429,7 +426,7 @@ function assume(
429426
)
430427
value = init(rng, dist, sampler)
431428
# Transform if we're working in unconstrained space.
432-
value_raw = maybe_reconstruct_and_link(vi, vn, dist, value)
429+
value_raw = maybe_link(vi, vn, dist, value)
433430
vi = BangBang.push!!(vi, vn, value_raw, dist, sampler)
434431
return value, Bijectors.logpdf_with_trans(dist, value, istrans(vi, vn)), vi
435432
end
@@ -447,9 +444,9 @@ function dot_assume(
447444

448445
# Transform if we're working in transformed space.
449446
value_raw = if dists isa Distribution
450-
maybe_reconstruct_and_link.((vi,), vns, (dists,), value)
447+
maybe_link.((vi,), vns, (dists,), value)
451448
else
452-
maybe_reconstruct_and_link.((vi,), vns, dists, value)
449+
maybe_link.((vi,), vns, dists, value)
453450
end
454451

455452
# Update `vi`
@@ -476,7 +473,7 @@ function dot_assume(
476473

477474
# Update `vi`.
478475
for (vn, val) in zip(vns, eachcol(value))
479-
val_linked = maybe_reconstruct_and_link(vi, vn, dist, val)
476+
val_linked = maybe_link(vi, vn, dist, val)
480477
vi = BangBang.setindex!!(vi, val_linked, vn)
481478
end
482479

@@ -491,7 +488,7 @@ function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:NamedTuple{names}}) where {
491488
nt_vals = map(keys(vi)) do vn
492489
val = vi[vn]
493490
vns = collect(DynamicPPL.TestUtils.varname_leaves(vn, val))
494-
vals = map(copy Base.Fix1(getindex, vi), vns)
491+
vals = map(Base.Fix1(getindex, vi), vns)
495492
(vals, map(string, vns))
496493
end
497494

@@ -504,7 +501,7 @@ function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:Dict})
504501
# Extract the leaf varnames and values.
505502
val = vi[vn]
506503
vns = collect(DynamicPPL.TestUtils.varname_leaves(vn, val))
507-
vals = map(copy Base.Fix1(getindex, vi), vns)
504+
vals = map(Base.Fix1(getindex, vi), vns)
508505

509506
# Determine the corresponding symbol.
510507
sym = only(unique(map(getsym, vns)))

src/threadsafe.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,5 +178,3 @@ end
178178

179179
istrans(vi::ThreadSafeVarInfo, vn::VarName) = istrans(vi.varinfo, vn)
180180
istrans(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = istrans(vi.varinfo, vns)
181-
182-
getval(vi::ThreadSafeVarInfo, vn::VarName) = getval(vi.varinfo, vn)

src/transforming.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ function tilde_assume(
1515

1616
# Only transform if `!isinverse` since `vi[vn, right]`
1717
# already performs the inverse transformation if it's transformed.
18-
r_transformed = isinverse ? r : link_transform(right)(r)
18+
r_transformed = isinverse ? r : bijector(right)(r)
1919
return r, lp, setindex!!(vi, r_transformed, vn)
2020
end
2121

@@ -27,7 +27,7 @@ function dot_tilde_assume(
2727
vi,
2828
) where {isinverse}
2929
r = getindex.((vi,), vns, (dist,))
30-
b = link_transform(dist)
30+
b = bijector(dist)
3131

3232
is_trans_uniques = unique(istrans.((vi,), vns))
3333
@assert length(is_trans_uniques) == 1 "DynamicTransformationContext only supports transforming all variables"
@@ -70,7 +70,7 @@ function dot_tilde_assume(
7070
@assert !isinverse "Trying to invlink non-transformed variables"
7171
end
7272

73-
b = link_transform(dist)
73+
b = bijector(dist)
7474
for (vn, ri) in zip(vns, eachcol(r))
7575
# Only transform if `!isinverse` since `vi[vn, right]`
7676
# already performs the inverse transformation if it's transformed.

src/utils.jl

Lines changed: 1 addition & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -177,39 +177,10 @@ function to_namedtuple_expr(syms, vals)
177177
return :(NamedTuple{$names_expr}($vals_expr))
178178
end
179179

180-
"""
181-
link_transform(dist)
182-
183-
Return the constrained-to-unconstrained bijector for distribution `dist`.
184-
185-
By default, this is just `Bijectors.bijector(dist)`.
186-
187-
!!! warning
188-
Note that currently this is not used by `Bijectors.logpdf_with_trans`,
189-
hence that needs to be overloaded separately if the intention is
190-
to change behavior of an existing distribution.
191-
"""
192-
link_transform(dist) = bijector(dist)
193-
194-
"""
195-
invlink_transform(dist)
196-
197-
Return the unconstrained-to-constrained bijector for distribution `dist`.
198-
199-
By default, this is just `inverse(link_transform(dist))`.
200-
201-
!!! warning
202-
Note that currently this is not used by `Bijectors.logpdf_with_trans`,
203-
hence that needs to be overloaded separately if the intention is
204-
to change behavior of an existing distribution.
205-
"""
206-
invlink_transform(dist) = inverse(link_transform(dist))
207-
208180
#####################################################
209181
# Helper functions for vectorize/reconstruct values #
210182
#####################################################
211183

212-
vectorize(d, r) = vec(r)
213184
vectorize(d::UnivariateDistribution, r::Real) = [r]
214185
vectorize(d::MultivariateDistribution, r::AbstractVector{<:Real}) = copy(r)
215186
vectorize(d::MatrixDistribution, r::AbstractMatrix{<:Real}) = copy(vec(r))
@@ -220,23 +191,7 @@ vectorize(d::MatrixDistribution, r::AbstractMatrix{<:Real}) = copy(vec(r))
220191
# otherwise we will have error for MatrixDistribution.
221192
# Note this is not the case for MultivariateDistribution so I guess this might be lack of
222193
# support for some types related to matrices (like PDMat).
223-
224-
"""
225-
reconstruct([f, ]dist, val)
226-
227-
Reconstruct `val` so that it's compatible with `dist`.
228-
229-
If `f` is also provided, the reconstruct value will be
230-
such that `f(reconstruct_val)` is compatible with `dist`.
231-
"""
232-
reconstruct(f, dist, val) = reconstruct(dist, val)
233-
234-
# No-op versions.
235-
reconstruct(::UnivariateDistribution, val::Real) = val
236-
reconstruct(::MultivariateDistribution, val::AbstractVector{<:Real}) = copy(val)
237-
reconstruct(::MatrixDistribution, val::AbstractMatrix{<:Real}) = copy(val)
238-
# TODO: Implement no-op `reconstruct` for general array variates.
239-
194+
reconstruct(d::UnivariateDistribution, val::Real) = val
240195
reconstruct(d::Distribution, val::AbstractVector) = reconstruct(size(d), val)
241196
reconstruct(::Tuple{}, val::AbstractVector) = val[1]
242197
reconstruct(s::NTuple{1}, val::AbstractVector) = copy(val)

0 commit comments

Comments
 (0)