Skip to content

Generating Julia function for log density evaluation #278

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 52 commits into from
Apr 17, 2025
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
1460ff8
add source_gen code
sunxd3 Feb 27, 2025
274ef06
Merge branch 'master' into sunxd/source_gen
sunxd3 Feb 27, 2025
574db5d
Apply suggestions from code review
sunxd3 Feb 27, 2025
e0bcab0
update CI, cosmic only
sunxd3 Feb 27, 2025
0ffd586
fix some corner cases: `JuliaBUGS.phi` and indices Int casting
sunxd3 Feb 28, 2025
2b6abfb
support missing data
sunxd3 Mar 1, 2025
1478d0e
update `else` handle function
sunxd3 Mar 4, 2025
99249cb
refactor source_gen to not using `model`, also integrate generated fu…
sunxd3 Mar 4, 2025
5e49fc7
Merge branch 'main' into sunxd/source_gen
sunxd3 Mar 4, 2025
98f5757
Apply suggestions from code review
sunxd3 Mar 4, 2025
9bbfb2d
Merge branch 'main' into sunxd/source_gen
sunxd3 Mar 9, 2025
09bb98d
fix some test errors
sunxd3 Mar 9, 2025
fca5470
Merge branch 'main' into sunxd/source_gen
sunxd3 Mar 14, 2025
85736df
update Project.oml
sunxd3 Mar 17, 2025
043c8a0
add preliminary doc
sunxd3 Mar 17, 2025
5be974a
improve benchmark code; fix bugs
sunxd3 Mar 18, 2025
661dcad
fix gibbs tests; use graph for evaluation by default
sunxd3 Mar 18, 2025
81269e2
fix vol3 examples
sunxd3 Mar 18, 2025
165c000
update examples
sunxd3 Mar 19, 2025
0ffb096
Update src/BUGSExamples/Volume_3/06_Hips1.jl
sunxd3 Mar 25, 2025
58787bd
Update src/BUGSExamples/Volume_3/08_Hips3.jl
sunxd3 Mar 25, 2025
333ddf0
Update src/BUGSExamples/Volume_3/08_Hips3.jl
sunxd3 Mar 25, 2025
c28f131
Update src/BUGSExamples/Volume_3/09_Hips4.jl
sunxd3 Mar 25, 2025
b1b2e9c
Update src/BUGSExamples/Volume_3/09_Hips4.jl
sunxd3 Mar 25, 2025
540d834
Update src/model.jl
sunxd3 Mar 25, 2025
382d25f
Update src/compiler_pass.jl
sunxd3 Mar 25, 2025
cd6be9c
remove Enzyme from test env
sunxd3 Mar 31, 2025
5a0b8b9
improve LogDensityProblems.logdensity
sunxd3 Mar 31, 2025
6d4f97a
add tricks
sunxd3 Mar 31, 2025
6599d4b
update to source_gen.md documentation
sunxd3 Mar 31, 2025
b04bdfa
update writing
sunxd3 Apr 1, 2025
8c280fa
Merge branch 'main' into sunxd/source_gen
sunxd3 Apr 1, 2025
852ae32
only use Mooncake for benchmark
sunxd3 Apr 1, 2025
8baa70a
Update benchmark/juliabugs.jl
sunxd3 Apr 1, 2025
8f76754
remove refer to Enzyme
sunxd3 Apr 1, 2025
979d1f0
update a util function
sunxd3 Apr 2, 2025
a12c899
add some descriptions to the code file
sunxd3 Apr 4, 2025
68ba13f
bump patch version -- the generated function is not exposed yet, but …
sunxd3 Apr 4, 2025
5cfd3f7
Update docs/src/source_gen.md
sunxd3 Apr 8, 2025
b5b9651
Update docs/src/source_gen.md
sunxd3 Apr 8, 2025
bbd2d77
make the example more clear
sunxd3 Apr 9, 2025
63d1877
Update src/BUGSExamples/Volume_3/09_Hips4.jl
sunxd3 Apr 9, 2025
2ecea06
applying code review comments
sunxd3 Apr 11, 2025
93b3936
Update src/source_gen.jl
sunxd3 Apr 11, 2025
5ea7b88
Merge branch 'main' into sunxd/source_gen
sunxd3 Apr 11, 2025
adc43c8
Update src/source_gen.jl
sunxd3 Apr 11, 2025
25e6c60
add code to check conflicted names
sunxd3 Apr 11, 2025
a1214bd
Apply suggestions from code review
sunxd3 Apr 11, 2025
1bf6162
switch to using generated function, instead of LogDensityProblems, fo…
sunxd3 Apr 11, 2025
f86d603
fix type error
sunxd3 Apr 11, 2025
e6b1b68
try fix again
sunxd3 Apr 11, 2025
2fc6a20
Update test/source_gen.jl
sunxd3 Apr 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions .github/workflows/Tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,20 @@ jobs:
test:
name: Julia ${{ matrix.version }} on ${{ matrix.os }} (${{ matrix.arch }})
runs-on: ${{ matrix.os }}
continue-on-error: ${{ matrix.version == 'nightly' }}
continue-on-error: ${{ matrix.version == 'pre' }}
strategy:
fail-fast: false
matrix:
version: ['1', '1.10', 'nightly']
os: [ubuntu-latest, windows-latest]
arch: [x64]
version:
- '1'
- 'min'
- 'pre'
os:
- ubuntu-latest
- windows-latest
arch:
- x64
include:
- version: '1'
- version: 'min'
os: ubuntu-latest
arch: x64
coverage: true
Expand Down Expand Up @@ -60,6 +65,11 @@ jobs:
env:
TEST_GROUP: "log_density"

