-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generating Julia function for log density evaluation #278
base: main
Are you sure you want to change the base?
Conversation
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Pull Request Test Coverage Report for Build 14398731557Details
💛 - Coveralls |
…nction into BUGSModel
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
I don't understand why it is related to Mooncake. The above example doesn't use |
benchmark code for gradient is added (see first code block of #278 (comment)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot reviewed 18 out of 35 changed files in this pull request and generated no comments.
Files not reviewed (17)
- benchmark/benchmark.jl: Language not supported
- benchmark/juliabugs.jl: Language not supported
- benchmark/run_benchmarks.jl: Language not supported
- src/BUGSExamples/Volume_3/03_Fire.jl: Language not supported
- src/BUGSExamples/Volume_3/04_Circle.jl: Language not supported
- src/BUGSExamples/Volume_3/04_FunShapes.jl: Language not supported
- src/BUGSExamples/Volume_3/04_HollowSquare.jl: Language not supported
- src/BUGSExamples/Volume_3/04_Parallelogram.jl: Language not supported
- src/BUGSExamples/Volume_3/04_Ring.jl: Language not supported
- src/BUGSExamples/Volume_3/04_SquareMinusCircle.jl: Language not supported
- src/BUGSExamples/Volume_3/05_Hepatitis.jl: Language not supported
- src/BUGSExamples/Volume_3/05_Hepatitis_ME.jl: Language not supported
- src/BUGSExamples/Volume_3/06_Hips1.jl: Language not supported
- src/BUGSExamples/Volume_3/07_Hips2.jl: Language not supported
- src/BUGSExamples/Volume_3/08_Hips3.jl: Language not supported
- src/BUGSExamples/Volume_3/09_Hips4.jl: Language not supported
- src/BUGSExamples/Volume_3/11_PigWeights.jl: Language not supported
@penelopeysm @mhauru, can you review this PR to the best of your ability? I'll take a look more carefully later. |
Thanks @penelopeysm , just to reiterate, I know the amount of lines changed is very large. But only https://github.com/TuringLang/JuliaBUGS.jl/blob/sunxd/source_gen/src/source_gen.jl and its test https://github.com/TuringLang/JuliaBUGS.jl/blob/sunxd/source_gen/test/source_gen.jl are worth looking. Others are either low risk or updates to benchmarking that's no super relevant. |
I didn't look at any of the Julia files because I don't know how to, and I have no comments on the other files! Just kidding. I'll do my best haha |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't yet looked at the code, but I've given the documentation a read through. It's very good! I just put in a couple of comments on places where I still had a bit of confusion. I'll look at the source code another time.
All subsequent iterations follow the same pattern and have the same dependence vector of $(1)$. Because all dependence vectors are lexicographically non-negative, the loop is sequentially valid. | ||
|
||
This requires storing the loop variable `i` for each variable, but we already computed this with JuliaBUGS compilation, so not much overhead is required. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To make sure I'm understanding this properly: as part of the compilation, you step through the loop and on each iteration you calculate the dependency vector for the variables in the loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, we don't compute the dependence vector. But we store the values of loop inductions variables for each variables in the model.
end | ||
``` | ||
|
||
We made a simple change to the program to prepare for lowering: we need to distinguish between observations and model parameters (because they correspond to different code). We introduce a new operator into the program `\eqsim` to indicate that the left hand side is an observation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Personally, I don't think I'm entirely comfortable with eqsim; I think it looks too much like an ordinary equals, and also it's hard to type.
I think something like ~=
would be better for my eyes - would you be ok with that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
~=
is a good proposal, but I don't think Julia parser allows it...
julia> :(a ~= Normal(0, 1))
ERROR: ParseError:
# Error @ REPL[1]:1:6
:(a ~= Normal(0, 1))
# ╙ ── unexpected `=`
Stacktrace:
[1] top-level scope
@ REPL:1
All subsequent iterations follow the same pattern and have the same dependence vector of $(1)$. Because all dependence vectors are lexicographically non-negative, the loop is sequentially valid. | ||
|
||
This requires storing the loop variable `i` for each variable, but we already computed this with JuliaBUGS compilation, so not much overhead is required. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, we don't compute the dependence vector. But we store the values of loop inductions variables for each variables in the model.
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only thing I really think is important is the macro hygiene one!
abstract type EvaluationMode end | ||
|
||
struct UseGeneratedLogDensityFunction <: EvaluationMode end | ||
struct UseGraph <: EvaluationMode end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this pattern very much (it's something like sum types / enums) and it makes me sad that it doesn't get used often enough in Julia
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, any reason why you didn't use @enum
here? Just curious.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No particular reason, just didn't come to my mind at the time.
I don't think enum
is more elegant than subtyping for this case (not saying it's worse either). Do you have a preference?
@@ -378,6 +429,16 @@ function settrans(model::BUGSModel, bool::Bool=!(model.transformed)) | |||
return BangBang.setproperty!!(model, :transformed, bool) | |||
end | |||
|
|||
function set_evaluation_mode(model::BUGSModel, mode::EvaluationMode) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be set_evaluation_mode!!
? I'm still not 100% sure when we use !
or !!
(I suspect DynamicPPL isn't super consistent on this either)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case, because it never modify the original object, I didn't use any !
.
In my mind, !!
means "mutate if possible". Then if no way of mutating anyway, maybe okay to not use !
at all?
test/source_gen.jl
Outdated
for example_name in test_examples | ||
model, evaluation_env = _create_model(example_name) | ||
bugs_models[example_name] = model | ||
evaluation_envs[example_name] = evaluation_env | ||
lowered_model_def, reconstructed_model_def = _generate_lowered_model_def( | ||
model.model_def, model.g, evaluation_env | ||
) | ||
log_density_computation_expr = _gen_log_density_computation_function_expr( | ||
lowered_model_def, evaluation_env, gensym(example_name) | ||
) | ||
log_density_computation_functions[example_name] = eval(log_density_computation_expr) | ||
reconstructed_model_defs[example_name] = reconstructed_model_def | ||
end | ||
|
||
@testset "source_gen: $example_name" for example_name in test_examples | ||
model_with_consistent_sorted_nodes = _create_bugsmdoel_with_consistent_sorted_nodes( | ||
bugs_models[example_name], reconstructed_model_defs[example_name] | ||
) | ||
result_with_old_model = JuliaBUGS.evaluate!!(bugs_models[example_name])[2] | ||
params = JuliaBUGS.getparams(model_with_consistent_sorted_nodes) | ||
result_with_bugsmodel = JuliaBUGS.evaluate!!( | ||
model_with_consistent_sorted_nodes, params | ||
)[2] | ||
result_with_log_density_computation_function = log_density_computation_functions[example_name]( | ||
evaluation_envs[example_name], params | ||
) | ||
@test result_with_old_model ≈ result_with_bugsmodel | ||
@test result_with_log_density_computation_function ≈ result_with_bugsmodel | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could the for loop and the testset be combined here? It seems to me that you don't need to construct the dictionaries here since each iteration just processes one symbol at a time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so unfortunately. Because I am generating function and eval them. If combined, there are world-age problems.
MacroTools.@capture(stmt, lhs_ ~ rhs_) | ||
return MacroTools.@q begin | ||
__dist__ = $rhs | ||
__b__ = Bijectors.bijector(__dist__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a particular reason why you use the double-underscore here? My impression is that it'd be safer to gensym these variables, because otherwise it's leaky - if someone defines a model with a parameter called __logp__
for example it would interfere with this. For example, this is what DynamicPPL does:
MacroTools.@capture(stmt, lhs_ ~ rhs_) | |
return MacroTools.@q begin | |
__dist__ = $rhs | |
__b__ = Bijectors.bijector(__dist__) | |
MacroTools.@capture(stmt, lhs_ ~ rhs_) | |
@gensym dist b | |
return MacroTools.@q begin | |
$dist = $rhs | |
$b = Bijectors.bijector($dist) |
(The suggestion extends to the other variables in this Expr.)
I did consider that maybe your intention was to expose some of these variables to the user, like logp? That's what we do with __varinfo__
and __context__
in DynamicPPL. But I guessed that probably not all of these should be exposed haha
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Link to relevant section in DynamicPPL: https://github.com/TuringLang/DynamicPPL.jl/blob/019e41b341ccfb04b8467e50bee6918fbe1d74c0/src/compiler.jl#L436)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I realised maybe you can't do this with __logp__
because it needs to be 'carried over' between different statements (or at the least, it'd be annoying because you'd have to gensym it outside of this function and then pass it in, to ensure that it was the same variable being used all the time). But it seems to me that the rest can be gensym'd!
If you don't gensym logp, then I'd document somewhere that it's a special/reserved variable name and you shouldn't use it (or alternatively, maybe better, issue a warning).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it was more for ease of programming. Using gensym
is much safer, but need to thread all the generated names to sub-functions, which is bit annoying.
Regarding the use of double underscores, I always thought it's convention to reserve using underscore as internal
. But I just realize this might only apply to function names?
This is a great catch! I'll add a test for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is done now
Co-authored-by: Penelope Yong <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Benchmark results on macOS (aarch64)BridgeStan not found at location specified by $BRIDGESTAN environment variable, downloading version 2.6.1 to /Users/runner/.bridgestan/bridgestan-2.6.1 Stan results:
JuliaBUGS Mooncake results:
Benchmark results on Ubuntu (x64)BridgeStan not found at location specified by $BRIDGESTAN environment variable, downloading version 2.6.1 to /home/runner/.bridgestan/bridgestan-2.6.1 Stan results:
JuliaBUGS Mooncake results:
|
FYI, I switched to directly using the generated function for benchmarking instead of through |
Motivation
JuliaBUGS compiles BUGS programs into a directed probabilistic graphical model (PGM), which implicitly defines the dependency structure between variables. While this graph allows for execution strategies like topological traversal and parallelization, a significant challenge arises from the BUGS language semantics: every element within an array can be treated as an individual random variable.
This fine-grained dependency structure means that a naive way to generate Julia source based on the variable-level graph would often require fully unrolling all loops. This approach is infeasible, especially for large datasets or complex models, and poses significant difficulties for automatic differentiation (AD) tools to analyze the program.
Proposed Changes
This PR introduces an initial implementation for generating a specialized Julia function dedicated to computing the log density of the model. The core idea is to operate on a higher level of abstraction than individual variable nodes.
The algorithm proceeds as follows:
=
).~
).≂
).The generated function takes a flattened vector of parameter values and reconstructs them using
Bijectors.jl
to handle constraints and compute log Jacobian adjustments, accumulating the log-prior and log-likelihood terms.Example: Rats Model
Consider the classic "Rats" example:
Original BUGS Code:
Statement Dependence Graph:
(Note: Mermaid graph slightly adjusted for clarity based on variable dependencies)
Sequential Representation (after Topological Sort & Loop Fission):
This intermediate representation reflects the order determined by the statement graph dependencies and separates the original nested loops.
Generated Julia Log Density Function:
This function takes the model environment (
__evaluation_env__
) and flattened parameters (__flattened_values__
), computes the log density (__logp__
), and handles necessary transformations viaBijectors.jl
.Performance
The generated function demonstrates significant performance improvements and eliminates allocations compared to evaluating the log density through the generic
LogDensityProblems.logdensity
interface, which involves more overhead:LogDensityProblems.logdensity
Benchmark:Directly Generated Function Benchmark:
Gradient Performance (using Mooncake AD on the generated function):
The generated function structure is also amenable to AD, yielding efficient gradient computations:
which is on par with Stan.