Skip to content

Conversation

@torfjelde
Copy link
Member

@torfjelde torfjelde commented Jun 11, 2023

In #477 we introduced proper support for keyword arguments in the models.

As it turns out, this completely breaks integration with AdvancedPS.jl: TuringLang/Turing.jl#2001 and TuringLang/Libtask.jl#163.

But there is a way we can both support kwargs as in #477 and simultaneously preserving the current behavior of AdvancedPS.jl (though, as mentioned in the PR above, this still silently doesn't work for @submodel).

The idea is to transform the evaluator (not the constructor) from

function f(args...)
    ...
end

into the @generated

@generated function f(__kwargs__::NamedTuple{__names__}, args...) where {__names__}
    kwargs_unwrapped = [:($n = getproperty(__kwargs__, $n)) for n in __names__]
    return quote
        $(kwargs_unwrapped...)
        ...
    end
end

This is similar to how Base.kwcall but with the difference that:

  1. We change the actual definition of the method to have the keyword arguments inlined.
  2. We perform no checking as to whether the required keywords are present, etc.

(2) shouldn't really be a problem because we perform all of this in the constructor; the evaluator should never be called explicitly anyways.

The result is as follows:

julia> @model function demo(x; y=100)
           z ~ Normal(y, 1)
           x ~ Normal(z, 1)
       end
demo (generic function with 4 methods)

julia> model = demo(1)
Model{typeof(demo), (:x,), (:y,), (), Tuple{Int64}, Tuple{Int64}, DefaultContext}(demo, (x = 1,), (y = 100,), DefaultContext())

julia> model()
1

julia> rand(model)
(z = 98.5004327168935,)

But oh boy the expression is not looking fun:

julia> @macroexpand @model function demo(x; y=100)
           z ~ Normal(y, 1)
           x ~ Normal(z, 1)
       end
quote
    function demo(var"##kwargs#559"::NamedTuple{var"##names#560"}, __model__::Model, __varinfo__::AbstractVarInfo, __context__::AbstractPPL.AbstractContext, x; ) where var"##names#560"
        kwargs_unwrapped = Expr(:block)
        if $(Expr(:generated))
            local var"##tmp#561" = begin
                        kwargs_unwrapped = Expr(:block)
                        for n = var"##names#560"
                            push!(kwargs_unwrapped.args, Expr(:(=), n, Expr(:call, :getproperty, Symbol("##kwargs#559"), QuoteNode(n))))
                        end
                        return Expr(:block, kwargs_unwrapped, $(Expr(:copyast, :($(QuoteNode(quote
    #= REPL[55]:1 =#
    begin
        #= REPL[55]:1 =#
        #= REPL[55]:2 =#
        begin
            var"##dist#551" = Normal(y, 1)
            var"##vn#548" = (DynamicPPL.resolve_varnames)((VarName){:z}(), var"##dist#551")
            var"##isassumption#549" = begin
                    if (DynamicPPL.contextual_isassumption)(__context__, var"##vn#548")
                        if !((DynamicPPL.inargnames)(var"##vn#548", __model__)) || (DynamicPPL.inmissings)(var"##vn#548", __model__)
                            true
                        else
                            z === missing
                        end
                    else
                        false
                    end
                end
            if var"##isassumption#549"
                begin
                    (var"##value#552", __varinfo__) = (DynamicPPL.tilde_assume!!)(__context__, (DynamicPPL.unwrap_right_vn)((DynamicPPL.check_tilde_rhs)(var"##dist#551"), var"##vn#548")..., __varinfo__)
                    z = var"##value#552"
                    var"##value#552"
                end
            else
                if !((DynamicPPL.inargnames)(var"##vn#548", __model__))
                    z = (DynamicPPL.getvalue_nested)(__context__, var"##vn#548")
                end
                (var"##value#550", __varinfo__) = (DynamicPPL.tilde_observe!!)(__context__, (DynamicPPL.check_tilde_rhs)(var"##dist#551"), z, var"##vn#548", __varinfo__)
                var"##value#550"
            end
        end
        #= REPL[55]:3 =#
        begin
            #= /home/tor/Projects/public/DynamicPPL.jl/src/compiler.jl:487 =#
            var"##retval#558" = begin
                    var"##dist#556" = Normal(z, 1)
                    var"##vn#553" = (DynamicPPL.resolve_varnames)((VarName){:x}(), var"##dist#556")
                    var"##isassumption#554" = begin
                            if (DynamicPPL.contextual_isassumption)(__context__, var"##vn#553")
                                if !((DynamicPPL.inargnames)(var"##vn#553", __model__)) || (DynamicPPL.inmissings)(var"##vn#553", __model__)
                                    true
                                else
                                    x === missing
                                end
                            else
                                false
                            end
                        end
                    if var"##isassumption#554"
                        begin
                            (var"##value#557", __varinfo__) = (DynamicPPL.tilde_assume!!)(__context__, (DynamicPPL.unwrap_right_vn)((DynamicPPL.check_tilde_rhs)(var"##dist#556"), var"##vn#553")..., __varinfo__)
                            x = var"##value#557"
                            var"##value#557"
                        end
                    else
                        if !((DynamicPPL.inargnames)(var"##vn#553", __model__))
                            x = (DynamicPPL.getvalue_nested)(__context__, var"##vn#553")
                        end
                        (var"##value#555", __varinfo__) = (DynamicPPL.tilde_observe!!)(__context__, (DynamicPPL.check_tilde_rhs)(var"##dist#556"), x, var"##vn#553", __varinfo__)
                        var"##value#555"
                    end
                end
            #= /home/tor/Projects/public/DynamicPPL.jl/src/compiler.jl:488 =#
            return (var"##retval#558", __varinfo__)
        end
    end
end))))))
                    end
            if var"##tmp#561" isa Core.CodeInfo
                #= expr.jl:913 =#
                return var"##tmp#561"
            else
                #= expr.jl:913 =#
                var"##tmp#561"
            end
        else
            $(Expr(:meta, :generated_only))
            return
        end
    end
    begin
        $(Expr(:meta, :doc))
        function demo(x; y = 100)
            #= REPL[55]:1 =#
            return (Model)(demo, NamedTuple{(:x,)}((x,)); y)
        end
    end
