Skip to content

Use NoCache to improve set_to_zero!! performance with Mooncake #975

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# DynamicPPL Changelog

## 0.36.15

Improved performance for some models with Mooncake.jl by using `NoCache` with `Mooncake.set_to_zero!!` for DynamicPPL types.

## 0.36.14

Added compatibility with [email protected].
Expand Down
101 changes: 101 additions & 0 deletions ext/DynamicPPLMooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,110 @@
module DynamicPPLMooncakeExt

__precompile__(false)

using DynamicPPL: DynamicPPL, istrans
using Mooncake: Mooncake
import Mooncake: set_to_zero!!
using Mooncake: NoTangent, Tangent, MutableTangent, NoCache, set_to_zero_internal!!

# This is purely an optimisation.
Mooncake.@zero_adjoint Mooncake.DefaultCtx Tuple{typeof(istrans),Vararg}

"""
Check if a tangent has the expected structure for a given type.
"""
function has_expected_structure(

Check warning on line 16 in ext/DynamicPPLMooncakeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMooncakeExt.jl#L16

Added line #L16 was not covered by tests
x, expected_type::Type{<:Union{Tangent,MutableTangent}}, expected_fields
)
x isa expected_type || return false
hasfield(typeof(x), :fields) || return false

Check warning on line 20 in ext/DynamicPPLMooncakeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMooncakeExt.jl#L19-L20

Added lines #L19 - L20 were not covered by tests

fields = x.fields
if expected_fields isa Tuple

Check warning on line 23 in ext/DynamicPPLMooncakeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMooncakeExt.jl#L22-L23

Added lines #L22 - L23 were not covered by tests
# Exact match required
propertynames(fields) == expected_fields || return false

Check warning on line 25 in ext/DynamicPPLMooncakeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMooncakeExt.jl#L25

Added line #L25 was not covered by tests
else
# All expected fields must be present
all(f in propertynames(fields) for f in expected_fields) || return false

Check warning on line 28 in ext/DynamicPPLMooncakeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMooncakeExt.jl#L28

Added line #L28 was not covered by tests
end

return true

Check warning on line 31 in ext/DynamicPPLMooncakeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMooncakeExt.jl#L31

Added line #L31 was not covered by tests
end

"""
Check if a tangent corresponds to a DynamicPPL.LogDensityFunction
"""
function is_dppl_ldf_tangent(x)
has_expected_structure(x, Tangent, (:model, :varinfo, :context, :adtype, :prep)) ||

Check warning on line 38 in ext/DynamicPPLMooncakeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMooncakeExt.jl#L37-L38

Added lines #L37 - L38 were not covered by tests
return false

fields = x.fields
is_dppl_varinfo_tangent(fields.varinfo) || return false
is_dppl_model_tangent(fields.model) || return false

Check warning on line 43 in ext/DynamicPPLMooncakeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMooncakeExt.jl#L41-L43

Added lines #L41 - L43 were not covered by tests

return true

Check warning on line 45 in ext/DynamicPPLMooncakeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMooncakeExt.jl#L45

Added line #L45 was not covered by tests
end

"""
Check if a tangent corresponds to a DynamicPPL.VarInfo
"""
function is_dppl_varinfo_tangent(x)
return has_expected_structure(x, Tangent, (:metadata, :logp, :num_produce))

Check warning on line 52 in ext/DynamicPPLMooncakeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMooncakeExt.jl#L51-L52

Added lines #L51 - L52 were not covered by tests
end

"""
Check if a tangent corresponds to a DynamicPPL.Model
"""
function is_dppl_model_tangent(x)
return has_expected_structure(x, Tangent, (:f, :args, :defaults, :context))

Check warning on line 59 in ext/DynamicPPLMooncakeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMooncakeExt.jl#L58-L59

Added lines #L58 - L59 were not covered by tests
end

"""
Check if a MutableTangent corresponds to DynamicPPL.Metadata
"""
function is_dppl_metadata_tangent(x)
return has_expected_structure(

Check warning on line 66 in ext/DynamicPPLMooncakeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMooncakeExt.jl#L65-L66

Added lines #L65 - L66 were not covered by tests
x, MutableTangent, (:idcs, :vns, :ranges, :vals, :dists, :orders, :flags)
)
end

"""
Check if a model function tangent represents a closure.
"""
function is_closure_model(model_f_tangent)
model_f_tangent isa MutableTangent && return true

Check warning on line 75 in ext/DynamicPPLMooncakeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMooncakeExt.jl#L74-L75

Added lines #L74 - L75 were not covered by tests

if model_f_tangent isa Tangent && hasfield(typeof(model_f_tangent), :fields)

