Skip to content

Commit 7e38bbe

Browse files
committed
Merge remote-tracking branch 'origin/py/init-prior-uniform' into py/actually-use-init
2 parents 6c4bc4a + 044cb24 commit 7e38bbe

15 files changed

+260
-197
lines changed

docs/src/api.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ get_num_produce
341341
set_num_produce!!
342342
increment_num_produce!!
343343
reset_num_produce!!
344-
setorder!
344+
setorder!!
345345
set_retained_vns_del!
346346
```
347347

@@ -368,7 +368,7 @@ DynamicPPL provides the following default accumulators.
368368
```@docs
369369
LogPriorAccumulator
370370
LogLikelihoodAccumulator
371-
NumProduceAccumulator
371+
VariableOrderAccumulator
372372
```
373373

374374
### Common API

src/DynamicPPL.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ export AbstractVarInfo,
5050
AbstractAccumulator,
5151
LogLikelihoodAccumulator,
5252
LogPriorAccumulator,
53-
NumProduceAccumulator,
53+
VariableOrderAccumulator,
5454
push!!,
5555
empty!!,
5656
subset,
@@ -73,7 +73,7 @@ export AbstractVarInfo,
7373
is_flagged,
7474
set_flag!,
7575
unset_flag!,
76-
setorder!,
76+
setorder!!,
7777
istrans,
7878
link,
7979
link!!,

src/abstract_varinfo.jl

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,24 @@ function resetlogp!!(vi::AbstractVarInfo)
374374
return vi
375375
end
376376

377+
"""
378+
setorder!!(vi::AbstractVarInfo, vn::VarName, index::Integer)
379+
380+
Set the `order` of `vn` in `vi` to `index`, where `order` is the number of `observe
381+
statements run before sampling `vn`.
382+
"""
383+
function setorder!!(vi::AbstractVarInfo, vn::VarName, index::Integer)
384+
return map_accumulator!!(acc -> (acc.order[vn] = index; acc), vi, Val(:VariableOrder))
385+
end
386+
387+
"""
388+
getorder(vi::VarInfo, vn::VarName)
389+
390+
Get the `order` of `vn` in `vi`, where `order` is the number of `observe` statements
391+
run before sampling `vn`.
392+
"""
393+
getorder(vi::AbstractVarInfo, vn::VarName) = getacc(vi, Val(:VariableOrder)).order[vn]
394+
377395
# Variables and their realizations.
378396
@doc """
379397
keys(vi::AbstractVarInfo)
@@ -980,29 +998,37 @@ end
980998
981999
Return the `num_produce` of `vi`.
9821000
"""
983-
get_num_produce(vi::AbstractVarInfo) = getacc(vi, Val(:NumProduce)).num
1001+
get_num_produce(vi::AbstractVarInfo) = getacc(vi, Val(:VariableOrder)).num_produce
9841002

9851003
"""
9861004
set_num_produce!!(vi::AbstractVarInfo, n::Int)
9871005
9881006
Set the `num_produce` field of `vi` to `n`.
9891007
"""
990-
set_num_produce!!(vi::AbstractVarInfo, n::Int) = setacc!!(vi, NumProduceAccumulator(n))
1008+
function set_num_produce!!(vi::AbstractVarInfo, n::Integer)
1009+
if hasacc(vi, Val(:VariableOrder))
1010+
acc = getacc(vi, Val(:VariableOrder))
1011+
acc = VariableOrderAccumulator(n, acc.order)
1012+
else
1013+
acc = VariableOrderAccumulator(n)
1014+
end
1015+
return setacc!!(vi, acc)
1016+
end
9911017

9921018
"""
9931019
increment_num_produce!!(vi::AbstractVarInfo)
9941020
9951021
Add 1 to `num_produce` in `vi`.
9961022
"""
9971023
increment_num_produce!!(vi::AbstractVarInfo) =
998-
map_accumulator!!(increment, vi, Val(:NumProduce))
1024+
map_accumulator!!(increment, vi, Val(:VariableOrder))
9991025

10001026
"""
10011027
reset_num_produce!!(vi::AbstractVarInfo)
10021028
10031029
Reset the value of `num_produce` in `vi` to 0.
10041030
"""
1005-
reset_num_produce!!(vi::AbstractVarInfo) = map_accumulator!!(zero, vi, Val(:NumProduce))
1031+
reset_num_produce!!(vi::AbstractVarInfo) = set_num_produce!!(vi, zero(get_num_produce(vi)))
10061032

10071033
"""
10081034
from_internal_transform(varinfo::AbstractVarInfo, vn::VarName[, dist])

