Skip to content

Commit 616a07f

Browse files
sunxd3yebai
andauthored
Move ad.jl to DynamicPPL (#2158)
* `getADBackend`s are renamed to `getADType` and moved to `Inference` module as well as `LogDensityProblemsAD.ADgradient(ℓ::LogDensityFunction)` * The ` LogDensityProblemsAD.ADgradient(adtype, ℓ)` specific to RD and FD are moved to DynamicPPL * The idea is that with DynamicPPL, call to `ADgradient` must also gives the `adtype`, in Turing, we just use the `adtype` from the algorithm --------- Co-authored-by: Hong Ge <[email protected]>
1 parent cf647b1 commit 616a07f

File tree

7 files changed

+23
-71
lines changed

7 files changed

+23
-71
lines changed

HISTORY.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
# Release 0.30.5
2+
3+
- `essential/ad.jl` is removed, `ForwardDiff` and `ReverseDiff` integrations via `LogDensityProblemsAD` are moved to `DynamicPPL` and live in corresponding package extensions.
4+
- `LogDensityProblemsAD.ADgradient(ℓ::DynamicPPL.LogDensityFunction)` (i.e. the single argument method) is moved to `Inference` module. It will create `ADgradient` using the `adtype` information stored in `context` field of ``.
5+
- `getADbackend` function is renamed to `getADType`, the interface is preserved, but packages that previously used `getADbackend` should be updated to use `getADType`.
6+
- `TuringTag` for ForwardDiff is also removed, now `DynamicPPLTag` is defined in `DynamicPPL` package and should serve the same [purpose](https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/).
7+
18
# Release 0.30.0
29

310
- [`ADTypes.jl`](https://github.com/SciML/ADTypes.jl) replaced Turing's global AD backend. Users should now specify the desired `ADType` directly in sampler constructors, e.g., `HMC(0.1, 10; adtype=AutoForwardDiff(; chunksize))`, or `HMC(0.1, 10; adtype=AutoReverseDiff(false))` (`false` indicates not to use compiled tape).

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.30.4"
3+
version = "0.30.5"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -58,7 +58,7 @@ Distributions = "0.23.3, 0.24, 0.25"
5858
DistributionsAD = "0.6"
5959
DocStringExtensions = "0.8, 0.9"
6060
DynamicHMC = "3.4"
61-
DynamicPPL = "0.24"
61+
DynamicPPL = "0.24.7"
6262
EllipticalSliceSampling = "0.5, 1, 2"
6363
ForwardDiff = "0.10.3"
6464
Libtask = "0.7, 0.8"

src/Turing.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,6 @@ function setprogress!(progress::Bool)
3232
return progress
3333
end
3434

35-
# Standard tag: Improves stacktraces
36-
# Ref: https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/
37-
struct TuringTag end
38-
39-
# Allow Turing tag in gradient etc. calls of the log density function
40-
ForwardDiff.checktag(::Type{ForwardDiff.Tag{TuringTag, V}}, ::LogDensityFunction, ::AbstractArray{V}) where {V} = true
41-
ForwardDiff.checktag(::Type{ForwardDiff.Tag{TuringTag, V}}, ::Base.Fix1{typeof(LogDensityProblems.logdensity),<:LogDensityFunction}, ::AbstractArray{V}) where {V} = true
42-
4335
# Random probability measures.
4436
include("stdlib/distributions.jl")
4537
include("stdlib/RandomMeasures.jl")

src/essential/Essential.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,8 @@ using StatsFuns: logsumexp, softmax
1414
using ADTypes: ADTypes, AutoForwardDiff, AutoTracker, AutoReverseDiff, AutoZygote
1515

1616
import AdvancedPS
17-
import LogDensityProblems
18-
import LogDensityProblemsAD
1917

2018
include("container.jl")
21-
include("ad.jl")
2219

2320
export @model,
2421
@varname,

src/essential/ad.jl

Lines changed: 0 additions & 45 deletions
This file was deleted.

src/mcmc/Inference.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ import AdvancedHMC; const AHMC = AdvancedHMC
2929
import AdvancedMH; const AMH = AdvancedMH
3030
import AdvancedPS
3131
import BangBang
32-
import ..Essential: getADbackend
3332
import EllipticalSliceSampling
3433
import LogDensityProblems
3534
import LogDensityProblemsAD
@@ -78,7 +77,6 @@ abstract type ParticleInference <: InferenceAlgorithm end
7877
abstract type Hamiltonian <: InferenceAlgorithm end
7978
abstract type StaticHamiltonian <: Hamiltonian end
8079
abstract type AdaptiveHamiltonian <: Hamiltonian end
81-
getADbackend(alg::Hamiltonian) = alg.adtype
8280

8381
"""
8482
ExternalSampler{S<:AbstractSampler}
@@ -98,6 +96,20 @@ Wrap a sampler so it can be used as an inference algorithm.
9896
"""
9997
externalsampler(sampler::AbstractSampler) = ExternalSampler(sampler)
10098

99+
getADType(spl::Sampler) = getADType(spl.alg)
100+
getADType(::SampleFromPrior) = AutoForwardDiff(; chunksize=0)
101+
102+
getADType(ctx::DynamicPPL.SamplingContext) = getADType(ctx.sampler)
103+
getADType(ctx::DynamicPPL.AbstractContext) = getADType(DynamicPPL.NodeTrait(ctx), ctx)
104+
getADType(::DynamicPPL.IsLeaf, ctx::DynamicPPL.AbstractContext) = AutoForwardDiff(; chunksize=0)
105+
getADType(::DynamicPPL.IsParent, ctx::DynamicPPL.AbstractContext) = getADType(DynamicPPL.childcontext(ctx))
106+
107+
getADType(alg::Hamiltonian) = alg.adtype
108+
109+
function LogDensityProblemsAD.ADgradient(ℓ::DynamicPPL.LogDensityFunction)
110+
return LogDensityProblemsAD.ADgradient(getADType(ℓ.context), ℓ)
111+
end
112+
101113
function LogDensityProblems.logdensity(
102114
f::Turing.LogDensityFunction{<:AbstractVarInfo,<:Model,<:DynamicPPL.DefaultContext},
103115
x::NamedTuple

test/essential/ad.jl

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -165,17 +165,6 @@
165165

166166
end
167167

168-
@testset "tag" begin
169-
for chunksize in (0, 1, 10)
170-
ad = Turing.AutoForwardDiff(; chunksize=chunksize)
171-
@test ad === Turing.AutoForwardDiff(; chunksize=chunksize)
172-
@test Turing.Essential.standardtag(ad)
173-
for standardtag in (false, 0, 1)
174-
@test !Turing.Essential.standardtag(Turing.AutoForwardDiff(; chunksize=chunksize, tag=standardtag))
175-
end
176-
end
177-
end
178-
179168
@testset "ReverseDiff compiled without linking" begin
180169
f = DynamicPPL.LogDensityFunction(gdemo_default)
181170
θ = DynamicPPL.getparams(f)

0 commit comments

Comments
 (0)