Check warning on line 77 in ext/DynamicPPLMooncakeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMooncakeExt.jl#L77

Added line #L77 was not covered by tests
# Check if any field is a MutableTangent with PossiblyUninitTangent{Any}
for (_, fval) in pairs(model_f_tangent.fields)
if fval isa MutableTangent &&

Check warning on line 80 in ext/DynamicPPLMooncakeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMooncakeExt.jl#L79-L80

Added lines #L79 - L80 were not covered by tests
hasfield(typeof(fval), :fields) &&
hasfield(typeof(fval.fields), :contents) &&
fval.fields.contents isa Mooncake.PossiblyUninitTangent{Any}
return true

Check warning on line 84 in ext/DynamicPPLMooncakeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMooncakeExt.jl#L84

Added line #L84 was not covered by tests
end
end

Check warning on line 86 in ext/DynamicPPLMooncakeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMooncakeExt.jl#L86

Added line #L86 was not covered by tests
end

return false

Check warning on line 89 in ext/DynamicPPLMooncakeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMooncakeExt.jl#L89

Added line #L89 was not covered by tests
end

function Mooncake.set_to_zero!!(x)

Check warning on line 92 in ext/DynamicPPLMooncakeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMooncakeExt.jl#L92

Added line #L92 was not covered by tests
# Check for DynamicPPL types and use NoCache for better performance
if is_dppl_ldf_tangent(x)

Check warning on line 94 in ext/DynamicPPLMooncakeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMooncakeExt.jl#L94

Added line #L94 was not covered by tests
# Special handling for LogDensityFunction to detect closures
model_f_tangent = x.fields.model.fields.f
cache = is_closure_model(model_f_tangent) ? IdDict{Any,Bool}() : NoCache()
return set_to_zero_internal!!(cache, x)
elseif is_dppl_varinfo_tangent(x) ||

Check warning on line 99 in ext/DynamicPPLMooncakeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMooncakeExt.jl#L96-L99

Added lines #L96 - L99 were not covered by tests
is_dppl_model_tangent(x) ||
is_dppl_metadata_tangent(x)
# These types can always use NoCache
return set_to_zero_internal!!(NoCache(), x)

Check warning on line 103 in ext/DynamicPPLMooncakeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMooncakeExt.jl#L103

Added line #L103 was not covered by tests
else
# Use the original implementation with IdDict for all other types
return set_to_zero_internal!!(IdDict{Any,Bool}(), x)

Check warning on line 106 in ext/DynamicPPLMooncakeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMooncakeExt.jl#L106

Added line #L106 was not covered by tests
end
end

end # module
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down Expand Up @@ -45,6 +46,7 @@ JET = "0.9, 0.10"
LogDensityProblems = "2"
MCMCChains = "6.0.4, 7"
MacroTools = "0.5.6"
Mooncake = "0.4.137"
OrderedCollections = "1"
ReverseDiff = "1"
StableRNGs = "1"
Expand Down
193 changes: 190 additions & 3 deletions test/ext/DynamicPPLMooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,192 @@
using DynamicPPL
using Distributions
using Random
using Test
using StableRNGs
using Mooncake: Mooncake, NoCache, set_to_zero!!, set_to_zero_internal!!, zero_tangent
using DynamicPPL.TestUtils.AD: @be, median

# Define models globally to avoid closure issues
@model function test_model1(x)
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
return x .~ Normal(m, sqrt(s))
end

@model function test_model2(x, y)
τ ~ Gamma(1, 1)
σ ~ InverseGamma(2, 3)
μ ~ Normal(0, τ)
x .~ Normal(μ, σ)
return y .~ Normal(μ, σ)
end

@testset "DynamicPPLMooncakeExt" begin
Mooncake.TestUtils.test_rule(
StableRNG(123456), istrans, VarInfo(); unsafe_perturb=true, interface_only=true
)
@testset "istrans rule" begin
Mooncake.TestUtils.test_rule(
StableRNG(123456), istrans, VarInfo(); unsafe_perturb=true, interface_only=true
)
end

@testset "set_to_zero!! optimization" begin
# Test with a real DynamicPPL model
model = test_model1([1.0, 2.0, 3.0])
vi = VarInfo(Random.default_rng(), model)
ldf = LogDensityFunction(model, vi, DefaultContext())
tangent = zero_tangent(ldf)

# Test that set_to_zero!! works correctly
result = set_to_zero!!(deepcopy(tangent))
@test result isa typeof(tangent)

# Test with metadata - verify structure exists
if hasfield(typeof(tangent.fields.varinfo.fields), :metadata)
metadata = tangent.fields.varinfo.fields.metadata
@test !isnothing(metadata)
end
end