- name: Running `source_gen` tests
uses: julia-actions/julia-runtest@v1
env:
TEST_GROUP: "source_gen"

- name: Running `gibbs` tests
uses: nick-fields/retry@v3
with:
Expand Down
25 changes: 19 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "JuliaBUGS"
uuid = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf"
version = "0.9"
version = "0.9.1"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down Expand Up @@ -39,11 +39,10 @@ MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
JuliaBUGSAdvancedHMCExt = ["AdvancedHMC", "MCMCChains"]
JuliaBUGSAdvancedMHExt = ["AdvancedMH", "MCMCChains"]
JuliaBUGSGraphMakieExt = ["GraphMakie", "GLMakie"]
JuliaBUGSGraphPlotExt = ["GraphPlot"]
JuliaBUGSMCMCChainsExt = ["MCMCChains"]
JuliaBUGSGraphPlotExt = "GraphPlot"
JuliaBUGSMCMCChainsExt = "MCMCChains"

[compat]
ADTypes = "1.6"
AbstractMCMC = "5"
AbstractPPL = "0.8.4, 0.9, 0.10, 0.11"
Accessors = "0.1"
Expand All @@ -52,6 +51,7 @@ AdvancedMH = "0.8"
BangBang = "0.4.1"
Bijectors = "0.13, 0.14, 0.15.5"
ChainRules = "1"
DifferentiationInterface = "0.6.42"
Distributions = "0.23.8, 0.24, 0.25"
Documenter = "0.27, 1"
GLMakie = "0.10, 0.11"
Expand All @@ -67,6 +67,7 @@ LogExpFunctions = "0.3"
MCMCChains = "6"
MacroTools = "0.5"
MetaGraphsNext = "0.6, 0.7"
Mooncake = "0.4"
OrderedCollections = "1"
PDMats = "0.10, 0.11"
Serialization = "1.10"
Expand All @@ -76,15 +77,27 @@ Statistics = "1.10"
julia = "1.10.8"

[extras]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["AbstractMCMC", "ADTypes", "AdvancedHMC", "AdvancedMH", "ChainRules", "MCMCChains", "LogDensityProblemsAD", "ReverseDiff", "Test"]
test = [
"AbstractMCMC",
"AdvancedHMC",
"AdvancedMH",
"ChainRules",
"DifferentiationInterface",
"LogDensityProblemsAD",
"MCMCChains",
"Mooncake",
"ReverseDiff",
"Test"
]
3 changes: 2 additions & 1 deletion benchmark/Project.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
BridgeStan = "c88b6f0a-829e-4b0b-94b7-f06ab5908f5a"
Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
JuliaBUGS = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Expand Down
10 changes: 4 additions & 6 deletions benchmark/benchmark.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
module Benchmark

using Pkg
Pkg.develop(; path=joinpath(@__DIR__, ".."))

using JuliaBUGS
using ADTypes
using ReverseDiff

using DifferentiationInterface
using Mooncake: Mooncake

using MetaGraphsNext
using BridgeStan
using StanLogDensityProblems
Expand Down Expand Up @@ -96,5 +96,3 @@ function _print_results_table(
backend=backend,
)
end

end
27 changes: 20 additions & 7 deletions benchmark/juliabugs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,28 @@ function _create_JuliaBUGS_model(model_name::Symbol)
return compile(model_def, data, inits)
end

function benchmark_JuliaBUGS_model(model::JuliaBUGS.BUGSModel)
ad_model = ADgradient(AutoReverseDiff(true), model)
# ! writing a _function_ to benchmark all models won't work because of world-age error

function benchmark_JuliaBUGS_model_with_Mooncake(model::JuliaBUGS.BUGSModel)
p = Base.Fix1(LogDensityProblems.logdensity, model)
backend = AutoMooncake(; config=nothing)
dim = LogDensityProblems.dimension(model)
params_values = JuliaBUGS.getparams(model)
density_time = Chairmarks.@be LogDensityProblems.logdensity($ad_model, $params_values)
density_and_gradient_time = Chairmarks.@be LogDensityProblems.logdensity_and_gradient(
$ad_model, $params_values
)
prep = prepare_gradient(p, backend, params_values)
density_time = Chairmarks.@be LogDensityProblems.logdensity($model, $params_values)
density_and_gradient_time = Chairmarks.@be gradient($p, $prep, $backend, $params_values)
return BenchmarkResult(:juliabugs, dim, density_time, density_and_gradient_time)
end

# writing a _function_ to benchmark all models won't work because of worldage error
# function benchmark_JuliaBUGS_model_with_Enzyme(model::JuliaBUGS.BUGSModel)
# f(params, model) = LogDensityProblems.logdensity(model, params)
# backend = AutoEnzyme()
# dim = LogDensityProblems.dimension(model)
# params_values = JuliaBUGS.getparams(model)
# prep = prepare_gradient(f, backend, params_values, Constant(model))
# density_time = Chairmarks.@be LogDensityProblems.logdensity($model, $params_values)
# density_and_gradient_time = Chairmarks.@be gradient(
# $f, $prep, $backend, $params_values, $(Constant(model))
# )
# return BenchmarkResult(:juliabugs_enzyme, dim, density_time, density_and_gradient_time)
# end
31 changes: 22 additions & 9 deletions benchmark/run_benchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,35 @@
include("benchmark.jl")
using OrderedCollections

examples_to_benchmark = [
:rats, :pumps, :bones, :oxford, :epil, :lsat, :schools, :beetles, :air
]

stan_results = Benchmark.benchmark_Stan_models(examples_to_benchmark)
stan_results = benchmark_Stan_models(examples_to_benchmark)

juliabugs_models = [
Benchmark._create_JuliaBUGS_model(model_name) for model_name in examples_to_benchmark
JuliaBUGS.set_evaluation_mode(
_create_JuliaBUGS_model(model_name), JuliaBUGS.UseGeneratedLogDensityFunction()
) for model_name in examples_to_benchmark
]
juliabugs_results = OrderedDict{Symbol,Benchmark.BenchmarkResult}()
juliabugs_results = OrderedDict{Symbol,BenchmarkResult}()
for (model_name, model) in zip(examples_to_benchmark, juliabugs_models)
@info "Benchmarking $model_name"
juliabugs_results[model_name] = Benchmark.benchmark_JuliaBUGS_model(model)
@info "Benchmarking $model_name with Mooncake"
juliabugs_results[model_name] = benchmark_JuliaBUGS_model_with_Mooncake(model)
end

# juliabugs_enzyme_results = OrderedDict{Symbol,BenchmarkResult}()
# for (model_name, model) in zip(examples_to_benchmark, juliabugs_models)
# @info "Benchmarking $model_name with Enzyme"
# try
# juliabugs_enzyme_results[model_name] = benchmark_JuliaBUGS_model_with_Enzyme(model)
# catch e
# @warn "Error benchmarking $model_name with Enzyme: $e"
# end
# end

println("### Stan results:")
Benchmark._print_results_table(stan_results; backend=Val(:markdown))
println("### JuliaBUGS results:")
Benchmark._print_results_table(juliabugs_results; backend=Val(:markdown))
_print_results_table(stan_results; backend=Val(:markdown))
println("### JuliaBUGS Mooncake results:")
_print_results_table(juliabugs_results; backend=Val(:markdown))
# println("### JuliaBUGS Enzyme results:")
# _print_results_table(juliabugs_enzyme_results; backend=Val(:markdown))
Loading