Skip to content

Commit fd60cc1

Browse files
committed
Implement AD testing and benchmarking (hand rolled)
1 parent eed80e5 commit fd60cc1

File tree

5 files changed

+230
-5
lines changed

5 files changed

+230
-5
lines changed

Project.toml

+4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
1010
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
1111
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
1212
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
13+
Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de"
1314
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
1415
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1516
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
@@ -22,6 +23,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
2223
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
2324
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2425
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
26+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2527
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2628

2729
[weakdeps]
@@ -49,6 +51,7 @@ Accessors = "0.1"
4951
BangBang = "0.4.1"
5052
Bijectors = "0.13.18, 0.14, 0.15"
5153
ChainRulesCore = "1"
54+
Chairmarks = "1.3.1"
5255
Compat = "4"
5356
ConstructionBase = "1.5.4"
5457
DifferentiationInterface = "0.6.41"
@@ -67,5 +70,6 @@ Mooncake = "0.4.95"
6770
OrderedCollections = "1"
6871
Random = "1.6"
6972
Requires = "1"
73+
Statistics = "1.11.1"
7074
Test = "1.6"
7175
julia = "1.10"

src/DynamicPPL.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,6 @@ include("context_implementations.jl")
175175
include("compiler.jl")
176176
include("pointwise_logdensities.jl")
177177
include("submodel_macro.jl")
178-
include("test_utils.jl")
179178
include("transforming.jl")
180179
include("logdensityfunction.jl")
181180
include("model_utils.jl")
@@ -184,6 +183,7 @@ include("values_as_in_model.jl")
184183

185184
include("debug_utils.jl")
186185
using .DebugUtils
186+
include("test_utils.jl")
187187

188188
include("experimental.jl")
189189
include("deprecated.jl")

src/test_utils.jl

+1
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@ include("test_utils/models.jl")
1818
include("test_utils/contexts.jl")
1919
include("test_utils/varinfo.jl")
2020
include("test_utils/sampler.jl")
21+
include("test_utils/ad.jl")
2122

2223
end

src/test_utils/ad.jl