@testset "NoCache optimization correctness" begin
# Test that set_to_zero!! uses NoCache for DynamicPPL types
model = test_model1([1.0, 2.0, 3.0])
vi = VarInfo(Random.default_rng(), model)
ldf = LogDensityFunction(model, vi, DefaultContext())
tangent = zero_tangent(ldf)

# Modify some values
if hasfield(typeof(tangent.fields.model.fields), :args) &&
hasfield(typeof(tangent.fields.model.fields.args), :x)
x_tangent = tangent.fields.model.fields.args.x
if !isempty(x_tangent)
x_tangent[1] = 5.0
end
end

# Call set_to_zero!! and verify it works
set_to_zero!!(tangent)

# Check that values are zeroed
if hasfield(typeof(tangent.fields.model.fields), :args) &&
hasfield(typeof(tangent.fields.model.fields.args), :x)
x_tangent = tangent.fields.model.fields.args.x
if !isempty(x_tangent)
@test x_tangent[1] == 0.0
end
end
end

@testset "Performance improvement" begin
# Test with DEMO_MODELS if available
if isdefined(DynamicPPL.TestUtils, :DEMO_MODELS) &&
!isempty(DynamicPPL.TestUtils.DEMO_MODELS)
model = DynamicPPL.TestUtils.DEMO_MODELS[1]
else
# Fallback to our test model
model = test_model1([1.0, 2.0, 3.0, 4.0])
end

vi = VarInfo(Random.default_rng(), model)
ldf = LogDensityFunction(model, vi, DefaultContext())
tangent = zero_tangent(ldf)

# Run benchmarks
result_iddict = @be begin
cache = IdDict{Any,Bool}()
set_to_zero_internal!!(cache, tangent)
end

result_nocache = @be set_to_zero!!(tangent)

# Extract median times
time_iddict = median(result_iddict).time
time_nocache = median(result_nocache).time

# We expect NoCache to be faster
speedup = time_iddict / time_nocache
@test speedup > 1.5 # Conservative expectation - should be ~4x

# Sanity check
@info "Performance improvement" speedup time_iddict_μs = time_iddict / 1000 time_nocache_μs =
time_nocache / 1000
end

@testset "Aliasing safety" begin
# Test with aliased data
shared_data = [1.0, 2.0, 3.0]
model = test_model2(shared_data, shared_data) # x and y are the same array
vi = VarInfo(Random.default_rng(), model)
ldf = LogDensityFunction(model, vi, DefaultContext())
tangent = zero_tangent(ldf)

# Check that aliasing is preserved in tangent
if hasfield(typeof(tangent.fields.model.fields), :args)
args = tangent.fields.model.fields.args
if hasfield(typeof(args), :x) && hasfield(typeof(args), :y)
@test args.x === args.y # Aliasing should be preserved

# Modify via x
if !isempty(args.x)
args.x[1] = 10.0
@test args.y[1] == 10.0 # Should also change y
end

# Zero and check both are zeroed
# Since x and y are aliased, zeroing one zeros both
set_to_zero!!(tangent)
if !isempty(args.x)
@test args.x[1] == 0.0
@test args.y[1] == 0.0
end
end
end
end

@testset "Closure handling" begin
# Test that closure models are correctly handled

# Create closure model (captures environment, has circular references)
function create_closure_model()
local_var = 42
@model function closure_model(x)
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
return x .~ Normal(m, sqrt(s))
end
return closure_model
end

closure_fn = create_closure_model()
model_closure = closure_fn([1.0, 2.0, 3.0])
vi_closure = VarInfo(Random.default_rng(), model_closure)
ldf_closure = LogDensityFunction(model_closure, vi_closure, DefaultContext())
tangent_closure = zero_tangent(ldf_closure)

# Test that it works without stack overflow
@test_nowarn set_to_zero!!(deepcopy(tangent_closure))

# Compare with global model (no closure)
model_global = test_model1([1.0, 2.0, 3.0])
vi_global = VarInfo(Random.default_rng(), model_global)
ldf_global = LogDensityFunction(model_global, vi_global, DefaultContext())
tangent_global = zero_tangent(ldf_global)

# Verify model.f tangent types differ
f_tangent_closure = tangent_closure.fields.model.fields.f
f_tangent_global = tangent_global.fields.model.fields.f

@test f_tangent_global isa Mooncake.NoTangent # Global function
@test f_tangent_closure isa Mooncake.Tangent # Closure function

# Performance comparison
time_global = @elapsed for _ in 1:100
set_to_zero!!(tangent_global)
end

time_closure = @elapsed for _ in 1:100
set_to_zero!!(tangent_closure)
end

# Global should be faster (uses NoCache)
@test time_global < time_closure
end
end
Loading