end

Yeaaaah.

And it most certainly complicates the compiler.

All in all, I'm not entirely certain if this is the way to go. The problem is just that

  1. Prior to DPPL 0.23 we're not supporting stuff like kwargs..., which can be quite annoying.
  2. Pushing the current DPPL 0.23 into Turing.jl will completely break usage of kwargs for SMC samplers.
  3. Libtask.jl needs quite a bit of work to be able to overcome this issue.

In conclusion, I don't see a nice solution other than the above 😕

EDIT: There's also the "fun" effect that you can now also use kwargs to condition 🙃

julia> @model function demo(x; y=100, kwargs...)
           z ~ Normal(y, 1)
           x ~ Normal(z, 1)
       end
demo (generic function with 4 methods)

julia> model = demo(1; z=1);

julia> model()
1

julia> rand(model)
NamedTuple()

@torfjelde torfjelde requested a review from devmotion June 11, 2023 21:43
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@torfjelde
Copy link
Member Author

And not very unexpectedly, this breaks models involving eval, etc. 😕

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm certainly not a fan of @generated if it can be avoided, due to its limitations and impact on compilation times. Due to all the disadvantages you mentioned I'd really like to not merge this PR. Also partially fixing SMC with kwargs while breaking other use cases doesn't seem convincing.

Since the supported model classes for SMC will be limited even with this PR, maybe it's fine to just not support kwargs with SMC and PG until it is fixed upstream? We should just make sure that it's not silently broken - could we add a check in DynamicPPL that throws a descriptive error if kwargs are used with SMC or PG (I guess this might need a trait for LibTask-samplers)?

@torfjelde
Copy link
Member Author

Oh I'm also very much with you that this is not a great way to go about things. I put it here as a "this is a last resort bandaid but it's hacky as hell" (didn't put it in draft mode because I wanted to run tests).

Since the supported model classes for SMC will be limited even with this PR, maybe it's fine to just not support kwargs with SMC and PG until it is fixed upstream? We should just make sure that it's not silently broken - could we add a check in DynamicPPL that throws a descriptive error if kwargs are used with SMC or PG (I guess this might need a trait for LibTask-samplers)?

I'm honestly fine with this. We can also do it in an easier way: in the construction of the taped task we just check if kwargs are empty or not, and if they arent' we complain.

@yebai
Copy link
Member

yebai commented Jun 12, 2023

I'm honestly fine with this. We can also do it in an easier way: in the construction of the taped task we just check if kwargs are empty or not, and if they arent' we complain.

Also sounds sensible to me! It's a good idea to keep the compiler as simple as possible.

@yebai
Copy link
Member

yebai commented Jun 15, 2023

As discussed, we'd like to fix this from Libtask which avoids introducing complex code into the compiler.

@yebai yebai closed this Jun 15, 2023
@yebai yebai deleted the torfjelde/inline-kwargs-in-evaluator branch June 15, 2023 11:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants