Skip to content

Commit b393c54

Browse files
committed
Handle callable structs
1 parent 2c82159 commit b393c54

File tree

2 files changed

+52
-7
lines changed

2 files changed

+52
-7
lines changed

src/copyable_task.jl

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ function build_callable(sig::Type{<:Tuple})
3939
return mc, refs[end]
4040
end
4141

42-
mutable struct TapedTask{Tdynamic_scope,Targs,Tmc<:MistyClosure}
42+
mutable struct TapedTask{Tdynamic_scope,Tfargs,Tmc<:MistyClosure}
4343
dynamic_scope::Tdynamic_scope
44-
args::Targs
44+
fargs::Tfargs
4545
const mc::Tmc
4646
const position::Base.RefValue{Int32}
4747
end
@@ -165,7 +165,7 @@ julia> consume(t)
165165
function TapedTask(dynamic_scope::Any, fargs...)
166166
seed_id!()
167167
mc, count_ref = build_callable(typeof(fargs))
168-
return TapedTask(dynamic_scope, fargs[2:end], mc, count_ref)
168+
return TapedTask(dynamic_scope, fargs, mc, count_ref)
169169
end
170170

171171
"""
@@ -199,7 +199,7 @@ called, it start execution from the entry point. If `consume` has previously bee
199199
`nothing` will be returned.
200200
"""
201201
@inline function consume(t::TapedTask)
202-
v = with(() -> t.mc(t.args...), dynamic_scope => t.dynamic_scope)
202+
v = with(() -> t.mc(t.fargs...), dynamic_scope => t.dynamic_scope)
203203
return v isa ProducedValue ? v[] : nothing
204204
end
205205

@@ -287,12 +287,45 @@ end
287287

288288
@inline Base.getindex(x::ProducedValue) = x.x
289289

290+
"""
291+
inc_args(stmt)
292+
293+
Increment by `1` the `n` field of any `Argument`s present in `stmt`.
294+
Used in `make_ad_stmts!`.
295+
"""
296+
inc_args(x::Expr) = Expr(x.head, map(__inc, x.args)...)
297+
inc_args(x::ReturnNode) = isdefined(x, :val) ? ReturnNode(__inc(x.val)) : x
298+
inc_args(x::IDGotoIfNot) = IDGotoIfNot(__inc(x.cond), x.dest)
299+
inc_args(x::IDGotoNode) = x
300+
function inc_args(x::IDPhiNode)
301+
new_values = Vector{Any}(undef, length(x.values))
302+
for n in eachindex(x.values)
303+
if isassigned(x.values, n)
304+
new_values[n] = __inc(x.values[n])
305+
end
306+
end
307+
return IDPhiNode(x.edges, new_values)
308+
end
309+
inc_args(::Nothing) = nothing
310+
inc_args(x::GlobalRef) = x
311+
inc_args(x::Core.PiNode) = Core.PiNode(__inc(x.val), __inc(x.typ))
312+
313+
__inc(x::Argument) = Argument(x.n + 1)
314+
__inc(x) = x
315+
290316
function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple}
291317

292318
# The location from which all state can be retrieved. Since we're using `OpaqueClosure`s
293319
# to implement `TapedTask`s, this appears via the first argument.
294320
refs_id = Argument(1)
295321

322+
# Increment all arguments by 1.
323+
for bb in ir.blocks, (n, inst) in enumerate(bb.insts)
324+
bb.insts[n] = CC.NewInstruction(
325+
inc_args(inst.stmt), inst.type, inst.info, inst.line, inst.flag
326+
)
327+
end
328+
296329
# Construct map between SSA IDs and their index in the state data structure and back.
297330
# Also construct a map from each ref index to its type. We only construct `Ref`s
298331
# for statements which return a value e.g. `IDGotoIfNot`s do not have a meaningful
@@ -778,7 +811,7 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple}
778811
# rather than nothing at all.
779812
new_argtypes = copy(ir.argtypes)
780813
refs = (_refs..., Ref{Int32}(-1))
781-
new_argtypes[1] = typeof(refs)
814+
new_argtypes = vcat(typeof(refs), copy(ir.argtypes))
782815

783816
# Return BBCode and the `Ref`s.
784817
return BBCode(new_bblocks, new_argtypes, ir.sptypes, ir.linetable, ir.meta), refs
@@ -830,7 +863,7 @@ end
830863

831864
function (l::LazyCallable)(args::Vararg{Any,N}) where {N}
832865
isdefined(l, :mc) || construct_callable!(l)
833-
return l.mc(args[2:end]...)
866+
return l.mc(args...)
834867
end
835868

836869
function construct_callable!(l::LazyCallable{sig}) where {sig}
@@ -853,5 +886,5 @@ function (dynamic_callable::DynamicCallable)(args::Vararg{Any,N}) where {N}
853886
callable = build_callable(sig)
854887
dynamic_callable.cache[sig] = callable
855888
end
856-
return callable[1](args[2:end]...)
889+
return callable[1](args...)
857890
end

src/test_utils.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ function test_cases()
8787
(dynamic_nested_outer_use_produced, Ref{Any}(nested_inner)),
8888
[true, 1],
8989
),
90+
Testcase("callable struct", nothing, (CallableStruct(5), 4), [5, 4, 9]),
9091
]
9192
end
9293

@@ -210,4 +211,15 @@ function dynamic_nested_outer_use_produced(f::Ref{Any})
210211
return nothing
211212
end
212213

214+
struct CallableStruct{T}
215+
x::T
216+
end
217+
218+
function (c::CallableStruct)(y)
219+
produce(c.x)
220+
produce(y)
221+
produce(c.x + y)
222+
return nothing
223+
end
224+
213225
end

0 commit comments

Comments
 (0)