Skip to content

Commit 10fe743

Browse files
committed
Add adtype field to DynamicPPL.Model
1 parent 90c7b26 commit 10fe743

File tree

7 files changed

+98
-16
lines changed

7 files changed

+98
-16
lines changed

HISTORY.md

+18-1
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,22 @@ This release removes the feature of `VarInfo` where it kept track of which varia
128128
129129
**Other changes**
130130
131+
### Models now store AD backend types
132+
133+
In `DynamicPPL.Model`, an extra field `adtype::Union{Nothing,ADTypes.AbstractADType}` has been added. This field is used to store the AD backend which should be used when calculating gradients of the log density.
134+
135+
The field can be set by passing an extra argument to the `Model` constructor, but more realistically, it is likely that users will want to manually set the `adtype` field on an existing model:
136+
137+
```julia
138+
@model f() = ...
139+
model = f()
140+
model_with_adtype = setadtype(model, AutoForwardDiff())
141+
```
142+
143+
As far as `DynamicPPL.Model` is concerned, this field does not actually have any effect.
144+
However, when a `LogDensityFunction` is constructed from said model, it will inherit the `adtype` field from the model.
145+
See below for more information on `LogDensityFunction`.
146+
131147
### `LogDensityProblems` interface
132148
133149
LogDensityProblemsAD is now removed as a dependency.
@@ -136,7 +152,8 @@ Instead of constructing a `LogDensityProblemAD.ADgradient` object, we now direct
136152
Note that if you wish, you can still construct an `ADgradient` out of a `LogDensityFunction` object (there is nothing preventing this).
137153
138154
However, in this version, `LogDensityFunction` now takes an extra AD type argument.
139-
If this argument is not provided, the behaviour is exactly the same as before, i.e. you can calculate `logdensity` but not its gradient.
155+
By default, this AD type is inherited from the model that the `LogDensityFunction` is constructed from.
156+
If the model does not have an AD type, or if the argument is explicitly set to `nothing`, the behaviour is exactly the same as before, i.e. you can calculate `logdensity` but not its gradient.
140157
However, if you do pass an AD type, that will allow you to calculate the gradient as well.
141158
You may thus find that it is easier to instead do this:
142159

docs/src/api.md

+7
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@ getargnames
3636
getmissings
3737
```
3838

39+
The context and AD type of a model can be changed with [`contextualize`](@ref) and [`setadtype`](@ref) respectively.
40+
41+
```@docs
42+
contextualize
43+
setadtype
44+
```
45+
3946
## Evaluation
4047

4148
With [`rand`](@ref) one can draw samples from the prior distribution of a [`Model`](@ref).

src/DynamicPPL.jl

+1
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ export AbstractVarInfo,
8484
getargnames,
8585
extract_priors,
8686
values_as_in_model,
87+
setadtype,
8788
# Samplers
8889
Sampler,
8990
SampleFromPrior,

src/logdensityfunction.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ struct LogDensityFunction{
111111
model::Model,
112112
varinfo::AbstractVarInfo=VarInfo(model),
113113
context::AbstractContext=leafcontext(model.context);
114-
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
114+
adtype::Union{ADTypes.AbstractADType,Nothing}=model.adtype,
115115
)
116116
if adtype === nothing
117117
prep = nothing

src/model.jl

+47-14
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
11
"""
2-
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext}
2+
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext,TAD<:Union{Nothing,ADTypes.AbstractADType}}
33
f::F
44
args::NamedTuple{argnames,Targs}
55
defaults::NamedTuple{defaultnames,Tdefaults}
66
context::Ctx=DefaultContext()
7+
adtype::TAD=nothing
78
end
89
9-
A `Model` struct with model evaluation function of type `F`, arguments of names `argnames`
10-
types `Targs`, default arguments of names `defaultnames` with types `Tdefaults`, missing
11-
arguments `missings`, and evaluation context of type `Ctx`.
10+
A `Model` struct contains the following fields:
11+
- `f`, a model evaluation function of type `F`
12+
- `args`, arguments of names `argnames` with types `Targs`
13+
- `defaults`, default arguments of names `defaultnames` with types `Tdefaults`
14+
- `context`, an evaluation context of type `Ctx`
15+
- `adtype`, which can be nothing, or an automatic differentiation backend of type `TAD`
1216
17+
Its missing arguments are also stored as a type parameter `missings`.
1318
Here `argnames`, `defaultargnames`, and `missings` are tuples of symbols, e.g. `(:a, :b)`.
14-
`context` is by default `DefaultContext()`.
19+
20+
`context` is by default `DefaultContext()`, and `adtype` is by default `nothing`.
1521
1622
An argument with a type of `Missing` will be in `missings` by default. However, in
1723
non-traditional use-cases `missings` can be defined differently. All variables in `missings`
@@ -33,12 +39,21 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition
3339
Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,))
3440
```
3541
"""
36-
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} <:
37-
AbstractProbabilisticProgram
42+
struct Model{
43+
F,
44+
argnames,
45+
defaultnames,
46+
missings,
47+
Targs,
48+
Tdefaults,
49+
Ctx<:AbstractContext,
50+
TAD<:Union{Nothing,ADTypes.AbstractADType},
51+
} <: AbstractProbabilisticProgram
3852
f::F
3953
args::NamedTuple{argnames,Targs}
4054
defaults::NamedTuple{defaultnames,Tdefaults}
4155
context::Ctx
56+
adtype::TAD
4257

4358
@doc """
4459
Model{missings}(f, args::NamedTuple, defaults::NamedTuple)
@@ -51,9 +66,10 @@ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractConte
5166
args::NamedTuple{argnames,Targs},
5267
defaults::NamedTuple{defaultnames,Tdefaults},
5368
context::Ctx=DefaultContext(),
54-
) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx}
55-
return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx}(
56-
f, args, defaults, context
69+
adtype::TAD=nothing,
70+
) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx,TAD}
71+
return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,TAD}(
72+
f, args, defaults, context, adtype
5773
)
5874
end
5975
end
@@ -71,22 +87,39 @@ model with different arguments.
7187
args::NamedTuple{argnames,Targs},
7288
defaults::NamedTuple{kwargnames,Tkwargs},
7389
context::AbstractContext=DefaultContext(),
74-
) where {F,argnames,Targs,kwargnames,Tkwargs}
90+
adtype::TAD=nothing,
91+
) where {F,argnames,Targs,kwargnames,Tkwargs,TAD}
7592
missing_args = Tuple(
7693
name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing
7794
)
7895
missing_kwargs = Tuple(
7996
name for (name, typ) in zip(kwargnames, Tkwargs.types) if typ <: Missing
8097
)
81-
return :(Model{$(missing_args..., missing_kwargs...)}(f, args, defaults, context))
98+
return :(Model{$(missing_args..., missing_kwargs...)}(
99+
f, args, defaults, context, adtype
100+
))
82101
end
83102

