Skip to content

Commit 2597193

Browse files
committed
Fix remaining bugs e.g. in nested decondition_context
1 parent f3d4fa9 commit 2597193

File tree

4 files changed

+49
-13
lines changed

4 files changed

+49
-13
lines changed

src/contexts.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,9 +447,17 @@ function decondition_context(context::ConditionContext, sym, syms...)
447447
# No more values left, can unwrap
448448
decondition_context(childcontext(context), syms...)
449449
else
450-
ConditionContext(new_values, decondition_context(childcontext(context), syms...))
450+
ConditionContext(
451+
new_values, decondition_context(childcontext(context), sym, syms...)
452+
)
451453
end
452454
end
455+
function decondition_context(context::NamedConditionContext, vn::VarName{sym}) where {sym}
456+
return ConditionContext(
457+
BangBang.delete!!(context.values, sym),
458+
decondition_context(childcontext(context), vn),
459+
)
460+
end
453461

454462
"""
455463
conditioned(context::AbstractContext)

src/model.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ Return a `Model` which now treats variables on the right-hand side as observatio
9696
9797
See [`condition`](@ref) for more information and examples.
9898
"""
99-
Base.:|(model::Model, values::Union{Tuple,NamedTuple,AbstractDict{<:VarName}}) =
99+
Base.:|(model::Model, values::Union{Pair,Tuple,NamedTuple,AbstractDict{<:VarName}}) =
100100
condition(model, values)
101101

102102
"""
@@ -281,14 +281,14 @@ end
281281
282282
Convert different types of input to either a `NamedTuple` or `AbstractDict` of
283283
conditioning values, suitable for storage in a `ConditionContext`.
284+
285+
This handles all the cases where `vals` is either already a NamedTuple or
286+
AbstractDict (e.g. `model | (x=1, y=2)`), as well as if they are splatted (e.g.
287+
`condition(model, x=1, y=2)`).
284288
"""
285-
# Case 1: Already in the right format, e.g. condition(model, (x=1, y=2))
286289
_make_conditioning_values(values::Union{NamedTuple,AbstractDict}) = values
287-
# Case 2: condition(model, (@varname(x) => 1, @varname(y) => 2))
288290
_make_conditioning_values(values::Tuple{Pair{<:VarName}}) = Dict(values)
289-
# Case 3: Case 1 but splatted, e.g. condition(model, x=1, y=2)
290291
_make_conditioning_values(v::Pair{<:Symbol}, vs::Pair{<:Symbol}...) = NamedTuple(v, vs...)
291-
# Case 4: Case 2 but splatted, e.g. condition(model, @varname(x) => 1, @varname(y) => 2)
292292
_make_conditioning_values(v::Pair{<:VarName}, vs::Pair{<:VarName}...) = Dict(v, vs...)
293293

294294
"""
@@ -401,7 +401,7 @@ true
401401
```
402402
"""
403403
function AbstractPPL.decondition(model::Model, syms...)
404-
return contextualize(model, decondition(model.context, syms...))
404+
return contextualize(model, decondition_context(model.context, syms...))
405405
end
406406

407407
"""
@@ -435,7 +435,7 @@ julia> # Returns all the variables we have conditioned on + their values.
435435
(x = 100.0, m = 1.0)
436436
437437
julia> # Nested ones also work (note that `PrefixContext` does nothing to the result).
438-
cm = condition(contextualize(m, PrefixContext{:a}(condition(m=1.0))), x=100.0);
438+
cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext((m=1.0,)))), x=100.0);
439439
440440
julia> conditioned(cm)
441441
(x = 100.0, m = 1.0)
@@ -447,15 +447,15 @@ julia> # Since we conditioned on `m`, not `a.m` as it will appear after prefixed
447447
a.m
448448
449449
julia> # If we instead condition on `a.m`, `m` in the model will be considered an observation.
450-
cm = condition(contextualize(m, PrefixContext{:a}(condition(var"a.m"=1.0))), x=100.0);
450+
cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext((var"a.m"=1.0,)))), x=100.0);
451451
452452
julia> conditioned(cm).x
453453
100.0
454454
455455
julia> conditioned(cm).var"a.m"
456456
1.0
457457
458-
julia> keys(VarInfo(cm)) # <= no variables are sampled
458+
julia> keys(VarInfo(cm)) # No variables are sampled
459459
VarName[]
460460
```
461461
"""

test/contexts.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,23 @@ end
259259
@test decondition_context(ctx, @varname(x), @varname(y), @varname(z)) isa
260260
DefaultContext
261261
end
262+
263+
@testset "Nesting" begin
264+
ctx = ConditionContext(
265+
(x=1, y=2), ConditionContext(Dict(@varname(a) => 3, @varname(b) => 4))
266+
)
267+
# Decondition an outer variable
268+
dctx = decondition_context(ctx, :x)
269+
@test dctx.values == (y=2,)
270+
@test childcontext(dctx).values == Dict(@varname(a) => 3, @varname(b) => 4)
271+
# Decondition an inner variable
272+
dctx = decondition_context(ctx, @varname(a))
273+
@test dctx.values == (x=1, y=2)
274+
@test childcontext(dctx).values == Dict(@varname(b) => 4)
275+
# Try deconditioning everything
276+
dctx = decondition_context(ctx)
277+
@test dctx isa DefaultContext
278+
end
262279
end
263280
end
264281

test/model.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,25 +100,36 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
100100
end
101101
end
102102

103-
@testset "model conditioning with various arguments" begin
103+
@testset "model de/conditioning" begin
104104
@model function demo_condition()
105105
x ~ Normal()
106106
return y ~ Normal(x)
107107
end
108108
model = demo_condition()
109+
109110
# Test that different syntaxes work and give the same underlying ConditionContext
110-
@testset "NamedTuple ConditionContext" begin
111+
@testset "conditioning NamedTuple" begin
111112
expected_values = (y=2,)
112113
@test condition(model, (y=2,)).context.values == expected_values
113114
@test condition(model; y=2).context.values == expected_values
114115
@test condition(model; y=2).context.values == expected_values
115116
@test (model | (y=2,)).context.values == expected_values
117+
conditioned_model = condition(model, (y=2,))
118+
@test keys(VarInfo(conditioned_model)) == [@varname(x)]
116119
end
117-
@testset "AbstractDict ConditionContext" begin
120+
@testset "conditioning AbstractDict" begin
118121
expected_values = Dict(@varname(y) => 2)
119122
@test condition(model, Dict(@varname(y) => 2)).context.values == expected_values
120123
@test condition(model, @varname(y) => 2).context.values == expected_values
121124
@test (model | (@varname(y) => 2,)).context.values == expected_values
125+
conditioned_model = condition(model, Dict(@varname(y) => 2))
126+
@test keys(VarInfo(conditioned_model)) == [@varname(x)]
127+
end
128+
129+
@testset "deconditioning" begin
130+
conditioned_model = condition(model, (y=2,))
131+
deconditioned_model = decondition(conditioned_model)
132+
@test keys(VarInfo(deconditioned_model)) == [@varname(x), @varname(y)]
122133
end
123134
end
124135

0 commit comments

Comments
 (0)