Skip to content

InitContext, part 5 - Remove SamplingContext, SampleFrom{Prior,Uniform}, {tilde_,}assume #985

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: py/actually-use-init
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 2 additions & 16 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -456,12 +456,12 @@ AbstractPPL.evaluate!!

This method mutates the `varinfo` used for execution.
By default, it does not perform any actual sampling: it only evaluates the model using the values of the variables that are already in the `varinfo`.
If you wish to sample new values, see the section on [VarInfo initialisation](#VarInfo-initialisation) just below this.

The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model.
Contexts are subtypes of `AbstractPPL.AbstractContext`.

```@docs
SamplingContext
DefaultContext
PrefixContext
ConditionContext
Expand Down Expand Up @@ -495,15 +495,7 @@ DynamicPPL.init

### Samplers

In DynamicPPL two samplers are defined that are used to initialize unobserved random variables:
[`SampleFromPrior`](@ref) which samples from the prior distribution, and [`SampleFromUniform`](@ref) which samples from a uniform distribution.

```@docs
SampleFromPrior
SampleFromUniform
```

Additionally, a generic sampler for inference is implemented.
In DynamicPPL a generic sampler for inference is implemented.

```@docs
Sampler
Expand All @@ -529,9 +521,3 @@ There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_va
DynamicPPL.Experimental.determine_suitable_varinfo
DynamicPPL.Experimental.is_suitable_varinfo
```

### [Model-Internal Functions](@id model_internal)

```@docs
tilde_assume
```
2 changes: 0 additions & 2 deletions ext/DynamicPPLEnzymeCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ else
using ..EnzymeCore
end

@inline EnzymeCore.EnzymeRules.inactive_type(::Type{<:DynamicPPL.SamplingContext}) = true

# Mark istrans as having 0 derivative. The `nothing` return value is not significant, Enzyme
# only checks whether such a method exists, and never runs it.
@inline EnzymeCore.EnzymeRules.inactive_noinl(::typeof(DynamicPPL.istrans), args...) =
Expand Down
3 changes: 0 additions & 3 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,10 @@ export AbstractVarInfo,
values_as_in_model,
# Samplers
Sampler,
SampleFromPrior,
SampleFromUniform,
# LogDensityFunction
LogDensityFunction,
# Contexts
contextualize,
SamplingContext,
DefaultContext,
PrefixContext,
ConditionContext,
Expand Down
137 changes: 13 additions & 124 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
@@ -1,94 +1,39 @@
# assume
"""
tilde_assume(context::SamplingContext, right, vn, vi)

Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
accumulate the log probability, and return the sampled value with a context associated
with a sampler.

Falls back to
```julia
tilde_assume(context.rng, context.context, context.sampler, right, vn, vi)
```
"""
function tilde_assume(context::SamplingContext, right, vn, vi)
return tilde_assume(context.rng, context.context, context.sampler, right, vn, vi)
end

function tilde_assume(context::AbstractContext, args...)
return tilde_assume(childcontext(context), args...)
end
function tilde_assume(::DefaultContext, right, vn, vi)
return assume(right, vn, vi)
end

function tilde_assume(rng::Random.AbstractRNG, context::AbstractContext, args...)
return tilde_assume(rng, childcontext(context), args...)
function tilde_assume!!(context::AbstractContext, right::Distribution, vn, vi)
return tilde_assume!!(childcontext(context), right, vn, vi)
end
function tilde_assume(rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, vi)
return assume(rng, sampler, right, vn, vi)
end
function tilde_assume(rng::Random.AbstractRNG, ::InitContext, sampler, right, vn, vi)
@warn(
"Encountered SamplingContext->InitContext. This method will be removed in the next PR.",
)
# just pretend the `InitContext` isn't there for now.
return assume(rng, sampler, right, vn, vi)
end
function tilde_assume(::DefaultContext, sampler, right, vn, vi)
# same as above but no rng
return assume(Random.default_rng(), sampler, right, vn, vi)
function tilde_assume!!(::DefaultContext, right::Distribution, vn, vi)
y = getindex_internal(vi, vn)
f = from_maybe_linked_internal_transform(vi, vn, right)
x, logjac = with_logabsdet_jacobian(f, y)
vi = accumulate_assume!!(vi, x, logjac, vn, right)
return x, vi
end

function tilde_assume(context::PrefixContext, right, vn, vi)
function tilde_assume!!(context::PrefixContext, right::Distribution, vn, vi)
# Note that we can't use something like this here:
# new_vn = prefix(context, vn)
# return tilde_assume(childcontext(context), right, new_vn, vi)
# return tilde_assume!!(childcontext(context), right, new_vn, vi)
# This is because `prefix` applies _all_ prefixes in a given context to a
# variable name. Thus, if we had two levels of nested prefixes e.g.
# `PrefixContext{:a}(PrefixContext{:b}(DefaultContext()))`, then the
# first call would apply the prefix `a.b._`, and the recursive call
# would apply the prefix `b._`, resulting in `b.a.b._`.
# This is why we need a special function, `prefix_and_strip_contexts`.
new_vn, new_context = prefix_and_strip_contexts(context, vn)
return tilde_assume(new_context, right, new_vn, vi)
end
function tilde_assume(
rng::Random.AbstractRNG, context::PrefixContext, sampler, right, vn, vi
)
new_vn, new_context = prefix_and_strip_contexts(context, vn)
return tilde_assume(rng, new_context, sampler, right, new_vn, vi)
return tilde_assume!!(new_context, right, new_vn, vi)
end

"""
tilde_assume!!(context, right, vn, vi)

Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
accumulate the log probability, and return the sampled value and updated `vi`.

By default, calls `tilde_assume(context, right, vn, vi)` and accumulates the log
probability of `vi` with the returned value.
"""
function tilde_assume!!(context, right, vn, vi)
return if right isa DynamicPPL.Submodel
_evaluate!!(right, vi, context, vn)
else
tilde_assume(context, right, vn, vi)
end
function tilde_assume!!(context, right::DynamicPPL.Submodel, vn, vi)
return _evaluate!!(right, vi, context, vn)
end

# observe
"""
tilde_observe!!(context::SamplingContext, right, left, vi)

Handle observed constants with a `context` associated with a sampler.

Falls back to `tilde_observe!!(context.context, right, left, vi)`.
"""
function tilde_observe!!(context::SamplingContext, right, left, vn, vi)
return tilde_observe!!(context.context, right, left, vn, vi)
end

function tilde_observe!!(context::AbstractContext, right, left, vn, vi)
return tilde_observe!!(childcontext(context), right, left, vn, vi)
end
Expand Down Expand Up @@ -121,59 +66,3 @@ function tilde_observe!!(::DefaultContext, right, left, vn, vi)
vi = accumulate_observe!!(vi, right, left, vn)
return left, vi
end

function assume(::Random.AbstractRNG, spl::Sampler, dist)
return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))")
end

# fallback without sampler
function assume(dist::Distribution, vn::VarName, vi)
y = getindex_internal(vi, vn)
f = from_maybe_linked_internal_transform(vi, vn, dist)
x, logjac = with_logabsdet_jacobian(f, y)
vi = accumulate_assume!!(vi, x, logjac, vn, dist)
return x, vi
end

# TODO: Remove this thing.
# SampleFromPrior and SampleFromUniform
function assume(
rng::Random.AbstractRNG,
sampler::Union{SampleFromPrior,SampleFromUniform},
dist::Distribution,
vn::VarName,
vi::VarInfoOrThreadSafeVarInfo,
)
if haskey(vi, vn)
# Always overwrite the parameters with new ones for `SampleFromUniform`.
if sampler isa SampleFromUniform || is_flagged(vi, vn, "del")
# TODO(mhauru) Is it important to unset the flag here? The `true` allows us
# to ignore the fact that for VarNamedVector this does nothing, but I'm unsure
# if that's okay.
unset_flag!(vi, vn, "del", true)
r = init(rng, dist, sampler)
f = to_maybe_linked_internal_transform(vi, vn, dist)
# TODO(mhauru) This should probably be call a function called setindex_internal!
vi = BangBang.setindex!!(vi, f(r), vn)
setorder!(vi, vn, get_num_produce(vi))
else
# Otherwise we just extract it.
r = vi[vn, dist]
end
else
r = init(rng, dist, sampler)
if istrans(vi)
f = to_linked_internal_transform(vi, vn, dist)
vi = push!!(vi, vn, f(r), dist)
# By default `push!!` sets the transformed flag to `false`.
vi = settrans!!(vi, true, vn)
else
vi = push!!(vi, vn, r, dist)
end
end

# HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct.
logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r)
vi = accumulate_assume!!(vi, r, -logjac, vn, dist)
return r, vi
end
71 changes: 2 additions & 69 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ effectively updating the child context.
```jldoctest
julia> using DynamicPPL: DynamicTransformationContext

julia> ctx = SamplingContext();
julia> ctx = ConditionContext((; a = 1);

julia> DynamicPPL.childcontext(ctx)
DefaultContext()
Expand Down Expand Up @@ -121,73 +121,6 @@ setleafcontext(::IsLeaf, ::IsParent, left, right) = right
setleafcontext(::IsLeaf, ::IsLeaf, left, right) = right

# Contexts
"""
SamplingContext(
[rng::Random.AbstractRNG=Random.default_rng()],
[sampler::AbstractSampler=SampleFromPrior()],
[context::AbstractContext=DefaultContext()],
)

Create a context that allows you to sample parameters with the `sampler` when running the model.
The `context` determines how the returned log density is computed when running the model.

See also: [`DefaultContext`](@ref)
"""
struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractContext
rng::R
sampler::S
context::C
end

function SamplingContext(
rng::Random.AbstractRNG=Random.default_rng(), sampler::AbstractSampler=SampleFromPrior()
)
return SamplingContext(rng, sampler, DefaultContext())
end

function SamplingContext(
sampler::AbstractSampler, context::AbstractContext=DefaultContext()
)
return SamplingContext(Random.default_rng(), sampler, context)
end

function SamplingContext(rng::Random.AbstractRNG, context::AbstractContext)
return SamplingContext(rng, SampleFromPrior(), context)
end

function SamplingContext(context::AbstractContext)
return SamplingContext(Random.default_rng(), SampleFromPrior(), context)
end

NodeTrait(context::SamplingContext) = IsParent()
childcontext(context::SamplingContext) = context.context
function setchildcontext(parent::SamplingContext, child)
return SamplingContext(parent.rng, parent.sampler, child)
end

"""
hassampler(context)

Return `true` if `context` has a sampler.
"""
hassampler(::SamplingContext) = true
hassampler(context::AbstractContext) = hassampler(NodeTrait(context), context)
hassampler(::IsLeaf, context::AbstractContext) = false
hassampler(::IsParent, context::AbstractContext) = hassampler(childcontext(context))

"""
getsampler(context)

Return the sampler of the context `context`.

This will traverse the context tree until it reaches the first [`SamplingContext`](@ref),
at which point it will return the sampler of that context.
"""
getsampler(context::SamplingContext) = context.sampler
getsampler(context::AbstractContext) = getsampler(NodeTrait(context), context)
getsampler(::IsParent, context::AbstractContext) = getsampler(childcontext(context))
getsampler(::IsLeaf, ::AbstractContext) = error("No sampler found in context")

"""
struct DefaultContext <: AbstractContext end

Expand Down Expand Up @@ -252,7 +185,7 @@ PrefixContexts removed.

NOTE: This does _not_ modify any variables in any `ConditionContext` and
`FixedContext` that may be present in the context stack. This is because this
function is only used in `tilde_assume`, which is lower in the tilde-pipeline
function is only used in `tilde_assume!!`, which is lower in the tilde-pipeline
than `contextual_isassumption` and `contextual_isfixed` (the functions which
actually use the `ConditionContext` and `FixedContext` values). Thus, by this
time, any `ConditionContext`s and `FixedContext`s present have already served
Expand Down
2 changes: 1 addition & 1 deletion src/contexts/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ struct InitContext{R<:Random.AbstractRNG,S<:AbstractInitStrategy} <: AbstractCon
end
NodeTrait(::InitContext) = IsLeaf()

function tilde_assume(
function tilde_assume!!(
ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo
)
in_varinfo = haskey(vi, vn)
Expand Down
17 changes: 5 additions & 12 deletions src/debug_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,17 +253,9 @@ function record_post_tilde_assume!(context::DebugContext, vn, dist, value, varin
return nothing
end

function DynamicPPL.tilde_assume(context::DebugContext, right, vn, vi)
function DynamicPPL.tilde_assume!!(context::DebugContext, right, vn, vi)
record_pre_tilde_assume!(context, vn, right, vi)
value, vi = DynamicPPL.tilde_assume(childcontext(context), right, vn, vi)
record_post_tilde_assume!(context, vn, right, value, vi)
return value, vi
end
function DynamicPPL.tilde_assume(
rng::Random.AbstractRNG, context::DebugContext, sampler, right, vn, vi
)
record_pre_tilde_assume!(context, vn, right, vi)
value, vi = DynamicPPL.tilde_assume(rng, childcontext(context), sampler, right, vn, vi)
value, vi = DynamicPPL.tilde_assume!!(childcontext(context), right, vn, vi)
record_post_tilde_assume!(context, vn, right, value, vi)
return value, vi
end
Expand Down Expand Up @@ -438,9 +430,10 @@ function check_model_and_trace(
kwargs...,
)
# Execute the model with the debug context.
debug_context = DebugContext(
SamplingContext(rng, model.context); error_on_failure=error_on_failure, kwargs...
new_context = DynamicPPL.setleafcontext(
model.context, DynamicPPL.InitContext(rng, DynamicPPL.PriorInit())
)
debug_context = DebugContext(new_context; error_on_failure=error_on_failure, kwargs...)
debug_model = DynamicPPL.contextualize(model, debug_context)

# Perform checks before evaluating the model.
Expand Down
Loading
Loading