@@ -32,14 +32,18 @@ function default_inlining_policy(@nospecialize(src))
32
32
if isa (src, CodeInfo) || isa (src, Vector{UInt8})
33
33
src_inferred = ccall (:jl_ir_flag_inferred , Bool, (Any,), src)
34
34
src_inlineable = ccall (:jl_ir_flag_inlineable , Bool, (Any,), src)
35
- return src_inferred && src_inlineable
35
+ return src_inferred && src_inlineable ? src : nothing
36
36
end
37
- return false
37
+ if isa (src, OptimizationState) && isdefined (src, :ir )
38
+ return src. src. inlineable ? src. ir : nothing
39
+ end
40
+ return nothing
38
41
end
39
42
40
43
mutable struct OptimizationState
41
44
linfo:: MethodInstance
42
45
src:: CodeInfo
46
+ ir:: Any # Union{Nothing, IRCode}
43
47
stmt_info:: Vector{Any}
44
48
mod:: Module
45
49
nargs:: Int
@@ -54,7 +58,7 @@ mutable struct OptimizationState
54
58
WorldView (code_cache (interp), frame. world),
55
59
inlining_policy (interp))
56
60
return new (frame. linfo,
57
- frame. src, frame. stmt_info, frame. mod, frame. nargs,
61
+ frame. src, nothing , frame. stmt_info, frame. mod, frame. nargs,
58
62
frame. sptypes, frame. slottypes, false ,
59
63
inlining)
60
64
end
@@ -88,7 +92,7 @@ mutable struct OptimizationState
88
92
WorldView (code_cache (interp), get_world_counter ()),
89
93
inlining_policy (interp))
90
94
return new (linfo,
91
- src, stmt_info, inmodule, nargs,
95
+ src, nothing , stmt_info, inmodule, nargs,
92
96
sptypes_from_meth_instance (linfo), slottypes, false ,
93
97
inlining)
94
98
end
@@ -100,6 +104,20 @@ function OptimizationState(linfo::MethodInstance, params::OptimizationParams, in
100
104
return OptimizationState (linfo, src, params, interp)
101
105
end
102
106
107
+ function ir_to_codeinf! (opt:: OptimizationState )
108
+ replace_code_newstyle! (opt. src, opt. ir, opt. nargs - 1 )
109
+ opt. ir = nothing
110
+ let src = opt. src:: CodeInfo
111
+ widen_all_consts! (src)
112
+ src. inferred = true
113
+ # finish updating the result struct
114
+ validate_code_in_debug_mode (opt. linfo, src, " optimized" )
115
+ return src
116
+ end
117
+ end
118
+
119
+ include (" compiler/ssair/driver.jl" )
120
+
103
121
104
122
# ############
105
123
# constants #
@@ -152,7 +170,7 @@ function isinlineable(m::Method, me::OptimizationState, params::OptimizationPara
152
170
end
153
171
end
154
172
if ! inlineable
155
- inlineable = inline_worthy (me. src . code, me . src, me . sptypes, me . slottypes , params, union_penalties, cost_threshold + bonus)
173
+ inlineable = inline_worthy (me. ir , params, union_penalties, cost_threshold + bonus)
156
174
end
157
175
return inlineable
158
176
end
@@ -175,7 +193,7 @@ function stmt_affects_purity(@nospecialize(stmt), ir)
175
193
end
176
194
177
195
# Convert IRCode back to CodeInfo and compute inlining cost and sideeffects
178
- function finish (opt:: OptimizationState , params:: OptimizationParams , ir, @nospecialize (result))
196
+ function finish (interp :: AbstractInterpreter , opt:: OptimizationState , params:: OptimizationParams , ir, @nospecialize (result))
179
197
def = opt. linfo. def
180
198
nargs = Int (opt. nargs) - 1
181
199
@@ -225,7 +243,7 @@ function finish(opt::OptimizationState, params::OptimizationParams, ir, @nospeci
225
243
end
226
244
end
227
245
228
- replace_code_newstyle! ( opt. src, ir, nargs)
246
+ opt. ir = ir
229
247
230
248
# determine and cache inlineability
231
249
union_penalties = false
@@ -263,14 +281,15 @@ function finish(opt::OptimizationState, params::OptimizationParams, ir, @nospeci
263
281
opt. src. inlineable = isinlineable (def, opt, params, union_penalties, bonus)
264
282
end
265
283
end
284
+
266
285
nothing
267
286
end
268
287
269
288
# run the optimization work
270
289
function optimize (interp:: AbstractInterpreter , opt:: OptimizationState , params:: OptimizationParams , @nospecialize (result))
271
290
nargs = Int (opt. nargs) - 1
272
291
@timeit " optimizer" ir = run_passes (opt. src, nargs, opt)
273
- finish (opt, params, ir, result)
292
+ finish (interp, opt, params, ir, result)
274
293
end
275
294
276
295
@@ -296,7 +315,7 @@ plus_saturate(x::Int, y::Int) = max(x, y, x+y)
296
315
# known return type
297
316
isknowntype (@nospecialize T) = (T === Union{}) || isa (T, Const) || isconcretetype (widenconst (T))
298
317
299
- function statement_cost (ex:: Expr , line:: Int , src:: CodeInfo , sptypes:: Vector{Any} ,
318
+ function statement_cost (ex:: Expr , line:: Int , src:: Union{ CodeInfo, IRCode} , sptypes:: Vector{Any} ,
300
319
slottypes:: Vector{Any} , union_penalties:: Bool ,
301
320
params:: OptimizationParams , error_path:: Bool = false )
302
321
head = ex. head
@@ -308,7 +327,7 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, sptypes::Vector{Any}
308
327
if ftyp === IntrinsicFunction && farg isa SSAValue
309
328
# if this comes from code that was already inlined into another function,
310
329
# Consts have been widened. try to recover in simple cases.
311
- farg = src. code[farg. id]
330
+ farg = isa ( src, CodeInfo) ? src . code[farg. id] : src . stmts[farg . id][ :inst ]
312
331
if isa (farg, GlobalRef) || isa (farg, QuoteNode) || isa (farg, IntrinsicFunction) || isexpr (farg, :static_parameter )
313
332
ftyp = argextype (farg, src, sptypes, slottypes)
314
333
end
@@ -350,7 +369,7 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, sptypes::Vector{Any}
350
369
end
351
370
return T_FFUNC_COST[fidx]
352
371
end
353
- extyp = line == - 1 ? Any : src . ssavaluetypes[ line]
372
+ extyp = line == - 1 ? Any : argextype ( SSAValue ( line), src, sptypes, slottypes)
354
373
if extyp === Union{}
355
374
return 0
356
375
end
@@ -361,7 +380,7 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, sptypes::Vector{Any}
361
380
# run-time of the function, we omit them from
362
381
# consideration. This way, non-inlined error branches do not
363
382
# prevent inlining.
364
- extyp = line == - 1 ? Any : src . ssavaluetypes[ line]
383
+ extyp = line == - 1 ? Any : argextype ( SSAValue ( line), src, sptypes, slottypes)
365
384
return extyp === Union{} ? 0 : 20
366
385
elseif head === :(= )
367
386
if ex. args[1 ] isa GlobalRef
@@ -386,31 +405,32 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, sptypes::Vector{Any}
386
405
return 0
387
406
end
388
407
389
- function statement_or_branch_cost (@nospecialize (stmt), line:: Int , src:: CodeInfo , sptypes:: Vector{Any} ,
408
+ function statement_or_branch_cost (@nospecialize (stmt), line:: Int , src:: Union{ CodeInfo, IRCode} , sptypes:: Vector{Any} ,
390
409
slottypes:: Vector{Any} , union_penalties:: Bool , params:: OptimizationParams ,
391
410
throw_blocks:: Union{Nothing,BitSet} )
392
411
thiscost = 0
412
+ dst (tgt) = isa (src, IRCode) ? first (src. cfg. blocks[tgt]. stmts) : tgt
393
413
if stmt isa Expr
394
414
thiscost = statement_cost (stmt, line, src, sptypes, slottypes, union_penalties, params,
395
415
throw_blocks != = nothing && line in throw_blocks):: Int
396
416
elseif stmt isa GotoNode
397
417
# loops are generally always expensive
398
418
# but assume that forward jumps are already counted for from
399
419
# summing the cost of the not-taken branch
400
- thiscost = stmt. label < line ? 40 : 0
420
+ thiscost = dst ( stmt. label) < line ? 40 : 0
401
421
elseif stmt isa GotoIfNot
402
- thiscost = stmt. dest < line ? 40 : 0
422
+ thiscost = dst ( stmt. dest) < line ? 40 : 0
403
423
end
404
424
return thiscost
405
425
end
406
426
407
- function inline_worthy (body :: Array{Any,1} , src :: CodeInfo , sptypes :: Vector{Any} , slottypes :: Vector{Any} ,
427
+ function inline_worthy (ir :: IRCode ,
408
428
params:: OptimizationParams , union_penalties:: Bool = false , cost_threshold:: Integer = params. inline_cost_threshold)
409
429
bodycost:: Int = 0
410
- throw_blocks = params. unoptimize_throw_blocks ? find_throw_blocks (body ) : nothing
411
- for line = 1 : length (body )
412
- stmt = body [line]
413
- thiscost = statement_or_branch_cost (stmt, line, src, sptypes, slottypes , union_penalties, params, throw_blocks)
430
+ throw_blocks = params. unoptimize_throw_blocks ? find_throw_blocks (ir . stmts . inst, RefValue (ir) ) : nothing
431
+ for line = 1 : length (ir . stmts )
432
+ stmt = ir . stmts [line][ :inst ]
433
+ thiscost = statement_or_branch_cost (stmt, line, ir, ir . sptypes, ir . argtypes , union_penalties, params, throw_blocks)
414
434
bodycost = plus_saturate (bodycost, thiscost)
415
435
bodycost > cost_threshold && return false
416
436
end
@@ -490,5 +510,3 @@ function renumber_ir_elements!(body::Vector{Any}, ssachangemap::Vector{Int}, lab
490
510
end
491
511
end
492
512
end
493
-
494
- include (" compiler/ssair/driver.jl" )
0 commit comments