Skip to content

Commit 22f84e7

Browse files
authored
Reorganize code to keep code as IRCode slightly longer (#39885)
This moved the primary place where IRCode gets converted into CodeInfo into the transform_result_for_cache call, which is a sensible place for it to be since the primary reason we need to convert back to IRCode is to make it acceptable for storing in the global cache. The reason we might want to not perform the conversion, is that the conversion is slightly lossy, because it drops stmtinfo. However, in Cthulhu, I would like to keep the statement info around such that Cthulhu can present it to the user, even in optimized code (i.e. where inlining decided not to inline, I would still like Cthulhu to be able to introspect inference's original annotations). This change makes that possible. In an ideal world, we wouldn't have to do this at all for uncached code, but of course both code_typed and typeinf_ext do look at the code, even if it is uncached. Unfortunately, at the moment we don't really have a good way to indicate whether or not the code will be looked at, so there is a fallback path that always does the conversion if we decided not to do the caching. Some future refactoring can save some additional time here.
1 parent bf0364b commit 22f84e7

File tree

4 files changed

+88
-64
lines changed

4 files changed

+88
-64
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
3232
fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any}, @nospecialize(atype),
3333
sv::InferenceState, max_methods::Int = InferenceParams(interp).MAX_METHODS)
3434
if sv.params.unoptimize_throw_blocks && sv.currpc in sv.throw_blocks
35+
add_remark!(interp, sv, "Skipped call in throw block")
3536
return CallMeta(Any, false)
3637
end
3738
valid_worlds = WorldRange()
@@ -365,7 +366,7 @@ function const_prop_heuristic(interp::AbstractInterpreter, method::Method, mi::M
365366
if isdefined(code, :inferred) && !cache_inlineable
366367
cache_inf = code.inferred
367368
if !(cache_inf === nothing)
368-
cache_inlineable = inlining_policy(interp)(cache_inf)
369+
cache_inlineable = inlining_policy(interp)(cache_inf) !== nothing
369370
end
370371
end
371372
if !cache_inlineable

base/compiler/optimize.jl

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,18 @@ function default_inlining_policy(@nospecialize(src))
3232
if isa(src, CodeInfo) || isa(src, Vector{UInt8})
3333
src_inferred = ccall(:jl_ir_flag_inferred, Bool, (Any,), src)
3434
src_inlineable = ccall(:jl_ir_flag_inlineable, Bool, (Any,), src)
35-
return src_inferred && src_inlineable
35+
return src_inferred && src_inlineable ? src : nothing
3636
end
37-
return false
37+
if isa(src, OptimizationState) && isdefined(src, :ir)
38+
return src.src.inlineable ? src.ir : nothing
39+
end
40+
return nothing
3841
end
3942

4043
mutable struct OptimizationState
4144
linfo::MethodInstance
4245
src::CodeInfo
46+
ir::Any # Union{Nothing, IRCode}
4347
stmt_info::Vector{Any}
4448
mod::Module
4549
nargs::Int
@@ -54,7 +58,7 @@ mutable struct OptimizationState
5458
WorldView(code_cache(interp), frame.world),
5559
inlining_policy(interp))
5660
return new(frame.linfo,
57-
frame.src, frame.stmt_info, frame.mod, frame.nargs,
61+
frame.src, nothing, frame.stmt_info, frame.mod, frame.nargs,
5862
frame.sptypes, frame.slottypes, false,
5963
inlining)
6064
end
@@ -88,7 +92,7 @@ mutable struct OptimizationState
8892
WorldView(code_cache(interp), get_world_counter()),
8993
inlining_policy(interp))
9094
return new(linfo,
91-
src, stmt_info, inmodule, nargs,
95+
src, nothing, stmt_info, inmodule, nargs,
9296
sptypes_from_meth_instance(linfo), slottypes, false,
9397
inlining)
9498
end
@@ -100,6 +104,20 @@ function OptimizationState(linfo::MethodInstance, params::OptimizationParams, in
100104
return OptimizationState(linfo, src, params, interp)
101105
end
102106

107+
function ir_to_codeinf!(opt::OptimizationState)
108+
replace_code_newstyle!(opt.src, opt.ir, opt.nargs - 1)
109+
opt.ir = nothing
110+
let src = opt.src::CodeInfo
111+
widen_all_consts!(src)
112+
src.inferred = true
113+
# finish updating the result struct
114+
validate_code_in_debug_mode(opt.linfo, src, "optimized")
115+
return src
116+
end
117+
end
118+
119+
include("compiler/ssair/driver.jl")
120+
103121

104122
#############
105123
# constants #
@@ -152,7 +170,7 @@ function isinlineable(m::Method, me::OptimizationState, params::OptimizationPara
152170
end
153171
end
154172
if !inlineable
155-
inlineable = inline_worthy(me.src.code, me.src, me.sptypes, me.slottypes, params, union_penalties, cost_threshold + bonus)
173+
inlineable = inline_worthy(me.ir, params, union_penalties, cost_threshold + bonus)
156174
end
157175
return inlineable
158176
end
@@ -175,7 +193,7 @@ function stmt_affects_purity(@nospecialize(stmt), ir)
175193
end
176194

177195
# Convert IRCode back to CodeInfo and compute inlining cost and sideeffects
178-
function finish(opt::OptimizationState, params::OptimizationParams, ir, @nospecialize(result))
196+
function finish(interp::AbstractInterpreter, opt::OptimizationState, params::OptimizationParams, ir, @nospecialize(result))
179197
def = opt.linfo.def
180198
nargs = Int(opt.nargs) - 1
181199

@@ -225,7 +243,7 @@ function finish(opt::OptimizationState, params::OptimizationParams, ir, @nospeci
225243
end
226244
end
227245

228-
replace_code_newstyle!(opt.src, ir, nargs)
246+
opt.ir = ir
229247

230248
# determine and cache inlineability
231249
union_penalties = false
@@ -263,14 +281,15 @@ function finish(opt::OptimizationState, params::OptimizationParams, ir, @nospeci
263281
opt.src.inlineable = isinlineable(def, opt, params, union_penalties, bonus)
264282
end
265283
end
284+
266285
nothing
267286
end
268287

269288
# run the optimization work
270289
function optimize(interp::AbstractInterpreter, opt::OptimizationState, params::OptimizationParams, @nospecialize(result))
271290
nargs = Int(opt.nargs) - 1
272291
@timeit "optimizer" ir = run_passes(opt.src, nargs, opt)
273-
finish(opt, params, ir, result)
292+
finish(interp, opt, params, ir, result)
274293
end
275294

276295

@@ -296,7 +315,7 @@ plus_saturate(x::Int, y::Int) = max(x, y, x+y)
296315
# known return type
297316
isknowntype(@nospecialize T) = (T === Union{}) || isa(T, Const) || isconcretetype(widenconst(T))
298317

299-
function statement_cost(ex::Expr, line::Int, src::CodeInfo, sptypes::Vector{Any},
318+
function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptypes::Vector{Any},
300319
slottypes::Vector{Any}, union_penalties::Bool,
301320
params::OptimizationParams, error_path::Bool = false)
302321
head = ex.head
@@ -308,7 +327,7 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, sptypes::Vector{Any}
308327
if ftyp === IntrinsicFunction && farg isa SSAValue
309328
# if this comes from code that was already inlined into another function,
310329
# Consts have been widened. try to recover in simple cases.
311-
farg = src.code[farg.id]
330+
farg = isa(src, CodeInfo) ? src.code[farg.id] : src.stmts[farg.id][:inst]
312331
if isa(farg, GlobalRef) || isa(farg, QuoteNode) || isa(farg, IntrinsicFunction) || isexpr(farg, :static_parameter)
313332
ftyp = argextype(farg, src, sptypes, slottypes)
314333
end
@@ -350,7 +369,7 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, sptypes::Vector{Any}
350369
end
351370
return T_FFUNC_COST[fidx]
352371
end
353-
extyp = line == -1 ? Any : src.ssavaluetypes[line]
372+
extyp = line == -1 ? Any : argextype(SSAValue(line), src, sptypes, slottypes)
354373
if extyp === Union{}
355374
return 0
356375
end
@@ -361,7 +380,7 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, sptypes::Vector{Any}
361380
# run-time of the function, we omit them from
362381
# consideration. This way, non-inlined error branches do not
363382
# prevent inlining.
364-
extyp = line == -1 ? Any : src.ssavaluetypes[line]
383+
extyp = line == -1 ? Any : argextype(SSAValue(line), src, sptypes, slottypes)
365384
return extyp === Union{} ? 0 : 20
366385
elseif head === :(=)
367386
if ex.args[1] isa GlobalRef
@@ -386,31 +405,32 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, sptypes::Vector{Any}
386405
return 0
387406
end
388407

389-
function statement_or_branch_cost(@nospecialize(stmt), line::Int, src::CodeInfo, sptypes::Vector{Any},
408+
function statement_or_branch_cost(@nospecialize(stmt), line::Int, src::Union{CodeInfo, IRCode}, sptypes::Vector{Any},
390409
slottypes::Vector{Any}, union_penalties::Bool, params::OptimizationParams,
391410
throw_blocks::Union{Nothing,BitSet})
392411
thiscost = 0
412+
dst(tgt) = isa(src, IRCode) ? first(src.cfg.blocks[tgt].stmts) : tgt
393413
if stmt isa Expr
394414
thiscost = statement_cost(stmt, line, src, sptypes, slottypes, union_penalties, params,
395415
throw_blocks !== nothing && line in throw_blocks)::Int
396416
elseif stmt isa GotoNode
397417
# loops are generally always expensive
398418
# but assume that forward jumps are already counted for from
399419
# summing the cost of the not-taken branch
400-
thiscost = stmt.label < line ? 40 : 0
420+
thiscost = dst(stmt.label) < line ? 40 : 0
401421
elseif stmt isa GotoIfNot
402-
thiscost = stmt.dest < line ? 40 : 0
422+
thiscost = dst(stmt.dest) < line ? 40 : 0
403423
end
404424
return thiscost
405425
end
406426

407-
function inline_worthy(body::Array{Any,1}, src::CodeInfo, sptypes::Vector{Any}, slottypes::Vector{Any},
427+
function inline_worthy(ir::IRCode,
408428
params::OptimizationParams, union_penalties::Bool=false, cost_threshold::Integer=params.inline_cost_threshold)
409429
bodycost::Int = 0
410-
throw_blocks = params.unoptimize_throw_blocks ? find_throw_blocks(body) : nothing
411-
for line = 1:length(body)
412-
stmt = body[line]
413-
thiscost = statement_or_branch_cost(stmt, line, src, sptypes, slottypes, union_penalties, params, throw_blocks)
430+
throw_blocks = params.unoptimize_throw_blocks ? find_throw_blocks(ir.stmts.inst, RefValue(ir)) : nothing
431+
for line = 1:length(ir.stmts)
432+
stmt = ir.stmts[line][:inst]
433+
thiscost = statement_or_branch_cost(stmt, line, ir, ir.sptypes, ir.argtypes, union_penalties, params, throw_blocks)
414434
bodycost = plus_saturate(bodycost, thiscost)
415435
bodycost > cost_threshold && return false
416436
end
@@ -490,5 +510,3 @@ function renumber_ir_elements!(body::Vector{Any}, ssachangemap::Vector{Int}, lab
490510
end
491511
end
492512
end
493-
494-
include("compiler/ssair/driver.jl")

base/compiler/ssair/inlining.jl

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -696,13 +696,13 @@ function resolve_todo(todo::InliningTodo, state::InliningState)
696696
isconst, src = false, nothing
697697
if isa(spec.match, InferenceResult)
698698
let inferred_src = spec.match.src
699-
if isa(inferred_src, CodeInfo)
700-
isconst, src = false, inferred_src
701-
elseif isa(inferred_src, Const)
699+
if isa(inferred_src, Const)
702700
if !is_inlineable_constant(inferred_src.val)
703701
return compileable_specialization(state.et, spec.match)
704702
end
705703
isconst, src = true, quoted(inferred_src.val)
704+
else
705+
isconst, src = false, inferred_src
706706
end
707707
end
708708
else
@@ -724,17 +724,15 @@ function resolve_todo(todo::InliningTodo, state::InliningState)
724724
return ConstantCase(src)
725725
end
726726

727-
if src !== nothing && !state.policy(src)
728-
src = nothing
727+
if src !== nothing
728+
src = state.policy(src)
729729
end
730730

731731
if src === nothing
732732
return compileable_specialization(state.et, spec.match)
733733
end
734734

735-
if isa(src, CodeInfo) || isa(src, Vector{UInt8})
736-
737-
elseif isa(src, IRCode)
735+
if isa(src, IRCode)
738736
src = copy(src)
739737
end
740738

@@ -1028,7 +1026,7 @@ function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, info::InvokeCallIn
10281026
calltype = ir.stmts[idx][:type]
10291027

10301028
if !info.match.fully_covers
1031-
# XXX: We could union split this
1029+
# TODO: We could union split out the signature check and continue on
10321030
return nothing
10331031
end
10341032

base/compiler/typeinfer.jl

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,18 @@ function typeinf(interp::AbstractInterpreter, frame::InferenceState)
210210
end
211211
end
212212

213+
function finish!(interp::AbstractInterpreter, caller::InferenceResult)
214+
# If we didn't transform the src for caching, we may have to transform
215+
# it anyway for users like typeinf_ext. Do that here.
216+
opt = caller.src
217+
if may_optimize(interp) && opt isa OptimizationState
218+
if opt.ir !== nothing
219+
caller.src = ir_to_codeinf!(opt)
220+
end
221+
end
222+
return caller.src
223+
end
224+
213225
function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
214226
typeinf_nocycle(interp, frame) || return false # frame is now part of a higher cycle
215227
# with no active ip's, frame is done
@@ -235,30 +247,26 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
235247
frames[i].cached )
236248
for i in 1:length(frames) ]
237249
empty!(frames)
238-
if may_optimize(interp)
239-
for (caller, _, _) in results
240-
opt = caller.src
241-
if opt isa OptimizationState
242-
result_type = caller.result
243-
@assert !(result_type isa LimitedAccuracy)
244-
optimize(interp, opt, OptimizationParams(interp), result_type)
245-
finish(opt.src, interp)
246-
# finish updating the result struct
247-
validate_code_in_debug_mode(opt.linfo, opt.src, "optimized")
248-
if opt.const_api
249-
if result_type isa Const
250-
caller.src = result_type
251-
else
252-
@assert isconstType(result_type)
253-
caller.src = Const(result_type.parameters[1])
254-
end
255-
elseif opt.src.inferred
256-
caller.src = opt.src::CodeInfo # stash a copy of the code (for inlining)
250+
for (caller, _, _) in results
251+
opt = caller.src
252+
if may_optimize(interp) && opt isa OptimizationState
253+
result_type = caller.result
254+
@assert !(result_type isa LimitedAccuracy)
255+
optimize(interp, opt, OptimizationParams(interp), result_type)
256+
if opt.const_api
257+
# XXX: The work in ir_to_codeinf! is essentially wasted. The only reason
258+
# we're doing it is so that code_llvm can return the code
259+
# for the `return ...::Const` (which never runs anyway). We should do this
260+
# as a post processing step instead.
261+
ir_to_codeinf!(opt)
262+
if result_type isa Const
263+
caller.src = result_type
257264
else
258-
caller.src = nothing
265+
@assert isconstType(result_type)
266+
caller.src = Const(result_type.parameters[1])
259267
end
260-
caller.valid_worlds = opt.inlining.et.valid_worlds[]
261268
end
269+
caller.valid_worlds = opt.inlining.et.valid_worlds[]
262270
end
263271
end
264272
for (caller, edges, cached) in results
@@ -271,6 +279,7 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
271279
if cached
272280
cache_result!(interp, caller)
273281
end
282+
finish!(interp, caller)
274283
end
275284
return true
276285
end
@@ -349,7 +358,12 @@ function transform_result_for_cache(interp::AbstractInterpreter, linfo::MethodIn
349358
# If we decided not to optimize, drop the OptimizationState now.
350359
# External interpreters can override as necessary to cache additional information
351360
if inferred_result isa OptimizationState
352-
inferred_result = inferred_result.src
361+
opt = inferred_result
362+
if isa(opt.src, CodeInfo)
363+
inferred_result = ir_to_codeinf!(opt)
364+
else
365+
inferred_result = opt.src
366+
end
353367
end
354368
if inferred_result isa CodeInfo
355369
inferred_result.min_world = first(valid_worlds)
@@ -470,13 +484,6 @@ function finish(me::InferenceState, interp::AbstractInterpreter)
470484
nothing
471485
end
472486

473-
function finish(src::CodeInfo, interp::AbstractInterpreter)
474-
# convert all type information into the form consumed by the cache for inlining and code-generation
475-
widen_all_consts!(src)
476-
src.inferred = true
477-
nothing
478-
end
479-
480487
# record the backedges
481488
function store_backedges(frame::InferenceResult, edges::Vector{Any})
482489
toplevel = !isa(frame.linfo.def, Method)
@@ -839,9 +846,9 @@ function typeinf_code(interp::AbstractInterpreter, method::Method, @nospecialize
839846
frame === nothing && return (nothing, Any)
840847
if typeinf(interp, frame) && run_optimizer
841848
opt_params = OptimizationParams(interp)
842-
opt = OptimizationState(frame, opt_params, interp)
843-
optimize(interp, opt, opt_params, ignorelimited(result.result))
844-
opt.src.inferred = true
849+
result.src = OptimizationState(frame, opt_params, interp)
850+
optimize(interp, result.src, opt_params, ignorelimited(result.result))
851+
frame.src = finish!(interp, result)
845852
end
846853
ccall(:jl_typeinf_end, Cvoid, ())
847854
frame.inferred || return (nothing, Any)

0 commit comments

Comments
 (0)