src/accumulators.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ An accumulator type `T <: AbstractAccumulator` must implement the following meth
1313
- `accumulator_name(acc::T)` or `accumulator_name(::Type{T})`
1414
- `accumulate_observe!!(acc::T, right, left, vn)`
1515
- `accumulate_assume!!(acc::T, val, logjac, vn, right)`
16+
- `Base.copy(acc::T)`
1617
1718
To be able to work with multi-threading, it should also implement:
1819
- `split(acc::T)`
@@ -138,6 +139,9 @@ function Base.haskey(at::AccumulatorTuple, ::Val{accname}) where {accname}
138139
@inline return haskey(at.nt, accname)
139140
end
140141
Base.keys(at::AccumulatorTuple) = keys(at.nt)
142+
Base.:(==)(at1::AccumulatorTuple, at2::AccumulatorTuple) = at1.nt == at2.nt
143+
Base.hash(at::AccumulatorTuple, h::UInt) = Base.hash((AccumulatorTuple, at.nt), h)
144+
Base.copy(at::AccumulatorTuple) = AccumulatorTuple(map(copy, at.nt))
141145

142146
function Base.convert(::Type{AccumulatorTuple{N,T}}, accs::AccumulatorTuple{N}) where {N,T}
143147
return AccumulatorTuple(convert(T, accs.nt))