84103
function Model(f, args::NamedTuple, context::AbstractContext=DefaultContext(); kwargs...)
85-
return Model(f, args, NamedTuple(kwargs), context)
104+
return Model(f, args, NamedTuple(kwargs), context, nothing)
86105
end
87106

107+
"""
108+
contextualize(model::Model, context::AbstractContext)
109+
110+
Set the context of `model` to `context`.
111+
"""
88112
function contextualize(model::Model, context::AbstractContext)
89-
return Model(model.f, model.args, model.defaults, context)
113+
return Model(model.f, model.args, model.defaults, context, model.adtype)
114+
end
115+
116+
"""
117+
setadtype(model::Model, adtype::Union{Nothing,ADTypes.AbstractADType})
118+
119+
Set the automatic differentiation backend of `model` to `adtype`.
120+
"""
121+
function setadtype(model::Model, adtype::Union{Nothing,ADTypes.AbstractADType})
122+
return Model(model.f, model.args, model.defaults, model.context, adtype)
90123
end
91124

92125
"""

test/logdensityfunction.jl

+14
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,20 @@ using Test, DynamicPPL, ADTypes, LogDensityProblems, ForwardDiff
99
end
1010
end
1111

12+
@testset "AD type forwarding from model" begin
13+
@model demo_simple() = x ~ Normal()
14+
adtype = ForwardDiff()
15+
model = setadtype(demo_simple(), adtype)
16+
ldf = DynamicPPL.LogDensityFunction(model)
17+
# Check that the model's AD type is forwarded to the LDF
18+
@test ldf.adtype == adtype
19+
# Check that the gradient can be evaluated on the resulting LDF
20+
@test LogDensityProblems.capabilities(typeof(ldf)) ==
21+
LogDensityProblems.LogDensityOrder{1}()
22+
@test LogDensityProblems.logdensity(ldf, [1.0]) isa Any
23+
@test LogDensityProblems.logdensity_and_gradient(ldf, [1.0])
24+
end
25+
1226
@testset "LogDensityFunction" begin
1327
@testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS
1428
example_values = DynamicPPL.TestUtils.rand_prior_true(model)

test/model.jl

+10
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,16 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
100100
end
101101
end
102102

103+
@testset "model adtype" begin
104+
# Check that adtype can be set and unset
105+
@model demo_adtype() = x ~ Normal()
106+
adtype = AutoForwardDiff()
107+
model = setadtype(demo_adtype(), adtype)
108+
@test model.adtype == adtype
109+
model = setadtype(model, nothing)
110+
@test model.adtype === nothing
111+
end
112+
103113
@testset "model de/conditioning" begin
104114
@model function demo_condition()
105115
x ~ Normal()

0 commit comments

Comments
 (0)