Skip to content

Commit 6718636

Browse files
committed
Implement AD testing and benchmarking (hand rolled)
1 parent eed80e5 commit 6718636

File tree

6 files changed

+205
-6
lines changed

6 files changed

+205
-6
lines changed

Project.toml

+4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
1010
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
1111
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
1212
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
13+
Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de"
1314
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
1415
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1516
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
@@ -22,6 +23,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
2223
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
2324
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2425
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
26+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2527
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2628

2729
[weakdeps]
@@ -49,6 +51,7 @@ Accessors = "0.1"
4951
BangBang = "0.4.1"
5052
Bijectors = "0.13.18, 0.14, 0.15"
5153
ChainRulesCore = "1"
54+
Chairmarks = "1.3.1"
5255
Compat = "4"
5356
ConstructionBase = "1.5.4"
5457
DifferentiationInterface = "0.6.41"
@@ -67,5 +70,6 @@ Mooncake = "0.4.95"
6770
OrderedCollections = "1"
6871
Random = "1.6"
6972
Requires = "1"
73+
Statistics = "1"
7074
Test = "1.6"
7175
julia = "1.10"

docs/src/api.md

+10-1
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,16 @@ values_as_in_model
205205
NamedDist
206206
```
207207

208-
## Testing Utilities
208+
## AD testing and benchmarking utilities
209+
210+
To test and/or benchmark the performance of an AD backend on a model, DynamicPPL provides the following utilities:
211+
212+
```@docs
213+
DynamicPPL.TestUtils.AD.run_ad
214+
DynamicPPL.TestUtils.AD.ADResult
215+
```
216+
217+
## Demo models
209218

210219
DynamicPPL provides several demo models and helpers for testing samplers in the `DynamicPPL.TestUtils` submodule.
211220

src/DynamicPPL.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,6 @@ include("context_implementations.jl")
175175
include("compiler.jl")
176176
include("pointwise_logdensities.jl")
177177
include("submodel_macro.jl")
178-
include("test_utils.jl")
179178
include("transforming.jl")
180179
include("logdensityfunction.jl")
181180
include("model_utils.jl")
@@ -184,6 +183,7 @@ include("values_as_in_model.jl")
184183

185184
include("debug_utils.jl")
186185
using .DebugUtils
186+
include("test_utils.jl")
187187

188188
include("experimental.jl")
189189
include("deprecated.jl")

src/test_utils.jl

+1
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@ include("test_utils/models.jl")
1818
include("test_utils/contexts.jl")
1919
include("test_utils/varinfo.jl")
2020
include("test_utils/sampler.jl")
21+
include("test_utils/ad.jl")
2122

2223
end

src/test_utils/ad.jl

+183
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
module AD
2+
3+
using ADTypes: AbstractADType, AutoForwardDiff
4+
using Chairmarks: @be
5+
import DifferentiationInterface as DI
6+
using DocStringExtensions
7+
using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo
8+
using LogDensityProblems: logdensity, logdensity_and_gradient
9+
using Random: Random, Xoshiro
10+
using Statistics: median
11+
using Test: @test
12+
13+
export ADResult, run_ad
14+
15+
"""
16+
REFERENCE_ADTYPE
17+
18+
Reference AD backend to use for comparison. In this case, ForwardDiff.jl, since
19+
it's the default AD backend used in Turing.jl.
20+
"""
21+
const REFERENCE_ADTYPE = AutoForwardDiff()
22+
23+
"""
24+
ADResult
25+
26+
Data structure to store the results of the AD correctness test.
27+
28+
If you want to quickly check whether the result is a success or failure, you
29+
can use `isnothing(result.error)`.
30+
"""
31+
struct ADResult
32+
"The DynamicPPL model that was tested"
33+
model::Model
34+
"The values at which the model was evaluated"
35+
params::Vector{<:Real}
36+
"The AD backend that was tested"
37+
adtype::AbstractADType
38+
"The absolute tolerance for the value of logp"
39+
value_atol::Real
40+
"The absolute tolerance for the gradient of logp"
41+
grad_atol::Real
42+
"The expected value of logp"
43+
value_expected::Union{Nothing,Float64}
44+
"The expected gradient of logp"
45+
grad_expected::Union{Nothing,Vector{Float64}}
46+
"The value of logp (calculated using `adtype`)"
47+
value_actual::Union{Nothing,Real}
48+
"The gradient of logp (calculated using `adtype`)"
49+
grad_actual::Union{Nothing,Vector{Float64}}
50+
"If benchmarking was requested, the time taken by the AD backend to calculate the gradient of logp, divided by the time taken to evaluate logp itself"
51+
time_vs_primal::Union{Nothing,Float64}
52+
end
53+
54+
"""
55+
run_ad(
56+
model::Model,
57+
adtype::ADTypes.AbstractADType;
58+
test=true,
59+
benchmark=false,
60+
value_atol=1e-6,
61+
grad_atol=1e-6,
62+
varinfo::AbstractVarInfo=VarInfo(model),
63+
params::Vector{<:Real}=varinfo[:],
64+
reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
65+
expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing,
66+
verbose=true,
67+
)::ADResult
68+
69+
Test the correctness and/or benchmark the AD backend `adtype` for the model
70+
`model`.
71+
72+
Whether to test and benchmark is controlled by the `test` and `benchmark`
73+
keyword arguments. By default, `test` is `true` and `benchmark` is `false.
74+
75+
Returns an [`ADResult`](@ref) object, which contains the results of the
76+
test and/or benchmark.
77+
78+
This function is not as complicated as its signature makes it look. There are
79+
two things that must be provided:
80+
81+
1. `model` - The model being tested.
82+
2. `adtype` - The AD backend being tested.
83+
84+
Everything else is optional, and can be categorised into several groups:
85+
86+
1. _How to specify the VarInfo._ DynamicPPL contains several different types of
87+
VarInfo objects which change the way model evaluation occurs. If you want to
88+
use a specific type of VarInfo, pass it as the `varinfo` argument. Otherwise,
89+
it will default to using a `TypedVarInfo` generated from the model.
90+
91+
2. _How to specify the parameters._ For maximum control over this, generate a
92+
vector of parameters yourself and pass this as the `params` argument. If you
93+
don't specify this, it will be taken from the contents of the VarInfo. Note
94+
that if the VarInfo is not specified (and thus automatically generated) the
95+
parameters in it will have been sampled from the prior of the model. If you
96+
want to seed the parameter generation, the easiest way is to pass a `rng`
97+
argument to the VarInfo constructor (i.e. do `VarInfo(rng, model)`).
98+
99+
3. _How to specify the results to compare against._ (Only if `test=true`.) Once
100+
logp and its gradient has been calculated with the specified `adtype`, it must
101+
be tested for correctness. This can be done either by specifying
102+
`reference_adtype`, in which case logp and its gradient will also be calculated
103+
with this reference in order to obtain the ground truth; or by using
104+
`expected_value_and_grad`, which is a tuple of (logp, gradient) that the
105+
calculated values must match. The latter is useful if you are testing multiple
106+
AD backends and want to avoid recalculating the ground truth multiple times.
107+
The default reference backend is ForwardDiff. If none of these parameters are
108+
specified, that will be used to calculate the ground truth.
109+
110+
4. _How to specify the tolerances._ (Only if `test=true`.) The tolerances for
111+
the value and gradient can be set using `value_atol` and `grad_atol`. These
112+
default to 1e-6.
113+
114+
5. _Whether to output extra logging information._ By default, this function
115+
prints a message when it runs. To silence it, set `verbose=false`.
116+
"""
117+
function run_ad(
118+
model::Model,
119+
adtype::AbstractADType;
120+
test=true,
121+
benchmark=false,
122+
value_atol=1e-6,
123+
grad_atol=1e-6,
124+
varinfo::AbstractVarInfo=VarInfo(model),
125+
params::Vector{<:Real}=varinfo[:],
126+
reference_adtype::AbstractADType=REFERENCE_ADTYPE,
127+
expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing,
128+
verbose=true,
129+
)::ADResult
130+
verbose && @info "Running AD on $(model.f) with $(adtype)\n"
131+
params = map(identity, params)
132+
verbose && println(" params : $(params)")
133+
ldf = LogDensityFunction(model; adtype=adtype)
134+
135+
value, grad = logdensity_and_gradient(ldf, params)
136+
if !(grad isa Vector{Float64})
137+
# https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754
138+
grad = collect(grad)
139+
end
140+
verbose && println(" actual : $((value, grad))")
141+
142+
if test
143+
# Calculate ground truth to compare against
144+
value_true, grad_true = if expected_value_and_grad === nothing
145+
ldf_reference = LogDensityFunction(model; adtype=reference_adtype)
146+
logdensity_and_gradient(ldf_reference, params)
147+
else
148+
expected_value_and_grad
149+
end
150+
verbose && println(" expected : $((value_true, grad_true))")
151+
# Then compare
152+
@test isapprox(value, value_true; atol=value_atol)
153+
@test isapprox(grad, grad_true; atol=grad_atol)
154+
else
155+
value_true = nothing
156+
grad_true = nothing
157+
end
158+
159+
time_vs_primal = if benchmark
160+
primal_benchmark = @be (ldf, params) logdensity(_[1], _[2])
161+
grad_benchmark = @be (ldf, params) logdensity_and_gradient(_[1], _[2])
162+
t = median(grad_benchmark).time / median(primal_benchmark).time
163+
verbose && println("grad / primal : $(t)")
164+
t
165+
else
166+
nothing
167+
end
168+
169+
return ADResult(
170+
model,
171+
params,
172+
adtype,
173+
value_atol,
174+
grad_atol,
175+
value_true,
176+
grad_true,
177+
value,
178+
grad,
179+
time_vs_primal,
180+
)
181+
end
182+
183+
end # module DynamicPPL.TestUtils.AD

test/ad.jl

+6-4
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,12 @@ using DynamicPPL: LogDensityFunction
5656
ref_ldf, adtype
5757
)
5858
else
59-
ldf = DynamicPPL.LogDensityFunction(ref_ldf, adtype)
60-
logp, grad = LogDensityProblems.logdensity_and_gradient(ldf, x)
61-
@test grad ref_grad
62-
@test logp ref_logp
59+
DynamicPPL.TestUtils.AD.run_ad(
60+
m,
61+
adtype;
62+
varinfo=varinfo,
63+
expected_value_and_grad=(ref_logp, ref_grad),
64+
)
6365
end
6466
end
6567
end

0 commit comments

Comments
 (0)