|
| 1 | +module AD |
| 2 | + |
| 3 | +using ADTypes: AbstractADType, AutoForwardDiff |
| 4 | +using Chairmarks: @be |
| 5 | +import DifferentiationInterface as DI |
| 6 | +using DocStringExtensions |
| 7 | +using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo |
| 8 | +using LogDensityProblems: logdensity, logdensity_and_gradient |
| 9 | +using Random: Random, Xoshiro |
| 10 | +using Statistics: median |
| 11 | +using Test: @test |
| 12 | + |
| 13 | +export ADResult, run_ad |
| 14 | + |
| 15 | +""" |
| 16 | + REFERENCE_ADTYPE |
| 17 | +
|
| 18 | +Reference AD backend to use for comparison. In this case, ForwardDiff.jl, since |
| 19 | +it's the default AD backend used in Turing.jl. |
| 20 | +""" |
| 21 | +const REFERENCE_ADTYPE = AutoForwardDiff() |
| 22 | + |
| 23 | +""" |
| 24 | + ADResult |
| 25 | +
|
| 26 | +Data structure to store the results of the AD correctness test. |
| 27 | +
|
| 28 | +If you want to quickly check whether the result is a success or failure, you |
| 29 | +can use `isnothing(result.error)`. |
| 30 | +""" |
| 31 | +struct ADResult |
| 32 | + "The DynamicPPL model that was tested" |
| 33 | + model::Model |
| 34 | + "The values at which the model was evaluated" |
| 35 | + params::Vector{<:Real} |
| 36 | + "The AD backend that was tested" |
| 37 | + adtype::AbstractADType |
| 38 | + "The absolute tolerance for the value of logp" |
| 39 | + value_atol::Real |
| 40 | + "The absolute tolerance for the gradient of logp" |
| 41 | + grad_atol::Real |
| 42 | + "The expected value of logp" |
| 43 | + value_expected::Union{Nothing,Float64} |
| 44 | + "The expected gradient of logp" |
| 45 | + grad_expected::Union{Nothing,Vector{Float64}} |
| 46 | + "The value of logp (calculated using `adtype`)" |
| 47 | + value_actual::Union{Nothing,Real} |
| 48 | + "The gradient of logp (calculated using `adtype`)" |
| 49 | + grad_actual::Union{Nothing,Vector{Float64}} |
| 50 | + "If benchmarking was requested, the time taken by the AD backend to calculate the gradient of logp, divided by the time taken to evaluate logp itself" |
| 51 | + time_vs_primal::Union{Nothing,Float64} |
| 52 | +end |
| 53 | + |
| 54 | +""" |
| 55 | + run_ad( |
| 56 | + model::Model, |
| 57 | + adtype::ADTypes.AbstractADType; |
| 58 | + test=true, |
| 59 | + benchmark=false, |
| 60 | + value_atol=1e-6, |
| 61 | + grad_atol=1e-6, |
| 62 | + varinfo::AbstractVarInfo=VarInfo(model), |
| 63 | + params::Vector{<:Real}=varinfo[:], |
| 64 | + reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE, |
| 65 | + expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing, |
| 66 | + verbose=true, |
| 67 | + )::ADResult |
| 68 | +
|
| 69 | +Test the correctness and/or benchmark the AD backend `adtype` for the model |
| 70 | +`model`. |
| 71 | +
|
| 72 | +Whether to test and benchmark is controlled by the `test` and `benchmark` |
| 73 | +keyword arguments. By default, `test` is `true` and `benchmark` is `false. |
| 74 | +
|
| 75 | +Returns an [`ADResult`](@ref) object, which contains the results of the |
| 76 | +test and/or benchmark. |
| 77 | +
|
| 78 | +This function is not as complicated as its signature makes it look. There are |
| 79 | +two things that must be provided: |
| 80 | +
|
| 81 | +1. `model` - The model being tested. |
| 82 | +2. `adtype` - The AD backend being tested. |
| 83 | +
|
| 84 | +Everything else is optional, and can be categorised into several groups: |
| 85 | +
|
| 86 | +1. _How to specify the VarInfo._ DynamicPPL contains several different types of |
| 87 | +VarInfo objects which change the way model evaluation occurs. If you want to |
| 88 | +use a specific type of VarInfo, pass it as the `varinfo` argument. Otherwise, |
| 89 | +it will default to using a `TypedVarInfo` generated from the model. |
| 90 | +
|
| 91 | +2. _How to specify the parameters._ For maximum control over this, generate a |
| 92 | +vector of parameters yourself and pass this as the `params` argument. If you |
| 93 | +don't specify this, it will be taken from the contents of the VarInfo. Note |
| 94 | +that if the VarInfo is not specified (and thus automatically generated) the |
| 95 | +parameters in it will have been sampled from the prior of the model. If you |
| 96 | +want to seed the parameter generation, the easiest way is to pass a `rng` |
| 97 | +argument to the VarInfo constructor (i.e. do `VarInfo(rng, model)`). |
| 98 | +
|
| 99 | +3. _How to specify the results to compare against._ (Only if `test=true`.) Once |
| 100 | +logp and its gradient has been calculated with the specified `adtype`, it must |
| 101 | +be tested for correctness. This can be done either by specifying |
| 102 | +`reference_adtype`, in which case logp and its gradient will also be calculated |
| 103 | +with this reference in order to obtain the ground truth; or by using |
| 104 | +`expected_value_and_grad`, which is a tuple of (logp, gradient) that the |
| 105 | +calculated values must match. The latter is useful if you are testing multiple |
| 106 | +AD backends and want to avoid recalculating the ground truth multiple times. |
| 107 | +The default reference backend is ForwardDiff. If none of these parameters are |
| 108 | +specified, that will be used to calculate the ground truth. |
| 109 | +
|
| 110 | +4. _How to specify the tolerances._ (Only if `test=true`.) The tolerances for |
| 111 | +the value and gradient can be set using `value_atol` and `grad_atol`. These |
| 112 | +default to 1e-6. |
| 113 | +
|
| 114 | +5. _Whether to output extra logging information._ By default, this function |
| 115 | +prints a message when it runs. To silence it, set `verbose=false`. |
| 116 | +""" |
| 117 | +function run_ad( |
| 118 | + model::Model, |
| 119 | + adtype::AbstractADType; |
| 120 | + test=true, |
| 121 | + benchmark=false, |
| 122 | + value_atol=1e-6, |
| 123 | + grad_atol=1e-6, |
| 124 | + varinfo::AbstractVarInfo=VarInfo(model), |
| 125 | + params::Vector{<:Real}=varinfo[:], |
| 126 | + reference_adtype::AbstractADType=REFERENCE_ADTYPE, |
| 127 | + expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing, |
| 128 | + verbose=true, |
| 129 | +)::ADResult |
| 130 | + verbose && @info "Running AD on $(model.f) with $(adtype)\n" |
| 131 | + params = map(identity, params) |
| 132 | + verbose && println(" params : $(params)") |
| 133 | + ldf = LogDensityFunction(model; adtype=adtype) |
| 134 | + |
| 135 | + value, grad = logdensity_and_gradient(ldf, params) |
| 136 | + if !(grad isa Vector{Float64}) |
| 137 | + # https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 |
| 138 | + grad = collect(grad) |
| 139 | + end |
| 140 | + verbose && println(" actual : $((value, grad))") |
| 141 | + |
| 142 | + if test |
| 143 | + # Calculate ground truth to compare against |
| 144 | + value_true, grad_true = if expected_value_and_grad === nothing |
| 145 | + ldf_reference = LogDensityFunction(model; adtype=reference_adtype) |
| 146 | + logdensity_and_gradient(ldf_reference, params) |
| 147 | + else |
| 148 | + expected_value_and_grad |
| 149 | + end |
| 150 | + verbose && println(" expected : $((value_true, grad_true))") |
| 151 | + # Then compare |
| 152 | + @test isapprox(value, value_true; atol=value_atol) |
| 153 | + @test isapprox(grad, grad_true; atol=grad_atol) |
| 154 | + else |
| 155 | + value_true = nothing |
| 156 | + grad_true = nothing |
| 157 | + end |
| 158 | + |
| 159 | + time_vs_primal = if benchmark |
| 160 | + primal_benchmark = @be (ldf, params) logdensity(_[1], _[2]) |
| 161 | + grad_benchmark = @be (ldf, params) logdensity_and_gradient(_[1], _[2]) |
| 162 | + t = median(grad_benchmark).time / median(primal_benchmark).time |
| 163 | + verbose && println("grad / primal : $(t)") |
| 164 | + t |
| 165 | + else |
| 166 | + nothing |
| 167 | + end |
| 168 | + |
| 169 | + return ADResult( |
| 170 | + model, |
| 171 | + params, |
| 172 | + adtype, |
| 173 | + value_atol, |
| 174 | + grad_atol, |
| 175 | + value_true, |
| 176 | + grad_true, |
| 177 | + value, |
| 178 | + grad, |
| 179 | + time_vs_primal, |
| 180 | + ) |
| 181 | +end |
| 182 | + |
| 183 | +end # module DynamicPPL.TestUtils.AD |
0 commit comments