Skip to content

Commit d0dba32

Browse files
committed
Remove separate adtype field from LogDensityFunction
1 parent fbef3db commit d0dba32

File tree

2 files changed

+60
-62
lines changed

2 files changed

+60
-62
lines changed

src/logdensityfunction.jl

+41-42
Original file line numberDiff line numberDiff line change
@@ -18,31 +18,17 @@ is_supported(::ADTypes.AutoReverseDiff) = true
1818
LogDensityFunction(
1919
model::Model,
2020
varinfo::AbstractVarInfo=VarInfo(model),
21-
context::AbstractContext=DefaultContext();
22-
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing
21+
context::AbstractContext=DefaultContext()
2322
)
2423
25-
A struct which contains a model, along with all the information necessary to:
26-
27-
- calculate its log density at a given point;
28-
- and if `adtype` is provided, calculate the gradient of the log density at
29-
that point.
24+
A struct which contains a model, along with all the information necessary to
25+
calculate its log density at a given point.
3026
3127
At its most basic level, a LogDensityFunction wraps the model together with its
3228
the type of varinfo to be used, as well as the evaluation context. These must
3329
be known in order to calculate the log density (using
3430
[`DynamicPPL.evaluate!!`](@ref)).
3531
36-
If the `adtype` keyword argument is provided, then this struct will also store
37-
the adtype along with other information for efficient calculation of the
38-
gradient of the log density. Note that preparing a `LogDensityFunction` with an
39-
AD type `AutoBackend()` requires the AD backend itself to have been loaded
40-
(e.g. with `import Backend`).
41-
42-
`DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface.
43-
If `adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a
44-
concrete AD backend type, then `logdensity_and_gradient` is also implemented.
45-
4632
# Fields
4733
$(FIELDS)
4834
@@ -84,40 +70,42 @@ julia> # This also respects the context in `model`.
8470
julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0)
8571
true
8672
87-
julia> # If we also need to calculate the gradient, we can specify an AD backend.
73+
julia> # If we also need to calculate the gradient, an AD backend must be specified as part of the model.
8874
import ForwardDiff, ADTypes
8975
90-
julia> f = LogDensityFunction(model, adtype=ADTypes.AutoForwardDiff());
76+
julia> model_with_ad = Model(model, ADTypes.AutoForwardDiff());
77+
78+
julia> f = LogDensityFunction(model_with_ad);
9179
9280
julia> LogDensityProblems.logdensity_and_gradient(f, [0.0])
9381
(-2.3378770664093453, [1.0])
9482
```
9583
"""
96-
struct LogDensityFunction{
97-
M<:Model,V<:AbstractVarInfo,C<:AbstractContext,AD<:Union{Nothing,ADTypes.AbstractADType}
98-
}
84+
struct LogDensityFunction{M<:Model,V<:AbstractVarInfo,C<:AbstractContext}
9985
"model used for evaluation"
10086
model::M
10187
"varinfo used for evaluation"
10288
varinfo::V
10389
"context used for evaluation; if `nothing`, `leafcontext(model.context)` will be used when applicable"
10490
context::C
105-
"AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated"
106-
adtype::AD
10791
"(internal use only) gradient preparation object for the model"
10892
prep::Union{Nothing,DI.GradientPrep}
10993

11094
function LogDensityFunction(
11195
model::Model,
11296
varinfo::AbstractVarInfo=VarInfo(model),
113-
context::AbstractContext=leafcontext(model.context);
114-
adtype::Union{ADTypes.AbstractADType,Nothing}=model.adtype,
97+
context::AbstractContext=leafcontext(model.context),
11598
)
99+
adtype = model.adtype
116100
if adtype === nothing
117101
prep = nothing
118102
else
119103
# Make backend-specific tweaks to the adtype
104+
# This should arguably be done in the model constructor, but it needs the
105+
# varinfo and context to do so, and it seems excessive to construct a
106+
# varinfo at the point of calling Model().
120107
adtype = tweak_adtype(adtype, model, varinfo, context)
108+
model = Model(model, adtype)
121109
# Check whether it is supported
122110
is_supported(adtype) ||
123111
@warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed."
@@ -138,8 +126,8 @@ struct LogDensityFunction{
138126
)
139127
end
140128
end
141-
return new{typeof(model),typeof(varinfo),typeof(context),typeof(adtype)}(
142-
model, varinfo, context, adtype, prep
129+
return new{typeof(model),typeof(varinfo),typeof(context)}(
130+
model, varinfo, context, prep
143131
)
144132
end
145133
end
@@ -157,10 +145,10 @@ Create a new LogDensityFunction using the model, varinfo, and context from the g
157145
function LogDensityFunction(
158146
f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType}
159147
)
160-
return if adtype === f.adtype
148+
return if adtype === f.model.adtype
161149
f # Avoid recomputing prep if not needed
162150
else
163-
LogDensityFunction(f.model, f.varinfo, f.context; adtype=adtype)
151+
LogDensityFunction(Model(f.model, adtype), f.varinfo, f.context)
164152
end
165153
end
166154

@@ -187,35 +175,46 @@ end
187175
### LogDensityProblems interface
188176

189177
function LogDensityProblems.capabilities(
190-
::Type{<:LogDensityFunction{M,V,C,Nothing}}
191-
) where {M,V,C}
178+
::Type{
179+
<:LogDensityFunction{
180+
Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,Nothing},V,C
181+
},
182+
},
183+
) where {F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,V,C}
192184
return LogDensityProblems.LogDensityOrder{0}()
193185
end
194186
function LogDensityProblems.capabilities(
195-
::Type{<:LogDensityFunction{M,V,C,AD}}
196-
) where {M,V,C,AD<:ADTypes.AbstractADType}
187+
::Type{
188+
<:LogDensityFunction{
189+
Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,TAD},V,C
190+
},
191+
},
192+
) where {
193+
F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,V,C,TAD<:ADTypes.AbstractADType
194+
}
197195
return LogDensityProblems.LogDensityOrder{1}()
198196
end
199197
function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector)
200198
return logdensity_at(x, f.model, f.varinfo, f.context)
201199
end
202200
function LogDensityProblems.logdensity_and_gradient(
203-
f::LogDensityFunction{M,V,C,AD}, x::AbstractVector
204-
) where {M,V,C,AD<:ADTypes.AbstractADType}
205-
f.prep === nothing &&
206-
error("Gradient preparation not available; this should not happen")
201+
f::LogDensityFunction{M,V,C}, x::AbstractVector
202+
) where {M,V,C}
203+
f.prep === nothing && error(
204+
"Attempted to call logdensity_and_gradient on a LogDensityFunction without an AD backend. You need to set an AD backend in the model before calculating the gradient of logp.",
205+
)
207206
x = map(identity, x) # Concretise type
208207
# Make branching statically inferrable, i.e. type-stable (even if the two
209208
# branches happen to return different types)
210-
return if use_closure(f.adtype)
209+
return if use_closure(f.model.adtype)
211210
DI.value_and_gradient(
212-
x -> logdensity_at(x, f.model, f.varinfo, f.context), f.prep, f.adtype, x
211+
x -> logdensity_at(x, f.model, f.varinfo, f.context), f.prep, f.model.adtype, x
213212
)
214213
else
215214
DI.value_and_gradient(
216215
logdensity_at,
217216
f.prep,
218-
f.adtype,
217+
f.model.adtype,
219218
x,
220219
DI.Constant(f.model),
221220
DI.Constant(f.varinfo),
@@ -292,7 +291,7 @@ getmodel(f::DynamicPPL.LogDensityFunction) = f.model
292291
Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
293292
"""
294293
function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
295-
return LogDensityFunction(model, f.varinfo, f.context; adtype=f.adtype)
294+
return LogDensityFunction(model, f.varinfo, f.context)
296295
end
297296

298297
"""

test/logdensityfunction.jl

+19-20
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,9 @@ using Test, DynamicPPL, ADTypes, LogDensityProblems, ForwardDiff
99
end
1010
end
1111

12-
@testset "AD type forwarding from model" begin
13-
@model demo_simple() = x ~ Normal()
14-
model = Model(demo_simple(), AutoForwardDiff())
15-
ldf = DynamicPPL.LogDensityFunction(model)
16-
# Check that the model's AD type is forwarded to the LDF
17-
# Note: can't check ldf.adtype == AutoForwardDiff() because `tweak_adtype`
18-
# modifies the underlying parameters a bit, so just check that it is still
19-
# the correct backend package.
20-
@test ldf.adtype isa AutoForwardDiff
21-
# Check that the gradient can be evaluated on the resulting LDF
22-
@test LogDensityProblems.capabilities(typeof(ldf)) ==
23-
LogDensityProblems.LogDensityOrder{1}()
24-
@test LogDensityProblems.logdensity(ldf, [1.0]) isa Any
25-
@test LogDensityProblems.logdensity_and_gradient(ldf, [1.0]) isa Any
26-
end
27-
2812
@testset "LogDensityFunction" begin
29-
@testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS
13+
@testset "construction from $(nameof(model))" for model in
14+
DynamicPPL.TestUtils.DEMO_MODELS
3015
example_values = DynamicPPL.TestUtils.rand_prior_true(model)
3116
vns = DynamicPPL.TestUtils.varnames(model)
3217
varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns)
@@ -39,14 +24,28 @@ end
3924
end
4025
end
4126

42-
@testset "capabilities" begin
43-
model = DynamicPPL.TestUtils.DEMO_MODELS[1]
27+
@testset "LogDensityProblems interface" begin
28+
@model demo_simple() = x ~ Normal()
29+
model = demo_simple()
30+
4431
ldf = DynamicPPL.LogDensityFunction(model)
4532
@test LogDensityProblems.capabilities(typeof(ldf)) ==
4633
LogDensityProblems.LogDensityOrder{0}()
34+
@test LogDensityProblems.logdensity(ldf, [1.0]) isa Any
4735

48-
ldf_with_ad = DynamicPPL.LogDensityFunction(model; adtype=AutoForwardDiff())
36+
# Set AD type on model, then reconstruct LDF
37+
model_with_ad = Model(model, AutoForwardDiff())
38+
ldf_with_ad = DynamicPPL.LogDensityFunction(model_with_ad)
4939
@test LogDensityProblems.capabilities(typeof(ldf_with_ad)) ==
5040
LogDensityProblems.LogDensityOrder{1}()
41+
@test LogDensityProblems.logdensity(ldf_with_ad, [1.0]) isa Any
42+
@test LogDensityProblems.logdensity_and_gradient(ldf_with_ad, [1.0]) isa Any
43+
44+
# Set AD type on LDF directly
45+
ldf_with_ad2 = DynamicPPL.LogDensityFunction(ldf, AutoForwardDiff())
46+
@test LogDensityProblems.capabilities(typeof(ldf_with_ad2)) ==
47+
LogDensityProblems.LogDensityOrder{1}()
48+
@test LogDensityProblems.logdensity(ldf_with_ad2, [1.0]) isa Any
49+
@test LogDensityProblems.logdensity_and_gradient(ldf_with_ad2, [1.0]) isa Any
5150
end
5251
end

0 commit comments

Comments
 (0)