Skip to content

Commit 90c7b26

Browse files
penelopeysmgdalle
andauthored
Remove LogDensityProblemsAD; wrap adtype in LogDensityFunction (#806)
* Remove LogDensityProblemsAD * Implement LogDensityFunctionWithGrad in place of ADgradient * Dynamically decide whether to use closure vs constant * Combine LogDensityFunction{,WithGrad} into one (#811) * Warn if unsupported AD type is used * Update changelog * Update DI compat bound Co-authored-by: Guillaume Dalle <[email protected]> * Don't store with_closure inside LogDensityFunction Co-authored-by: Guillaume Dalle <[email protected]> * setadtype --> LogDensityFunction * Re-add ForwardDiffExt (including tests) * Add more tests for coverage --------- Co-authored-by: Guillaume Dalle <[email protected]>
1 parent f5e84f4 commit 90c7b26

13 files changed

+420
-187
lines changed

HISTORY.md

+47-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
## 0.35.0
44

5-
**Breaking**
5+
**Breaking changes**
66

77
### `.~` right hand side must be a univariate distribution
88

@@ -119,6 +119,52 @@ This release removes the feature of `VarInfo` where it kept track of which varia
119119
120120
This change also affects sampling in Turing.jl.
121121
122+
### `LogDensityFunction` argument order
123+
124+
- The method `LogDensityFunction(varinfo, model, context)` has been removed.
125+
The only accepted order is `LogDensityFunction(model, varinfo, context; adtype)`.
126+
(For an explanation of `adtype`, see below.)
127+
The varinfo and context arguments are both still optional.
128+
129+
**Other changes**
130+
131+
### `LogDensityProblems` interface
132+
133+
LogDensityProblemsAD is now removed as a dependency.
134+
Instead of constructing a `LogDensityProblemAD.ADgradient` object, we now directly use `DifferentiationInterface` to calculate the gradient of the log density with respect to model parameters.
135+
136+
Note that if you wish, you can still construct an `ADgradient` out of a `LogDensityFunction` object (there is nothing preventing this).
137+
138+
However, in this version, `LogDensityFunction` now takes an extra AD type argument.
139+
If this argument is not provided, the behaviour is exactly the same as before, i.e. you can calculate `logdensity` but not its gradient.
140+
However, if you do pass an AD type, that will allow you to calculate the gradient as well.
141+
You may thus find that it is easier to instead do this:
142+
143+
```julia
144+
@model f() = ...
145+
146+
ldf = LogDensityFunction(f(); adtype=AutoForwardDiff())
147+
```
148+
149+
This will return an object which satisfies the `LogDensityProblems` interface to first-order, i.e. you can now directly call both
150+
151+
```
152+
LogDensityProblems.logdensity(ldf, params)
153+
LogDensityProblems.logdensity_and_gradient(ldf, params)
154+
```
155+
156+
without having to construct a separate `ADgradient` object.
157+
158+
If you prefer, you can also construct a new `LogDensityFunction` with a new AD type afterwards.
159+
The model, varinfo, and context will be taken from the original `LogDensityFunction`:
160+
161+
```julia
162+
@model f() = ...
163+
164+
ldf = LogDensityFunction(f()) # by default, no adtype set
165+
ldf_with_ad = LogDensityFunction(ldf, AutoForwardDiff())
166+
```
167+
122168
## 0.34.2
123169
124170
- Fixed bugs in ValuesAsInModelContext as well as DebugContext where underlying PrefixContexts were not being applied.

Project.toml

+3-4
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
1212
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1313
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
1414
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
15+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1516
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1617
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1718
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
1819
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1920
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2021
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
21-
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
2222
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
2323
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
2424
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -51,15 +51,14 @@ Bijectors = "0.13.18, 0.14, 0.15"
5151
ChainRulesCore = "1"
5252
Compat = "4"
5353
ConstructionBase = "1.5.4"
54+
DifferentiationInterface = "0.6.41"
5455
Distributions = "0.25"
5556
DocStringExtensions = "0.9"
56-
KernelAbstractions = "0.9.33"
5757
EnzymeCore = "0.6 - 0.8"
58-
ForwardDiff = "0.10"
5958
JET = "0.9"
59+
KernelAbstractions = "0.9.33"
6060
LinearAlgebra = "1.6"
6161
LogDensityProblems = "2"
62-
LogDensityProblemsAD = "1.7.0"
6362
MCMCChains = "6"
6463
MacroTools = "0.5.6"
6564
Mooncake = "0.4.95"

docs/src/api.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ logjoint
5454

5555
### LogDensityProblems.jl interface
5656

57-
The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface is also supported by simply wrapping a [`Model`](@ref) in a `DynamicPPL.LogDensityFunction`:
57+
The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface is also supported by wrapping a [`Model`](@ref) in a `DynamicPPL.LogDensityFunction`.
5858

5959
```@docs
6060
DynamicPPL.LogDensityFunction

ext/DynamicPPLForwardDiffExt.jl

+29-43
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,40 @@
11
module DynamicPPLForwardDiffExt
22

3-
if isdefined(Base, :get_extension)
4-
using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD
5-
using ForwardDiff
6-
else
7-
using ..DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD
8-
using ..ForwardDiff
9-
end
10-
11-
getchunksize(::ADTypes.AutoForwardDiff{chunk}) where {chunk} = chunk
12-
13-
standardtag(::ADTypes.AutoForwardDiff{<:Any,Nothing}) = true
14-
standardtag(::ADTypes.AutoForwardDiff) = false
15-
16-
function LogDensityProblemsAD.ADgradient(
17-
ad::ADTypes.AutoForwardDiff, ℓ::DynamicPPL.LogDensityFunction
18-
)
19-
θ = DynamicPPL.getparams(ℓ)
20-
f = Base.Fix1(LogDensityProblems.logdensity, ℓ)
21-
22-
# Define configuration for ForwardDiff.
23-
tag = if standardtag(ad)
24-
ForwardDiff.Tag(DynamicPPL.DynamicPPLTag(), eltype(θ))
3+
using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems
4+
using ForwardDiff
5+
6+
# check if the AD type already has a tag
7+
use_dynamicppl_tag(::ADTypes.AutoForwardDiff{<:Any,Nothing}) = true
8+
use_dynamicppl_tag(::ADTypes.AutoForwardDiff) = false
9+
10+
function DynamicPPL.tweak_adtype(
11+
ad::ADTypes.AutoForwardDiff{chunk_size},
12+
::DynamicPPL.Model,
13+
vi::DynamicPPL.AbstractVarInfo,
14+
::DynamicPPL.AbstractContext,
15+
) where {chunk_size}
16+
params = vi[:]
17+
18+
# Use DynamicPPL tag to improve stack traces
19+
# https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/
20+
# NOTE: DifferentiationInterface disables tag checking if the
21+
# tag inside the AutoForwardDiff type is not nothing. See
22+
# https://github.com/JuliaDiff/DifferentiationInterface.jl/blob/1df562180bdcc3e91c885aa5f4162a0be2ced850/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl#L338-L350.
23+
# So we don't currently need to override ForwardDiff.checktag as well.
24+
tag = if use_dynamicppl_tag(ad)
25+
ForwardDiff.Tag(DynamicPPL.DynamicPPLTag(), eltype(params))
2526
else
26-
ForwardDiff.Tag(f, eltype(θ))
27+
ad.tag
2728
end
28-
chunk_size = getchunksize(ad)
29+
30+
# Optimise chunk size according to size of model
2931
chunk = if chunk_size == 0 || chunk_size === nothing
30-
ForwardDiff.Chunk(θ)
32+
ForwardDiff.Chunk(params)
3133
else
32-
ForwardDiff.Chunk(length(θ), chunk_size)
34+
ForwardDiff.Chunk(length(params), chunk_size)
3335
end
3436

35-
return LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; chunk, tag, x=θ)
36-
end
37-
38-
# Allow Turing tag in gradient etc. calls of the log density function
39-
function ForwardDiff.checktag(
40-
::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}},
41-
::DynamicPPL.LogDensityFunction,
42-
::AbstractArray{W},
43-
) where {V,W}
44-
return true
45-
end
46-
function ForwardDiff.checktag(
47-
::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}},
48-
::Base.Fix1{typeof(LogDensityProblems.logdensity),<:DynamicPPL.LogDensityFunction},
49-
::AbstractArray{W},
50-
) where {V,W}
51-
return true
37+
return ADTypes.AutoForwardDiff(; chunksize=ForwardDiff.chunksize(chunk), tag=tag)
5238
end
5339

5440
end # module

src/DynamicPPL.jl

-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ using MacroTools: MacroTools
1414
using ConstructionBase: ConstructionBase
1515
using Accessors: Accessors
1616
using LogDensityProblems: LogDensityProblems
17-
using LogDensityProblemsAD: LogDensityProblemsAD
1817

1918
using LinearAlgebra: LinearAlgebra, Cholesky
2019

src/contexts.jl

+1
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ at which point it will return the sampler of that context.
184184
getsampler(context::SamplingContext) = context.sampler
185185
getsampler(context::AbstractContext) = getsampler(NodeTrait(context), context)
186186
getsampler(::IsParent, context::AbstractContext) = getsampler(childcontext(context))
187+
getsampler(::IsLeaf, ::AbstractContext) = error("No sampler found in context")
187188

188189
"""
189190
struct DefaultContext <: AbstractContext end

0 commit comments

Comments
 (0)