@@ -18,31 +18,17 @@ is_supported(::ADTypes.AutoReverseDiff) = true
18
18
LogDensityFunction(
19
19
model::Model,
20
20
varinfo::AbstractVarInfo=VarInfo(model),
21
- context::AbstractContext=DefaultContext();
22
- adtype::Union{ADTypes.AbstractADType,Nothing}=nothing
21
+ context::AbstractContext=DefaultContext()
23
22
)
24
23
25
- A struct which contains a model, along with all the information necessary to:
26
-
27
- - calculate its log density at a given point;
28
- - and if `adtype` is provided, calculate the gradient of the log density at
29
- that point.
24
+ A struct which contains a model, along with all the information necessary to
25
+ calculate its log density at a given point.
30
26
31
27
At its most basic level, a LogDensityFunction wraps the model together with its
32
28
the type of varinfo to be used, as well as the evaluation context. These must
33
29
be known in order to calculate the log density (using
34
30
[`DynamicPPL.evaluate!!`](@ref)).
35
31
36
- If the `adtype` keyword argument is provided, then this struct will also store
37
- the adtype along with other information for efficient calculation of the
38
- gradient of the log density. Note that preparing a `LogDensityFunction` with an
39
- AD type `AutoBackend()` requires the AD backend itself to have been loaded
40
- (e.g. with `import Backend`).
41
-
42
- `DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface.
43
- If `adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a
44
- concrete AD backend type, then `logdensity_and_gradient` is also implemented.
45
-
46
32
# Fields
47
33
$(FIELDS)
48
34
@@ -84,40 +70,42 @@ julia> # This also respects the context in `model`.
84
70
julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0)
85
71
true
86
72
87
- julia> # If we also need to calculate the gradient, we can specify an AD backend.
73
+ julia> # If we also need to calculate the gradient, an AD backend must be specified as part of the model .
88
74
import ForwardDiff, ADTypes
89
75
90
- julia> f = LogDensityFunction(model, adtype=ADTypes.AutoForwardDiff());
76
+ julia> model_with_ad = Model(model, ADTypes.AutoForwardDiff());
77
+
78
+ julia> f = LogDensityFunction(model_with_ad);
91
79
92
80
julia> LogDensityProblems.logdensity_and_gradient(f, [0.0])
93
81
(-2.3378770664093453, [1.0])
94
82
```
95
83
"""
96
- struct LogDensityFunction{
97
- M<: Model ,V<: AbstractVarInfo ,C<: AbstractContext ,AD<: Union{Nothing,ADTypes.AbstractADType}
98
- }
84
+ struct LogDensityFunction{M<: Model ,V<: AbstractVarInfo ,C<: AbstractContext }
99
85
" model used for evaluation"
100
86
model:: M
101
87
" varinfo used for evaluation"
102
88
varinfo:: V
103
89
" context used for evaluation; if `nothing`, `leafcontext(model.context)` will be used when applicable"
104
90
context:: C
105
- " AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated"
106
- adtype:: AD
107
91
" (internal use only) gradient preparation object for the model"
108
92
prep:: Union{Nothing,DI.GradientPrep}
109
93
110
94
function LogDensityFunction (
111
95
model:: Model ,
112
96
varinfo:: AbstractVarInfo = VarInfo (model),
113
- context:: AbstractContext = leafcontext (model. context);
114
- adtype:: Union{ADTypes.AbstractADType,Nothing} = model. adtype,
97
+ context:: AbstractContext = leafcontext (model. context),
115
98
)
99
+ adtype = model. adtype
116
100
if adtype === nothing
117
101
prep = nothing
118
102
else
119
103
# Make backend-specific tweaks to the adtype
104
+ # This should arguably be done in the model constructor, but it needs the
105
+ # varinfo and context to do so, and it seems excessive to construct a
106
+ # varinfo at the point of calling Model().
120
107
adtype = tweak_adtype (adtype, model, varinfo, context)
108
+ model = Model (model, adtype)
121
109
# Check whether it is supported
122
110
is_supported (adtype) ||
123
111
@warn " The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed."
@@ -138,8 +126,8 @@ struct LogDensityFunction{
138
126
)
139
127
end
140
128
end
141
- return new {typeof(model),typeof(varinfo),typeof(context),typeof(adtype) } (
142
- model, varinfo, context, adtype, prep
129
+ return new {typeof(model),typeof(varinfo),typeof(context)} (
130
+ model, varinfo, context, prep
143
131
)
144
132
end
145
133
end
@@ -157,10 +145,10 @@ Create a new LogDensityFunction using the model, varinfo, and context from the g
157
145
function LogDensityFunction (
158
146
f:: LogDensityFunction , adtype:: Union{Nothing,ADTypes.AbstractADType}
159
147
)
160
- return if adtype === f. adtype
148
+ return if adtype === f. model . adtype
161
149
f # Avoid recomputing prep if not needed
162
150
else
163
- LogDensityFunction (f. model, f. varinfo, f. context; adtype = adtype )
151
+ LogDensityFunction (Model ( f. model, adtype), f. varinfo, f. context)
164
152
end
165
153
end
166
154
@@ -187,35 +175,46 @@ end
187
175
# ## LogDensityProblems interface
188
176
189
177
function LogDensityProblems. capabilities (
190
- :: Type{<:LogDensityFunction{M,V,C,Nothing}}
191
- ) where {M,V,C}
178
+ :: Type {
179
+ <: LogDensityFunction {
180
+ Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,Nothing},V,C
181
+ },
182
+ },
183
+ ) where {F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,V,C}
192
184
return LogDensityProblems. LogDensityOrder {0} ()
193
185
end
194
186
function LogDensityProblems. capabilities (
195
- :: Type{<:LogDensityFunction{M,V,C,AD}}
196
- ) where {M,V,C,AD<: ADTypes.AbstractADType }
187
+ :: Type {
188
+ <: LogDensityFunction {
189
+ Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,TAD},V,C
190
+ },
191
+ },
192
+ ) where {
193
+ F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,V,C,TAD<: ADTypes.AbstractADType
194
+ }
197
195
return LogDensityProblems. LogDensityOrder {1} ()
198
196
end
199
197
function LogDensityProblems. logdensity (f:: LogDensityFunction , x:: AbstractVector )
200
198
return logdensity_at (x, f. model, f. varinfo, f. context)
201
199
end
202
200
function LogDensityProblems. logdensity_and_gradient (
203
- f:: LogDensityFunction{M,V,C,AD} , x:: AbstractVector
204
- ) where {M,V,C,AD<: ADTypes.AbstractADType }
205
- f. prep === nothing &&
206
- error (" Gradient preparation not available; this should not happen" )
201
+ f:: LogDensityFunction{M,V,C} , x:: AbstractVector
202
+ ) where {M,V,C}
203
+ f. prep === nothing && error (
204
+ " Attempted to call logdensity_and_gradient on a LogDensityFunction without an AD backend. You need to set an AD backend in the model before calculating the gradient of logp." ,
205
+ )
207
206
x = map (identity, x) # Concretise type
208
207
# Make branching statically inferrable, i.e. type-stable (even if the two
209
208
# branches happen to return different types)
210
- return if use_closure (f. adtype)
209
+ return if use_closure (f. model . adtype)
211
210
DI. value_and_gradient (
212
- x -> logdensity_at (x, f. model, f. varinfo, f. context), f. prep, f. adtype, x
211
+ x -> logdensity_at (x, f. model, f. varinfo, f. context), f. prep, f. model . adtype, x
213
212
)
214
213
else
215
214
DI. value_and_gradient (
216
215
logdensity_at,
217
216
f. prep,
218
- f. adtype,
217
+ f. model . adtype,
219
218
x,
220
219
DI. Constant (f. model),
221
220
DI. Constant (f. varinfo),
@@ -292,7 +291,7 @@ getmodel(f::DynamicPPL.LogDensityFunction) = f.model
292
291
Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
293
292
"""
294
293
function setmodel (f:: DynamicPPL.LogDensityFunction , model:: DynamicPPL.Model )
295
- return LogDensityFunction (model, f. varinfo, f. context; adtype = f . adtype )
294
+ return LogDensityFunction (model, f. varinfo, f. context)
296
295
end
297
296
298
297
"""
0 commit comments