+218
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
module AD
2+
3+
using ADTypes: AbstractADType, AutoForwardDiff
4+
using Chairmarks: @be
5+
import DifferentiationInterface as DI
6+
using DocStringExtensions
7+
using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo
8+
using LogDensityProblems: logdensity, logdensity_and_gradient
9+
using Random: Random, Xoshiro
10+
using Statistics: median
11+
using Test: @test
12+
13+
export ADTestResult, run_ad, make_function, make_params
14+
15+
"""
16+
REFERENCE_ADTYPE
17+
18+
Reference AD backend to use for comparison. In this case, ForwardDiff.jl, since
19+
it's the default AD backend used in Turing.jl.
20+
"""
21+
const REFERENCE_ADTYPE = AutoForwardDiff()
22+
23+
"""
24+
ADTestResult
25+
26+
Data structure to store the results of the AD correctness test.
27+
28+
If you want to quickly check whether the result is a success or failure, you
29+
can use `isnothing(result.error)`.
30+
"""
31+
struct ADTestResult
32+
"The DynamicPPL model that was tested"
33+
model::Model
34+
"The values at which the model was evaluated"
35+
params::Vector{<:Real}
36+
"The AD backend that was tested"
37+
adtype::AbstractADType
38+
"The absolute tolerance for the value of logp"
39+
value_atol::Real
40+
"The absolute tolerance for the gradient of logp"
41+
grad_atol::Real
42+
"If the test ran, the expected value of logp (calculated using the reference AD backend)"
43+
value_expected::Union{Nothing,Float64}
44+
"If the test ran, the expected gradient of logp (calculated using the reference AD backend)"
45+
grad_expected::Union{Nothing,Vector{Float64}}
46+
"If the test ran, the actual value of logp (calculated using `adtype`)"
47+
value_actual::Union{Nothing,Real}
48+
"If the test ran, the actual gradient of logp (calculated using `adtype`)"
49+
grad_actual::Union{Nothing,Vector{Float64}}
50+
"If the test ran and benchmarking was requested, the time taken by the AD backend to calculate the gradient of logp, divided by the time taken to evaluate logp itself"
51+
time_vs_primal::Union{Nothing,Float64}
52+
end
53+
54+
"""
55+
run_ad(
56+
model::Model,
57+
adtype::ADTypes.AbstractADType;
58+
test=true,
59+
benchmark=false,
60+
value_atol=1e-6,
61+
grad_atol=1e-6,
62+
varinfo::AbstractVarInfo=VarInfo(model),
63+
params::Vector{<:Real}=varinfo[:],
64+
reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
65+
expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing,
66+
verbose=true,
67+
)::ADTestResult
68+
69+
Test the correctness and/or benchmark the AD backend `adtype` for the model
70+
`model`.
71+
72+
Whether to test and benchmark is controlled by the `test` and `benchmark`
73+
keyword arguments. By default, `test` is `true` and `benchmark` is `false.
74+
75+
Returns an [`ADTestResult`](@ref) object, which contains the results of the
76+
test and/or benchmark.
77+
78+
This function is not as complicated as its signature makes it look. There are
79+
two things that must be provided:
80+
81+
1. `model` - The model being tested.
82+
2. `adtype` - The AD backend being tested.
83+
84+
Everything else is optional, and can be categorised into several groups:
85+
86+
1. _How to specify the VarInfo._ DynamicPPL contains several different types of
87+
VarInfo objects which change the way model evaluation occurs. If you want to
88+
use a specific type of VarInfo, pass it as the `varinfo` argument. Otherwise,
89+
it will default to using a `TypedVarInfo` generated from the model.
90+
91+
2. _How to specify the parameters._ For maximum control over this, generate a
92+
vector of parameters yourself and pass this as the `params` argument. If you
93+
don't specify this, it will be taken from the contents of the VarInfo. Note
94+
that if the VarInfo is not specified (and thus automatically generated) the
95+
parameters in it will have been sampled from the prior of the model. If you
96+
want to seed the parameter generation, the easiest way is to pass a `rng`
97+
argument to the VarInfo constructor (i.e. do `VarInfo(rng, model)`).
98+
99+
3. _How to specify the results to compare against._ (Only if `test=true`.) Once
100+
logp and its gradient has been calculated with the specified `adtype`, it must
101+
be tested for correctness. This can be done either by specifying
102+
`reference_adtype`, in which case logp and its gradient will also be calculated
103+
with this reference in order to obtain the ground truth; or by using
104+
`expected_value_and_grad`, which is a tuple of (logp, gradient) that the
105+
calculated values must match. The latter is useful if you are testing multiple
106+
AD backends and want to avoid recalculating the ground truth multiple times.
107+
The default reference backend is ForwardDiff. If none of these parameters are
108+
specified, that will be used to calculate the ground truth.
109+
110+
4. _How to specify the tolerances._ (Only if `test=true`.) The tolerances for
111+
the value and gradient can be set using `value_atol` and `grad_atol`. These
112+
default to 1e-6.
113+
114+
5. _Whether to output extra logging information._ By default, this function
115+
prints a message when it runs. To silence it, set `verbose=false`.
116+
"""
117+
function run_ad(
118+
model::Model,
119+
adtype::AbstractADType;
120+
test=true,
121+
benchmark=false,
122+
value_atol=1e-6,
123+
grad_atol=1e-6,
124+
varinfo::AbstractVarInfo=VarInfo(model),
125+
params::Vector{<:Real}=varinfo[:],
126+
reference_adtype::AbstractADType=REFERENCE_ADTYPE,
127+
expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing,
128+
verbose=true,
129+
)::ADTestResult
130+
verbose && @info "Running AD on $(model.f) with $(adtype)\n"
131+
ldf = LogDensityFunction(model; adtype=adtype)
132+
verbose && println(" params : $(params)")
133+
134+
value, grad = logdensity_and_gradient(ldf, params)
135+
if !(grad isa Vector{Float64})
136+
# https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754
137+
grad = collect(grad)
138+
end
139+
verbose && println(" actual : $((value, grad))")
140+
141+
if test
142+
# Calculate ground truth to compare against
143+
value_true, grad_true = if expected_value_and_grad === nothing
144+
ldf_reference = LogDensityFunction(model; adtype=reference_adtype)
145+
logdensity_and_gradient(ldf_reference, params)
146+
else
147+
expected_value_and_grad
148+
end
149+
verbose && println(" expected : $((value_true, grad_true))")
150+
@test isapprox(value, value_true; atol=value_atol)
151+
@test isapprox(grad, grad_true; atol=grad_atol)
152+
else
153+
value_true = nothing
154+
grad_true = nothing
155+
end
156+
157+
time_vs_primal = if benchmark
158+
primal_benchmark = @be (ldf, params) logdensity(_[1], _[2])
159+
grad_benchmark = @be (ldf, params) logdensity_and_gradient(_[1], _[2])
160+
t = median(grad_benchmark).time / median(primal_benchmark).time
161+
verbose && println("grad / primal : $(t)")
162+
t
163+
else
164+
nothing
165+
end
166+
167+
return ADTestResult(
168+
model,
169+
params,
170+
adtype,
171+
value_atol,
172+
grad_atol,
173+
value_true,
174+
grad_true,
175+
value,
176+
grad,
177+
time_vs_primal,
178+
)
179+
end
180+
181+
"""
182+
make_function(model)
183+
184+
Generate the function to be differentiated. Specifically,
185+
`make_function(model)` returns a function which takes a single argument
186+
`params` and returns the logdensity of `model` evaluated at `params`.
187+
188+
Thus, if you have an AD package that does not have integrations with either
189+
LogDensityProblemsAD.jl (in which case you can use `ad_ldp`) or
190+
DifferentiationInterface.jl (in which case you can use `ad_di`), you can
191+
test whether your AD package works with Turing.jl models using:
192+
193+
```julia
194+
f = make_function(model)
195+
params = make_params(model)
196+
value, grad = YourADPackage.gradient(f, params)
197+
```
198+
199+
and compare the results against that obtained from either `ad_ldp` or `ad_di` for
200+
an existing AD package with support.
201+
202+
See also: `make_params`.
203+
"""
204+
make_function(model::Model) = Base.Fix1(logdensity, LogDensityFunction(model))
205+
206+
"""
207+
make_params([rng, ]model)
208+
209+
Generate a vector of parameters sampled from the prior distribution of `model`.
210+
This can be used as the input to the function to be differentiated. See
211+
`make_function` for more details.
212+
"""
213+
function make_params(rng::Random.AbstractRNG, model::Model)
214+
return VarInfo(rng, model)[:]
215+
end
216+
make_params(model::Model) = make_params(Random.default_rng(), model)
217+
218+
end # module DynamicPPL.TestUtils.AD

test/ad.jl

+6-4
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,12 @@ using DynamicPPL: LogDensityFunction
5656
ref_ldf, adtype
5757
)
5858
else
59-
ldf = DynamicPPL.LogDensityFunction(ref_ldf, adtype)
60-
logp, grad = LogDensityProblems.logdensity_and_gradient(ldf, x)
61-
@test grad ref_grad
62-
@test logp ref_logp
59+
DynamicPPL.TestUtils.AD.run_ad(
60+
m,
61+
adtype;
62+
varinfo=varinfo,
63+
expected_value_and_grad=(ref_logp, ref_grad),
64+
)
6365
end
6466
end
6567
end

0 commit comments

Comments
 (0)