Skip to content

Commit 9c87d53

Browse files
committed
Find an edge case
1 parent 19a7c4f commit 9c87d53

File tree

4 files changed

+29
-8
lines changed

4 files changed

+29
-8
lines changed

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ end
4444

4545
function chain_sample_to_varname_dict(c::MCMCChains.Chains, sample_idx, chain_idx)
4646
_check_varname_indexing(c)
47-
d = Dict{VarName}()
47+
d = Dict{DynamicPPL.VarName,Any}()
4848
for vn in DynamicPPL.varnames(c)
4949
d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx)
5050
end
@@ -271,10 +271,7 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha
271271
# return the model's retval (`first`).
272272
first(
273273
DynamicPPL.init!!(
274-
rng,
275-
model,
276-
varinfo,
277-
DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit()),
274+
model, varinfo, DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit())
278275
),
279276
)
280277
end

src/contexts/init.jl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,23 @@ struct ParamsInit{P,S<:AbstractInitStrategy} <: AbstractInitStrategy
9494
end
9595
end
9696
function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::ParamsInit)
97+
# TODO(penelopeysm): Fix this. If anything in p.params _subsumes_ vn,
98+
# we don't know how to handle it. This is just another corollary of
99+
# https://github.com/TuringLang/DynamicPPL.jl/issues/814
100+
# This used to be handled by nested_setindex_maybe, which I'd really like
101+
# to get rid of.
102+
if p.params isa AbstractDict{<:VarName}
103+
strictly_subsumed = filter(
104+
vn_in_params -> vn_in_params != vn && subsumes(vn, vn_in_params), keys(p.params)
105+
)
106+
if !isempty(strictly_subsumed)
107+
throw(
108+
ArgumentError(
109+
"The given dictionary of parameters contain the following sub-variables of $(vn): $(strictly_subsumed). ParamsInit doesn't know how to deal with this.",
110+
),
111+
)
112+
end
113+
end
97114
return if hasvalue(p.params, vn)
98115
x = getvalue(p.params, vn)
99116
if x === missing
@@ -159,12 +176,15 @@ function tilde_assume(
159176
y = f(x)
160177
logjac = logabsdetjac(insert_transformed_value ? link_transform(dist) : identity, x)
161178
# Add the new value to the VarInfo. `push!!` errors if the value already
162-
# exists, hence the need for setindex!!
179+
# exists, hence the need for setindex!!.
163180
if in_varinfo
164181
vi = setindex!!(vi, y, vn)
165182
else
166183
vi = push!!(vi, vn, y, dist)
167184
end
185+
# Neither of these set the `trans` flag so we have to do it manually if
186+
# necessary.
187+
insert_transformed_value && settrans!!(vi, true, vn)
168188
# `accumulate_assume!!` wants untransformed values as the second argument.
169189
vi = accumulate_assume!!(vi, x, -logjac, vn, dist)
170190
# We always return the untransformed value here, as that will determine

src/simple_varinfo.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,6 @@ end
463463

464464
# Context implementations
465465

466-
# NOTE: We don't implement `settrans!!(vi, trans, vn)`.
467466
function settrans!!(vi::SimpleVarInfo, trans)
468467
return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation())
469468
end
@@ -473,6 +472,9 @@ end
473472
function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans)
474473
return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans)
475474
end
475+
function settrans!!(::SimpleOrThreadSafeSimple, trans::Bool, vn::VarName)
476+
@info "Attempting to call `settrans!!` on a `SimpleVarInfo` for a specific variable `$vn`; this will be ignored"
477+
end
476478

477479
istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation)
478480
istrans(vi::SimpleVarInfo, ::VarName) = istrans(vi)

test/test_util.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,10 @@ function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::I
8181
varnames = collect(varnames)
8282
# Construct matrix of values
8383
vals = [get(dict, vn, missing) for dict in dicts, vn in varnames]
84+
# Construct mapping of varnames to symbols
85+
vns_to_syms = Dict{VarName,Symbol}(zip(varnames, Symbol.(varnames)))
8486
# Construct and return the Chains object
85-
return Chains(vals, varnames)
87+
return Chains(vals, varnames; info=(varname_to_symbol=vns_to_syms,))
8688
end
8789
function make_chain_from_prior(model::Model, n_iters::Int)
8890
return make_chain_from_prior(Random.default_rng(), model, n_iters)

0 commit comments

Comments
 (0)