src/context_implementations.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,6 @@ function assume(
155155
f = to_maybe_linked_internal_transform(vi, vn, dist)
156156
# TODO(mhauru) This should probably be call a function called setindex_internal!
157157
vi = BangBang.setindex!!(vi, f(r), vn)
158-
setorder!(vi, vn, get_num_produce(vi))
159158
else
160159
# Otherwise we just extract it.
161160
r = vi[vn, dist]

src/debug_utils.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -410,11 +410,7 @@ function check_model_and_trace(
410410
model::Model, varinfo::AbstractVarInfo; error_on_failure=false
411411
)
412412
# Add debug accumulator to the VarInfo.
413-
# Need a NumProduceAccumulator as well or else get_num_produce may throw
414-
# TODO(mhauru) Remove this once VariableOrderAccumulator stuff is done.
415-
varinfo = DynamicPPL.setaccs!!(
416-
deepcopy(varinfo), (DebugAccumulator(error_on_failure), NumProduceAccumulator())
417-
)
413+
varinfo = DynamicPPL.setaccs!!(deepcopy(varinfo), (DebugAccumulator(error_on_failure),))
418414

419415
# Perform checks before evaluating the model.
420416
issuccess = check_model_pre_evaluation(model)

src/default_accumulators.jl

Lines changed: 82 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -41,52 +41,102 @@ LogLikelihoodAccumulator{T}() where {T<:Real} = LogLikelihoodAccumulator(zero(T)
4141
LogLikelihoodAccumulator() = LogLikelihoodAccumulator{LogProbType}()
4242

4343
"""
44-
NumProduceAccumulator{T} <: AbstractAccumulator
44+
VariableOrderAccumulator{T} <: AbstractAccumulator
4545
46-
An accumulator that tracks the number of observations during model execution.
46+
An accumulator that tracks the order of variables in a `VarInfo`.
47+
48+
This doesn't track the full ordering, but rather how many observations have taken place
49+
before the assume statement for each variable. This is needed for particle methods, where
50+
the model is segmented into parts by each observation, and we need to know which part each
51+
assume statement is in.
4752
4853
# Fields
4954
$(TYPEDFIELDS)
5055
"""
51-
struct NumProduceAccumulator{T<:Integer} <: AbstractAccumulator
56+
struct VariableOrderAccumulator{Eltype<:Integer,VNType<:VarName} <: AbstractAccumulator
5257
"the number of observations"
53-
num::T
58+
num_produce::Eltype
59+
"mapping of variable names to their order in the model"
60+
order::Dict{VNType,Eltype}
5461
end
5562

5663
"""
57-
NumProduceAccumulator{T<:Integer}()
64+
VariableOrderAccumulator{T<:Integer}(n=zero(T))
5865
59-
Create a new `NumProduceAccumulator` accumulator with the number of observations initialized to zero.
66+
Create a new `VariableOrderAccumulator` with the number of observations set to `n`.
6067
"""
61-
NumProduceAccumulator{T}() where {T<:Integer} = NumProduceAccumulator(zero(T))
62-
NumProduceAccumulator() = NumProduceAccumulator{Int}()
68+
VariableOrderAccumulator{T}(n=zero(T)) where {T<:Integer} =
69+
VariableOrderAccumulator(convert(T, n), Dict{VarName,T}())
70+
VariableOrderAccumulator(n) = VariableOrderAccumulator{typeof(n)}(n)
71+
VariableOrderAccumulator() = VariableOrderAccumulator{Int}()
72+
73+
Base.copy(acc::LogPriorAccumulator) = acc
74+
Base.copy(acc::LogLikelihoodAccumulator) = acc
75+
function Base.copy(acc::VariableOrderAccumulator)
76+
return VariableOrderAccumulator(acc.num_produce, copy(acc.order))
77+
end
6378

6479
function Base.show(io::IO, acc::LogPriorAccumulator)
6580
return print(io, "LogPriorAccumulator($(repr(acc.logp)))")
6681
end
6782
function Base.show(io::IO, acc::LogLikelihoodAccumulator)
6883
return print(io, "LogLikelihoodAccumulator($(repr(acc.logp)))")
6984
end
70-
function Base.show(io::IO, acc::NumProduceAccumulator)
71-
return print(io, "NumProduceAccumulator($(repr(acc.num)))")
85+
function Base.show(io::IO, acc::VariableOrderAccumulator)
86+
return print(
87+
io, "VariableOrderAccumulator($(repr(acc.num_produce)), $(repr(acc.order)))"
88+
)
89+
end
90+
91+
# Note that == and isequal are different, and equality under the latter should imply
92+
# equality of hashes. Both of the below implementations are also different from the default
93+
# implementation for structs.
94+
Base.:(==)(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator) = acc1.logp == acc2.logp
95+
function Base.:(==)(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator)
96+
return acc1.logp == acc2.logp
97+
end
98+
function Base.:(==)(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator)
99+
return acc1.num_produce == acc2.num_produce && acc1.order == acc2.order
100+
end
101+
102+
function Base.isequal(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator)
103+
return isequal(acc1.logp, acc2.logp)
104+
end
105+
function Base.isequal(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator)
106+
return isequal(acc1.logp, acc2.logp)
107+
end
108+
function Base.isequal(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator)
109+
return isequal(acc1.num_produce, acc2.num_produce) && isequal(acc1.order, acc2.order)
110+
end
111+
112+
Base.hash(acc::LogPriorAccumulator, h::UInt) = hash((LogPriorAccumulator, acc.logp), h)
113+
function Base.hash(acc::LogLikelihoodAccumulator, h::UInt)
114+
return hash((LogLikelihoodAccumulator, acc.logp), h)
115+
end
116+
function Base.hash(acc::VariableOrderAccumulator, h::UInt)
117+
return hash((VariableOrderAccumulator, acc.num_produce, acc.order), h)
72118
end
73119

74120
accumulator_name(::Type{<:LogPriorAccumulator}) = :LogPrior
75121
accumulator_name(::Type{<:LogLikelihoodAccumulator}) = :LogLikelihood
76-
accumulator_name(::Type{<:NumProduceAccumulator}) = :NumProduce
122+
accumulator_name(::Type{<:VariableOrderAccumulator}) = :VariableOrder
77123

78124
split(::LogPriorAccumulator{T}) where {T} = LogPriorAccumulator(zero(T))
79125
split(::LogLikelihoodAccumulator{T}) where {T} = LogLikelihoodAccumulator(zero(T))
80-
split(acc::NumProduceAccumulator) = acc
126+
split(acc::VariableOrderAccumulator) = copy(acc)
81127

82128
function combine(acc::LogPriorAccumulator, acc2::LogPriorAccumulator)
83129
return LogPriorAccumulator(acc.logp + acc2.logp)
84130
end
85131
function combine(acc::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator)
86132
return LogLikelihoodAccumulator(acc.logp + acc2.logp)
87133
end
88-
function combine(acc::NumProduceAccumulator, acc2::NumProduceAccumulator)
89-
return NumProduceAccumulator(max(acc.num, acc2.num))
134+
function combine(acc::VariableOrderAccumulator, acc2::VariableOrderAccumulator)
135+
# Note that assumptions are not allowed in parallelised blocks, and thus the
136+
# dictionaries should be identical.
137+
return VariableOrderAccumulator(
138+
max(acc.num_produce, acc2.num_produce), merge(acc.order, acc2.order)
139+
)
90140
end
91141

92142
function Base.:+(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator)
@@ -95,11 +145,12 @@ end
95145
function Base.:+(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator)
96146
return LogLikelihoodAccumulator(acc1.logp + acc2.logp)
97147
end
98-
increment(acc::NumProduceAccumulator) = NumProduceAccumulator(acc.num + oneunit(acc.num))
148+
function increment(acc::VariableOrderAccumulator)
149+
return VariableOrderAccumulator(acc.num_produce + oneunit(acc.num_produce), acc.order)
150+
end
99151

100152
Base.zero(acc::LogPriorAccumulator) = LogPriorAccumulator(zero(acc.logp))
101153
Base.zero(acc::LogLikelihoodAccumulator) = LogLikelihoodAccumulator(zero(acc.logp))
102-
Base.zero(acc::NumProduceAccumulator) = NumProduceAccumulator(zero(acc.num))
103154

104155
function accumulate_assume!!(acc::LogPriorAccumulator, val, logjac, vn, right)
105156
return acc + LogPriorAccumulator(logpdf(right, val) + logjac)
@@ -114,8 +165,11 @@ function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn)
114165
return acc + LogLikelihoodAccumulator(Distributions.loglikelihood(right, left))
115166
end
116167

117-
accumulate_assume!!(acc::NumProduceAccumulator, val, logjac, vn, right) = acc
118-
accumulate_observe!!(acc::NumProduceAccumulator, right, left, vn) = increment(acc)
168+
function accumulate_assume!!(acc::VariableOrderAccumulator, val, logjac, vn, right)
169+
acc.order[vn] = acc.num_produce
170+
return acc
171+
end
172+
accumulate_observe!!(acc::VariableOrderAccumulator, right, left, vn) = increment(acc)
119173

120174
function Base.convert(::Type{LogPriorAccumulator{T}}, acc::LogPriorAccumulator) where {T}
121175
return LogPriorAccumulator(convert(T, acc.logp))
@@ -126,15 +180,19 @@ function Base.convert(
126180
return LogLikelihoodAccumulator(convert(T, acc.logp))
127181
end
128182
function Base.convert(
129-
::Type{NumProduceAccumulator{T}}, acc::NumProduceAccumulator
130-
) where {T}
131-
return NumProduceAccumulator(convert(T, acc.num))
183+
::Type{VariableOrderAccumulator{ElType,VnType}}, acc::VariableOrderAccumulator
184+
) where {ElType,VnType}
185+
order = Dict{VnType,ElType}()
186+
for (k, v) in acc.order
187+
order[convert(VnType, k)] = convert(ElType, v)
188+
end
189+
return VariableOrderAccumulator(convert(ElType, acc.num_produce), order)
132190
end
133191

134192
# TODO(mhauru)
135-
# We ignore the convert_eltype calls for NumProduceAccumulator, by letting them fallback on
193+
# We ignore the convert_eltype calls for VariableOrderAccumulator, by letting them fallback on
136194
# convert_eltype(::AbstractAccumulator, ::Type). This is because they are only used to
137-
# deal with dual number types of AD backends, which shouldn't concern NumProduceAccumulator. This is
195+
# deal with dual number types of AD backends, which shouldn't concern VariableOrderAccumulator. This is
138196
# horribly hacky and should be fixed. See also comment in `unflatten` in `src/varinfo.jl`.
139197
function convert_eltype(::Type{T}, acc::LogPriorAccumulator) where {T}
140198
return LogPriorAccumulator(convert(T, acc.logp))
@@ -149,6 +207,6 @@ function default_accumulators(
149207
return AccumulatorTuple(
150208
LogPriorAccumulator{FloatT}(),
151209
LogLikelihoodAccumulator{FloatT}(),
152-
NumProduceAccumulator{IntT}(),
210+
VariableOrderAccumulator{IntT}(),
153211
)
154212
end

src/extract_priors.jl

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ end
44

55
PriorDistributionAccumulator() = PriorDistributionAccumulator(OrderedDict{VarName,Any}())
66

7+
function Base.copy(acc::PriorDistributionAccumulator)
8+
return PriorDistributionAccumulator(copy(acc.priors))
9+
end
10+
711
accumulator_name(::PriorDistributionAccumulator) = :PriorDistributionAccumulator
812

913
split(acc::PriorDistributionAccumulator) = PriorDistributionAccumulator(empty(acc.priors))
@@ -112,10 +116,7 @@ extract_priors(args::Union{Model,AbstractVarInfo}...) =
112116
extract_priors(Random.default_rng(), args...)
113117
function extract_priors(rng::Random.AbstractRNG, model::Model)
114118
varinfo = VarInfo()
115-
# TODO(mhauru) This doesn't actually need the NumProduceAccumulator, it's only a
116-
# workaround for the fact that `order` is still hardcoded in VarInfo, and hence you
117-
# can't push new variables without knowing the num_produce. Remove this when possible.
118-
varinfo = setaccs!!(varinfo, (PriorDistributionAccumulator(), NumProduceAccumulator()))
119+
varinfo = setaccs!!(varinfo, (PriorDistributionAccumulator(),))
119120
varinfo = last(init!!(rng, model, varinfo))
120121
return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors
121122
end
@@ -129,12 +130,7 @@ This is done by evaluating the model at the values present in `varinfo`
129130
and recording the distributions that are present at each tilde statement.
130131
"""
131132
function extract_priors(model::Model, varinfo::AbstractVarInfo)
132-
# TODO(mhauru) This doesn't actually need the NumProduceAccumulator, it's only a
133-
# workaround for the fact that `order` is still hardcoded in VarInfo, and hence you
134-
# can't push new variables without knowing the num_produce. Remove this when possible.
135-
varinfo = setaccs!!(
136-
deepcopy(varinfo), (PriorDistributionAccumulator(), NumProduceAccumulator())
137-
)
133+
varinfo = setaccs!!(deepcopy(varinfo), (PriorDistributionAccumulator(),))
138134
varinfo = last(evaluate!!(model, varinfo))
139135
return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors
140136
end

src/pointwise_logdensities.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ function PointwiseLogProbAccumulator{whichlogprob,KeyType}() where {whichlogprob
3232
return PointwiseLogProbAccumulator{whichlogprob,KeyType,typeof(logps)}(logps)
3333
end
3434

35+
function Base.copy(acc::PointwiseLogProbAccumulator{whichlogprob}) where {whichlogprob}
36+
return PointwiseLogProbAccumulator{whichlogprob}(copy(acc.logps))
37+
end
38+
3539
function Base.push!(acc::PointwiseLogProbAccumulator, vn, logp)
3640
logps = acc.logps
3741
# The last(fieldtypes(eltype(...))) gets the type of the values, rather than the keys.

0 commit comments

Comments
 (0)