Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement AD testing and benchmarking (hand rolled) #882

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Expand All @@ -22,6 +23,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[weakdeps]
Expand Down Expand Up @@ -49,6 +51,7 @@ Accessors = "0.1"
BangBang = "0.4.1"
Bijectors = "0.13.18, 0.14, 0.15"
ChainRulesCore = "1"
Chairmarks = "1.3.1"
Compat = "4"
ConstructionBase = "1.5.4"
DifferentiationInterface = "0.6.41"
Expand All @@ -67,5 +70,6 @@ Mooncake = "0.4.95"
OrderedCollections = "1"
Random = "1.6"
Requires = "1"
Statistics = "1"
Test = "1.6"
julia = "1.10"
11 changes: 10 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,16 @@ values_as_in_model
NamedDist
```

## Testing Utilities
## AD testing and benchmarking utilities

To test and/or benchmark the performance of an AD backend on a model, DynamicPPL provides the following utilities:

```@docs
DynamicPPL.TestUtils.AD.run_ad
DynamicPPL.TestUtils.AD.ADResult
```

## Demo models

DynamicPPL provides several demo models and helpers for testing samplers in the `DynamicPPL.TestUtils` submodule.

Expand Down
2 changes: 1 addition & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ include("context_implementations.jl")
include("compiler.jl")
include("pointwise_logdensities.jl")
include("submodel_macro.jl")
include("test_utils.jl")
include("transforming.jl")
include("logdensityfunction.jl")
include("model_utils.jl")
Expand All @@ -184,6 +183,7 @@ include("values_as_in_model.jl")

include("debug_utils.jl")
using .DebugUtils
include("test_utils.jl")

include("experimental.jl")
include("deprecated.jl")
Expand Down
1 change: 1 addition & 0 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ include("test_utils/models.jl")
include("test_utils/contexts.jl")
include("test_utils/varinfo.jl")
include("test_utils/sampler.jl")
include("test_utils/ad.jl")

end
190 changes: 190 additions & 0 deletions src/test_utils/ad.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
module AD

using ADTypes: AbstractADType, AutoForwardDiff
using Chairmarks: @be
import DifferentiationInterface as DI
using DocStringExtensions
using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo
using LogDensityProblems: logdensity, logdensity_and_gradient
using Random: Random, Xoshiro
using Statistics: median
using Test: @test

export ADResult, run_ad

# This function needed to work around the fact that different backends can
# return different AbstractArrays for the gradient. See
# https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 for more
# context.
_to_vec_f64(x::AbstractArray) = x isa Vector{Float64} ? x : collect(Float64, x)

"""
REFERENCE_ADTYPE
Reference AD backend to use for comparison. In this case, ForwardDiff.jl, since
it's the default AD backend used in Turing.jl.
"""
const REFERENCE_ADTYPE = AutoForwardDiff()

"""
ADResult
Data structure to store the results of the AD correctness test.
If you want to quickly check whether the result is a success or failure, you
can use `isnothing(result.error)`.
"""
struct ADResult
"The DynamicPPL model that was tested"
model::Model
"The VarInfo that was used"
varinfo::AbstractVarInfo
"The values at which the model was evaluated"
params::Vector{<:Real}
"The AD backend that was tested"
adtype::AbstractADType
"The absolute tolerance for the value of logp"
value_atol::Real
"The absolute tolerance for the gradient of logp"
grad_atol::Real
"The expected value of logp"
value_expected::Union{Nothing,Float64}
"The expected gradient of logp"
grad_expected::Union{Nothing,Vector{Float64}}
"The value of logp (calculated using `adtype`)"
value_actual::Union{Nothing,Real}
"The gradient of logp (calculated using `adtype`)"
grad_actual::Union{Nothing,Vector{Float64}}
"If benchmarking was requested, the time taken by the AD backend to calculate the gradient of logp, divided by the time taken to evaluate logp itself"
time_vs_primal::Union{Nothing,Float64}
end

"""
run_ad(
model::Model,
adtype::ADTypes.AbstractADType;
test=true,
benchmark=false,
value_atol=1e-6,
grad_atol=1e-6,
varinfo::AbstractVarInfo=VarInfo(model),
params::Vector{<:Real}=varinfo[:],
reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing,
verbose=true,
)::ADResult
Test the correctness and/or benchmark the AD backend `adtype` for the model
`model`.
Whether to test and benchmark is controlled by the `test` and `benchmark`
keyword arguments. By default, `test` is `true` and `benchmark` is `false.
Returns an [`ADResult`](@ref) object, which contains the results of the
test and/or benchmark.
This function is not as complicated as its signature makes it look. There are
two things that must be provided:
1. `model` - The model being tested.
2. `adtype` - The AD backend being tested.
Everything else is optional, and can be categorised into several groups:
1. _How to specify the VarInfo._ DynamicPPL contains several different types of
VarInfo objects which change the way model evaluation occurs. If you want to
use a specific type of VarInfo, pass it as the `varinfo` argument. Otherwise,
it will default to using a `TypedVarInfo` generated from the model.
2. _How to specify the parameters._ For maximum control over this, generate a
vector of parameters yourself and pass this as the `params` argument. If you
don't specify this, it will be taken from the contents of the VarInfo. Note
that if the VarInfo is not specified (and thus automatically generated) the
parameters in it will have been sampled from the prior of the model. If you
want to seed the parameter generation, the easiest way is to pass a `rng`
argument to the VarInfo constructor (i.e. do `VarInfo(rng, model)`).
3. _How to specify the results to compare against._ (Only if `test=true`.) Once
logp and its gradient has been calculated with the specified `adtype`, it must
be tested for correctness. This can be done either by specifying
`reference_adtype`, in which case logp and its gradient will also be calculated
with this reference in order to obtain the ground truth; or by using
`expected_value_and_grad`, which is a tuple of (logp, gradient) that the
calculated values must match. The latter is useful if you are testing multiple
AD backends and want to avoid recalculating the ground truth multiple times.
The default reference backend is ForwardDiff. If none of these parameters are
specified, that will be used to calculate the ground truth.
4. _How to specify the tolerances._ (Only if `test=true`.) The tolerances for
the value and gradient can be set using `value_atol` and `grad_atol`. These
default to 1e-6.
5. _Whether to output extra logging information._ By default, this function
prints a message when it runs. To silence it, set `verbose=false`.
"""
function run_ad(
model::Model,
adtype::AbstractADType;
test=true,
benchmark=false,
value_atol=1e-6,
grad_atol=1e-6,
varinfo::AbstractVarInfo=VarInfo(model),
params::Vector{<:Real}=varinfo[:],
reference_adtype::AbstractADType=REFERENCE_ADTYPE,
expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing,
verbose=true,
)::ADResult
verbose && @info "Running AD on $(model.f) with $(adtype)\n"
params = map(identity, params)
verbose && println(" params : $(params)")
ldf = LogDensityFunction(model; adtype=adtype)

value, grad = logdensity_and_gradient(ldf, params)
grad = _to_vec_f64(grad)
verbose && println(" actual : $((value, grad))")

if test
# Calculate ground truth to compare against
value_true, grad_true = if expected_value_and_grad === nothing
ldf_reference = LogDensityFunction(model; adtype=reference_adtype)
logdensity_and_gradient(ldf_reference, params)

Check warning on line 151 in src/test_utils/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/test_utils/ad.jl#L150-L151

Added lines #L150 - L151 were not covered by tests
else
expected_value_and_grad
end
verbose && println(" expected : $((value_true, grad_true))")
grad_true = _to_vec_f64(grad_true)
# Then compare
@test isapprox(value, value_true; atol=value_atol)
@test isapprox(grad, grad_true; atol=grad_atol)
else
value_true = nothing
grad_true = nothing

Check warning on line 162 in src/test_utils/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/test_utils/ad.jl#L161-L162

Added lines #L161 - L162 were not covered by tests
end

time_vs_primal = if benchmark
primal_benchmark = @be (ldf, params) logdensity(_[1], _[2])
grad_benchmark = @be (ldf, params) logdensity_and_gradient(_[1], _[2])
t = median(grad_benchmark).time / median(primal_benchmark).time
verbose && println("grad / primal : $(t)")
t

Check warning on line 170 in src/test_utils/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/test_utils/ad.jl#L166-L170

Added lines #L166 - L170 were not covered by tests
else
nothing
end

return ADResult(
model,
varinfo,
params,
adtype,
value_atol,
grad_atol,
value_true,
grad_true,
value,
grad,
time_vs_primal,
)
end

end # module DynamicPPL.TestUtils.AD
10 changes: 6 additions & 4 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@ using DynamicPPL: LogDensityFunction
ref_ldf, adtype
)
else
ldf = DynamicPPL.LogDensityFunction(ref_ldf, adtype)
logp, grad = LogDensityProblems.logdensity_and_gradient(ldf, x)
@test grad ref_grad
@test logp ref_logp
DynamicPPL.TestUtils.AD.run_ad(
m,
adtype;
varinfo=varinfo,
expected_value_and_grad=(ref_logp, ref_grad),
)
end
end
end
Expand Down
Loading