Skip to content

Commit a02dfac

Browse files
committed
fix more varinfo tests
1 parent 9c87d53 commit a02dfac

File tree

3 files changed

+11
-30
lines changed

3 files changed

+11
-30
lines changed

src/contexts/init.jl

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -94,25 +94,11 @@ struct ParamsInit{P,S<:AbstractInitStrategy} <: AbstractInitStrategy
9494
end
9595
end
9696
function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::ParamsInit)
97-
# TODO(penelopeysm): Fix this. If anything in p.params _subsumes_ vn,
98-
# we don't know how to handle it. This is just another corollary of
99-
# https://github.com/TuringLang/DynamicPPL.jl/issues/814
100-
# This used to be handled by nested_setindex_maybe, which I'd really like
101-
# to get rid of.
102-
if p.params isa AbstractDict{<:VarName}
103-
strictly_subsumed = filter(
104-
vn_in_params -> vn_in_params != vn && subsumes(vn, vn_in_params), keys(p.params)
105-
)
106-
if !isempty(strictly_subsumed)
107-
throw(
108-
ArgumentError(
109-
"The given dictionary of parameters contain the following sub-variables of $(vn): $(strictly_subsumed). ParamsInit doesn't know how to deal with this.",
110-
),
111-
)
112-
end
113-
end
114-
return if hasvalue(p.params, vn)
115-
x = getvalue(p.params, vn)
97+
# TODO(penelopeysm): We should really do a quick check to make sure that all of the
98+
# parameters in `p.params` were actually used, and either warn or error if
99+
# they aren't.
100+
return if hasvalue(p.params, vn, dist)
101+
x = getvalue(p.params, vn, dist)
116102
if x === missing
117103
init(rng, vn, dist, p.default)
118104
else

src/simple_varinfo.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -259,13 +259,8 @@ function SimpleVarInfo{T}(vi::NTVarInfo, ::Type{D}) where {T<:Real,D}
259259
end
260260

261261
function untyped_simple_varinfo(model::Model)
262-
<<<<<<< HEAD
263262
varinfo = SimpleVarInfo(OrderedDict{VarName,Any}())
264-
return last(evaluate_and_sample!!(model, varinfo))
265-
=======
266-
varinfo = SimpleVarInfo(OrderedDict())
267263
return last(init!!(model, varinfo))
268-
>>>>>>> bc16e09 (WIP: InitContext)
269264
end
270265

271266
function typed_simple_varinfo(model::Model)

test/runtests.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,12 @@ include("test_util.jl")
6060
# include("utils.jl")
6161
# include("accumulators.jl")
6262
# include("compiler.jl")
63-
include("varnamedvector.jl")
64-
include("varinfo.jl")
65-
# include("simple_varinfo.jl")
66-
# include("model.jl")
67-
# include("sampler.jl")
68-
# include("distribution_wrappers.jl")
63+
# include("varnamedvector.jl")
64+
# include("varinfo.jl")
65+
include("simple_varinfo.jl")
66+
include("model.jl")
67+
include("sampler.jl")
68+
include("distribution_wrappers.jl")
6969
# include("logdensityfunction.jl")
7070
# include("linking.jl")
7171
# include("serialization.jl")

0 commit comments

Comments
 (0)