@@ -41,52 +41,102 @@ LogLikelihoodAccumulator{T}() where {T<:Real} = LogLikelihoodAccumulator(zero(T)
41
41
LogLikelihoodAccumulator () = LogLikelihoodAccumulator {LogProbType} ()
42
42
43
43
"""
44
- NumProduceAccumulator {T} <: AbstractAccumulator
44
+ VariableOrderAccumulator {T} <: AbstractAccumulator
45
45
46
- An accumulator that tracks the number of observations during model execution.
46
+ An accumulator that tracks the order of variables in a `VarInfo`.
47
+
48
+ This doesn't track the full ordering, but rather how many observations have taken place
49
+ before the assume statement for each variable. This is needed for particle methods, where
50
+ the model is segmented into parts by each observation, and we need to know which part each
51
+ assume statement is in.
47
52
48
53
# Fields
49
54
$(TYPEDFIELDS)
50
55
"""
51
- struct NumProduceAccumulator{T <: Integer } <: AbstractAccumulator
56
+ struct VariableOrderAccumulator{Eltype <: Integer ,VNType <: VarName } <: AbstractAccumulator
52
57
" the number of observations"
53
- num:: T
58
+ num_produce:: Eltype
59
+ " mapping of variable names to their order in the model"
60
+ order:: Dict{VNType,Eltype}
54
61
end
55
62
56
63
"""
57
- NumProduceAccumulator {T<:Integer}()
64
+ VariableOrderAccumulator {T<:Integer}(n=zero(T) )
58
65
59
- Create a new `NumProduceAccumulator` accumulator with the number of observations initialized to zero .
66
+ Create a new `VariableOrderAccumulator` with the number of observations set to `n` .
60
67
"""
61
- NumProduceAccumulator {T} () where {T<: Integer } = NumProduceAccumulator (zero (T))
62
- NumProduceAccumulator () = NumProduceAccumulator {Int} ()
68
+ VariableOrderAccumulator {T} (n= zero (T)) where {T<: Integer } =
69
+ VariableOrderAccumulator (convert (T, n), Dict {VarName,T} ())
70
+ VariableOrderAccumulator (n) = VariableOrderAccumulator {typeof(n)} (n)
71
+ VariableOrderAccumulator () = VariableOrderAccumulator {Int} ()
72
+
73
+ Base. copy (acc:: LogPriorAccumulator ) = acc
74
+ Base. copy (acc:: LogLikelihoodAccumulator ) = acc
75
+ function Base. copy (acc:: VariableOrderAccumulator )
76
+ return VariableOrderAccumulator (acc. num_produce, copy (acc. order))
77
+ end
63
78
64
79
function Base. show (io:: IO , acc:: LogPriorAccumulator )
65
80
return print (io, " LogPriorAccumulator($(repr (acc. logp)) )" )
66
81
end
67
82
function Base. show (io:: IO , acc:: LogLikelihoodAccumulator )
68
83
return print (io, " LogLikelihoodAccumulator($(repr (acc. logp)) )" )
69
84
end
70
- function Base. show (io:: IO , acc:: NumProduceAccumulator )
71
- return print (io, " NumProduceAccumulator($(repr (acc. num)) )" )
85
+ function Base. show (io:: IO , acc:: VariableOrderAccumulator )
86
+ return print (
87
+ io, " VariableOrderAccumulator($(repr (acc. num_produce)) , $(repr (acc. order)) )"
88
+ )
89
+ end
90
+
91
+ # Note that == and isequal are different, and equality under the latter should imply
92
+ # equality of hashes. Both of the below implementations are also different from the default
93
+ # implementation for structs.
94
+ Base.:(== )(acc1:: LogPriorAccumulator , acc2:: LogPriorAccumulator ) = acc1. logp == acc2. logp
95
+ function Base.:(== )(acc1:: LogLikelihoodAccumulator , acc2:: LogLikelihoodAccumulator )
96
+ return acc1. logp == acc2. logp
97
+ end
98
+ function Base.:(== )(acc1:: VariableOrderAccumulator , acc2:: VariableOrderAccumulator )
99
+ return acc1. num_produce == acc2. num_produce && acc1. order == acc2. order
100
+ end
101
+
102
+ function Base. isequal (acc1:: LogPriorAccumulator , acc2:: LogPriorAccumulator )
103
+ return isequal (acc1. logp, acc2. logp)
104
+ end
105
+ function Base. isequal (acc1:: LogLikelihoodAccumulator , acc2:: LogLikelihoodAccumulator )
106
+ return isequal (acc1. logp, acc2. logp)
107
+ end
108
+ function Base. isequal (acc1:: VariableOrderAccumulator , acc2:: VariableOrderAccumulator )
109
+ return isequal (acc1. num_produce, acc2. num_produce) && isequal (acc1. order, acc2. order)
110
+ end
111
+
112
+ Base. hash (acc:: LogPriorAccumulator , h:: UInt ) = hash ((LogPriorAccumulator, acc. logp), h)
113
+ function Base. hash (acc:: LogLikelihoodAccumulator , h:: UInt )
114
+ return hash ((LogLikelihoodAccumulator, acc. logp), h)
115
+ end
116
+ function Base. hash (acc:: VariableOrderAccumulator , h:: UInt )
117
+ return hash ((VariableOrderAccumulator, acc. num_produce, acc. order), h)
72
118
end
73
119
74
120
accumulator_name (:: Type{<:LogPriorAccumulator} ) = :LogPrior
75
121
accumulator_name (:: Type{<:LogLikelihoodAccumulator} ) = :LogLikelihood
76
- accumulator_name (:: Type{<:NumProduceAccumulator } ) = :NumProduce
122
+ accumulator_name (:: Type{<:VariableOrderAccumulator } ) = :VariableOrder
77
123
78
124
split (:: LogPriorAccumulator{T} ) where {T} = LogPriorAccumulator (zero (T))
79
125
split (:: LogLikelihoodAccumulator{T} ) where {T} = LogLikelihoodAccumulator (zero (T))
80
- split (acc:: NumProduceAccumulator ) = acc
126
+ split (acc:: VariableOrderAccumulator ) = copy ( acc)
81
127
82
128
function combine (acc:: LogPriorAccumulator , acc2:: LogPriorAccumulator )
83
129
return LogPriorAccumulator (acc. logp + acc2. logp)
84
130
end
85
131
function combine (acc:: LogLikelihoodAccumulator , acc2:: LogLikelihoodAccumulator )
86
132
return LogLikelihoodAccumulator (acc. logp + acc2. logp)
87
133
end
88
- function combine (acc:: NumProduceAccumulator , acc2:: NumProduceAccumulator )
89
- return NumProduceAccumulator (max (acc. num, acc2. num))
134
+ function combine (acc:: VariableOrderAccumulator , acc2:: VariableOrderAccumulator )
135
+ # Note that assumptions are not allowed in parallelised blocks, and thus the
136
+ # dictionaries should be identical.
137
+ return VariableOrderAccumulator (
138
+ max (acc. num_produce, acc2. num_produce), merge (acc. order, acc2. order)
139
+ )
90
140
end
91
141
92
142
function Base.:+ (acc1:: LogPriorAccumulator , acc2:: LogPriorAccumulator )
95
145
function Base.:+ (acc1:: LogLikelihoodAccumulator , acc2:: LogLikelihoodAccumulator )
96
146
return LogLikelihoodAccumulator (acc1. logp + acc2. logp)
97
147
end
98
- increment (acc:: NumProduceAccumulator ) = NumProduceAccumulator (acc. num + oneunit (acc. num))
148
+ function increment (acc:: VariableOrderAccumulator )
149
+ return VariableOrderAccumulator (acc. num_produce + oneunit (acc. num_produce), acc. order)
150
+ end
99
151
100
152
Base. zero (acc:: LogPriorAccumulator ) = LogPriorAccumulator (zero (acc. logp))
101
153
Base. zero (acc:: LogLikelihoodAccumulator ) = LogLikelihoodAccumulator (zero (acc. logp))
102
- Base. zero (acc:: NumProduceAccumulator ) = NumProduceAccumulator (zero (acc. num))
103
154
104
155
function accumulate_assume!! (acc:: LogPriorAccumulator , val, logjac, vn, right)
105
156
return acc + LogPriorAccumulator (logpdf (right, val) + logjac)
@@ -114,8 +165,11 @@ function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn)
114
165
return acc + LogLikelihoodAccumulator (Distributions. loglikelihood (right, left))
115
166
end
116
167
117
- accumulate_assume!! (acc:: NumProduceAccumulator , val, logjac, vn, right) = acc
118
- accumulate_observe!! (acc:: NumProduceAccumulator , right, left, vn) = increment (acc)
168
+ function accumulate_assume!! (acc:: VariableOrderAccumulator , val, logjac, vn, right)
169
+ acc. order[vn] = acc. num_produce
170
+ return acc
171
+ end
172
+ accumulate_observe!! (acc:: VariableOrderAccumulator , right, left, vn) = increment (acc)
119
173
120
174
function Base. convert (:: Type{LogPriorAccumulator{T}} , acc:: LogPriorAccumulator ) where {T}
121
175
return LogPriorAccumulator (convert (T, acc. logp))
@@ -126,15 +180,19 @@ function Base.convert(
126
180
return LogLikelihoodAccumulator (convert (T, acc. logp))
127
181
end
128
182
function Base. convert (
129
- :: Type{NumProduceAccumulator{T}} , acc:: NumProduceAccumulator
130
- ) where {T}
131
- return NumProduceAccumulator (convert (T, acc. num))
183
+ :: Type{VariableOrderAccumulator{ElType,VnType}} , acc:: VariableOrderAccumulator
184
+ ) where {ElType,VnType}
185
+ order = Dict {VnType,ElType} ()
186
+ for (k, v) in acc. order
187
+ order[convert (VnType, k)] = convert (ElType, v)
188
+ end
189
+ return VariableOrderAccumulator (convert (ElType, acc. num_produce), order)
132
190
end
133
191
134
192
# TODO (mhauru)
135
- # We ignore the convert_eltype calls for NumProduceAccumulator , by letting them fallback on
193
+ # We ignore the convert_eltype calls for VariableOrderAccumulator , by letting them fallback on
136
194
# convert_eltype(::AbstractAccumulator, ::Type). This is because they are only used to
137
- # deal with dual number types of AD backends, which shouldn't concern NumProduceAccumulator . This is
195
+ # deal with dual number types of AD backends, which shouldn't concern VariableOrderAccumulator . This is
138
196
# horribly hacky and should be fixed. See also comment in `unflatten` in `src/varinfo.jl`.
139
197
function convert_eltype (:: Type{T} , acc:: LogPriorAccumulator ) where {T}
140
198
return LogPriorAccumulator (convert (T, acc. logp))
@@ -149,6 +207,6 @@ function default_accumulators(
149
207
return AccumulatorTuple (
150
208
LogPriorAccumulator {FloatT} (),
151
209
LogLikelihoodAccumulator {FloatT} (),
152
- NumProduceAccumulator {IntT} (),
210
+ VariableOrderAccumulator {IntT} (),
153
211
)
154
212
end
0 commit comments