Skip to content

Commit 5826564

Browse files
committed
Implement AD testing and benchmarking (hand rolled)
1 parent eed80e5 commit 5826564

File tree

6 files changed

+208
-6
lines changed

6 files changed

+208
-6
lines changed

Project.toml

+4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
1010
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
1111
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
1212
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
13+
Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de"
1314
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
1415
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1516
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
@@ -22,6 +23,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
2223
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
2324
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2425
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
26+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2527
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2628

2729
[weakdeps]
@@ -49,6 +51,7 @@ Accessors = "0.1"
4951
BangBang = "0.4.1"
5052
Bijectors = "0.13.18, 0.14, 0.15"
5153
ChainRulesCore = "1"
54+
Chairmarks = "1.3.1"
5255
Compat = "4"
5356
ConstructionBase = "1.5.4"
5457
DifferentiationInterface = "0.6.41"
@@ -67,5 +70,6 @@ Mooncake = "0.4.95"
6770
OrderedCollections = "1"
6871
Random = "1.6"
6972
Requires = "1"
73+
Statistics = "1"
7074
Test = "1.6"
7175
julia = "1.10"

docs/src/api.md

+10-1
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,16 @@ values_as_in_model
205205
NamedDist
206206
```
207207

208-
## Testing Utilities
208+
## AD testing and benchmarking utilities
209+
210+
To test and/or benchmark the performance of an AD backend on a model, DynamicPPL provides the following utilities:
211+
212+
```@docs
213+
DynamicPPL.TestUtils.AD.run_ad
214+
DynamicPPL.TestUtils.AD.ADResult
215+
```
216+
217+
## Demo models
209218

210219
DynamicPPL provides several demo models and helpers for testing samplers in the `DynamicPPL.TestUtils` submodule.
211220

src/DynamicPPL.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,6 @@ include("context_implementations.jl")
175175
include("compiler.jl")
176176
include("pointwise_logdensities.jl")
177177
include("submodel_macro.jl")
178-
include("test_utils.jl")
179178
include("transforming.jl")
180179
include("logdensityfunction.jl")
181180
include("model_utils.jl")
@@ -184,6 +183,7 @@ include("values_as_in_model.jl")
184183

185184
include("debug_utils.jl")
186185
using .DebugUtils
186+
include("test_utils.jl")
187187

188188
include("experimental.jl")
189189
include("deprecated.jl")

src/test_utils.jl

+1
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@ include("test_utils/models.jl")
1818
include("test_utils/contexts.jl")
1919
include("test_utils/varinfo.jl")
2020
include("test_utils/sampler.jl")
21+
include("test_utils/ad.jl")
2122

2223
end

src/test_utils/ad.jl

+186
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
module AD
2+
3+
using ADTypes: AbstractADType, AutoForwardDiff
4+
using Chairmarks: @be
5+
import DifferentiationInterface as DI
6+
using DocStringExtensions
7+
using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo
8+
using LogDensityProblems: logdensity, logdensity_and_gradient
9+
using Random: Random, Xoshiro
10+
using Statistics: median
11+
using Test: @test
12+
13+
export ADResult, run_ad
14+
15+
"""
16+
REFERENCE_ADTYPE
17+
18+
Reference AD backend to use for comparison. In this case, ForwardDiff.jl, since
19+
it's the default AD backend used in Turing.jl.
20+
"""
21+
const REFERENCE_ADTYPE = AutoForwardDiff()
22+
23+
"""
24+
ADResult
25+
26+
Data structure to store the results of the AD correctness test.
27+
28+
If you want to quickly check whether the result is a success or failure, you
29+
can use `isnothing(result.error)`.
30+
"""
31+
struct ADResult
32+
"The DynamicPPL model that was tested"
33+
model::Model
34+
"The VarInfo that was used"
35+
varinfo::AbstractVarInfo
36+
"The values at which the model was evaluated"
37+
params::Vector{<:Real}
38+
"The AD backend that was tested"
39+
adtype::AbstractADType
40+
"The absolute tolerance for the value of logp"
41+
value_atol::Real
42+
"The absolute tolerance for the gradient of logp"
43+
grad_atol::Real
44+
"The expected value of logp"
45+
value_expected::Union{Nothing,Float64}
46+
"The expected gradient of logp"
47+
grad_expected::Union{Nothing,Vector{Float64}}
48+
"The value of logp (calculated using `adtype`)"
49+
value_actual::Union{Nothing,Real}
50+
"The gradient of logp (calculated using `adtype`)"
51+
grad_actual::Union{Nothing,Vector{Float64}}
52+
"If benchmarking was requested, the time taken by the AD backend to calculate the gradient of logp, divided by the time taken to evaluate logp itself"
53+
time_vs_primal::Union{Nothing,Float64}
54+
end
55+
56+
"""
57+
run_ad(
58+
model::Model,
59+
adtype::ADTypes.AbstractADType;
60+
test=true,
61+
benchmark=false,
62+
value_atol=1e-6,
63+
grad_atol=1e-6,
64+
varinfo::AbstractVarInfo=VarInfo(model),
65+
params::Vector{<:Real}=varinfo[:],
66+
reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
67+
expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing,
68+
verbose=true,
69+
)::ADResult
70+
71+
Test the correctness and/or benchmark the AD backend `adtype` for the model
72+
`model`.
73+
74+
Whether to test and benchmark is controlled by the `test` and `benchmark`
75+
keyword arguments. By default, `test` is `true` and `benchmark` is `false.
76+
77+
Returns an [`ADResult`](@ref) object, which contains the results of the
78+
test and/or benchmark.
79+
80+
This function is not as complicated as its signature makes it look. There are
81+
two things that must be provided:
82+
83+
1. `model` - The model being tested.
84+
2. `adtype` - The AD backend being tested.
85+
86+
Everything else is optional, and can be categorised into several groups:
87+
88+
1. _How to specify the VarInfo._ DynamicPPL contains several different types of
89+
VarInfo objects which change the way model evaluation occurs. If you want to
90+
use a specific type of VarInfo, pass it as the `varinfo` argument. Otherwise,
91+
it will default to using a `TypedVarInfo` generated from the model.
92+
93+
2. _How to specify the parameters._ For maximum control over this, generate a
94+
vector of parameters yourself and pass this as the `params` argument. If you
95+
don't specify this, it will be taken from the contents of the VarInfo. Note
96+
that if the VarInfo is not specified (and thus automatically generated) the
97+
parameters in it will have been sampled from the prior of the model. If you
98+
want to seed the parameter generation, the easiest way is to pass a `rng`
99+
argument to the VarInfo constructor (i.e. do `VarInfo(rng, model)`).
100+
101+
3. _How to specify the results to compare against._ (Only if `test=true`.) Once
102+
logp and its gradient has been calculated with the specified `adtype`, it must
103+
be tested for correctness. This can be done either by specifying
104+
`reference_adtype`, in which case logp and its gradient will also be calculated
105+
with this reference in order to obtain the ground truth; or by using
106+
`expected_value_and_grad`, which is a tuple of (logp, gradient) that the
107+
calculated values must match. The latter is useful if you are testing multiple
108+
AD backends and want to avoid recalculating the ground truth multiple times.
109+
The default reference backend is ForwardDiff. If none of these parameters are
110+
specified, that will be used to calculate the ground truth.
111+
112+
4. _How to specify the tolerances._ (Only if `test=true`.) The tolerances for
113+
the value and gradient can be set using `value_atol` and `grad_atol`. These
114+
default to 1e-6.
115+
116+
5. _Whether to output extra logging information._ By default, this function
117+
prints a message when it runs. To silence it, set `verbose=false`.
118+
"""
119+
function run_ad(
120+
model::Model,
121+
adtype::AbstractADType;
122+
test=true,
123+
benchmark=false,
124+
value_atol=1e-6,
125+
grad_atol=1e-6,
126+
varinfo::AbstractVarInfo=VarInfo(model),
127+
params::Vector{<:Real}=varinfo[:],
128+
reference_adtype::AbstractADType=REFERENCE_ADTYPE,
129+
expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing,
130+
verbose=true,
131+
)::ADResult
132+
verbose && @info "Running AD on $(model.f) with $(adtype)\n"
133+
params = map(identity, params)
134+
verbose && println(" params : $(params)")
135+
ldf = LogDensityFunction(model; adtype=adtype)
136+
137+
value, grad = logdensity_and_gradient(ldf, params)
138+
if !(grad isa Vector{Float64})
139+
# https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754
140+
grad = collect(grad)
141+
end
142+
verbose && println(" actual : $((value, grad))")
143+
144+
if test
145+
# Calculate ground truth to compare against
146+
value_true, grad_true = if expected_value_and_grad === nothing
147+
ldf_reference = LogDensityFunction(model; adtype=reference_adtype)
148+
logdensity_and_gradient(ldf_reference, params)
149+
else
150+
expected_value_and_grad
151+
end
152+
verbose && println(" expected : $((value_true, grad_true))")
153+
# Then compare
154+
@test isapprox(value, value_true; atol=value_atol)
155+
@test isapprox(grad, grad_true; atol=grad_atol)
156+
else
157+
value_true = nothing
158+
grad_true = nothing
159+
end
160+
161+
time_vs_primal = if benchmark
162+
primal_benchmark = @be (ldf, params) logdensity(_[1], _[2])
163+
grad_benchmark = @be (ldf, params) logdensity_and_gradient(_[1], _[2])
164+
t = median(grad_benchmark).time / median(primal_benchmark).time
165+
verbose && println("grad / primal : $(t)")
166+
t
167+
else
168+
nothing
169+
end
170+
171+
return ADResult(
172+
model,
173+
varinfo,
174+
params,
175+
adtype,
176+
value_atol,
177+
grad_atol,
178+
value_true,
179+
grad_true,
180+
value,
181+
grad,
182+
time_vs_primal,
183+
)
184+
end
185+
186+
end # module DynamicPPL.TestUtils.AD

test/ad.jl

+6-4
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,12 @@ using DynamicPPL: LogDensityFunction
5656
ref_ldf, adtype
5757
)
5858
else
59-
ldf = DynamicPPL.LogDensityFunction(ref_ldf, adtype)
60-
logp, grad = LogDensityProblems.logdensity_and_gradient(ldf, x)
61-
@test grad ref_grad
62-
@test logp ref_logp
59+
DynamicPPL.TestUtils.AD.run_ad(
60+
m,
61+
adtype;
62+
varinfo=varinfo,
63+
expected_value_and_grad=(ref_logp, ref_grad),
64+
)
6365
end
6466
end
6567
end

0 commit comments

Comments
 (0)